Skip to content

Commit

Permalink
Rename est_type to estimator_type
Browse files Browse the repository at this point in the history
This max things more clear as est is not a common abbreviation.
This is as suggested in #419.

Reviewers: mtrofin

Reviewed By: mtrofin

Pull Request: #422
  • Loading branch information
boomanaiden154 authored Jan 28, 2025
1 parent 3a4a297 commit a59ca65
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 60 deletions.
13 changes: 7 additions & 6 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ class SamplingBlackboxEvaluator(BlackboxEvaluator):
"""A blackbox evaluator that samples from a corpus to collect reward."""

def __init__(self, train_corpus: corpus.Corpus,
est_type: blackbox_optimizers.EstimatorType,
estimator_type: blackbox_optimizers.EstimatorType,
total_num_perturbations: int, num_ir_repeats_within_worker: int):
self._samples = []
self._train_corpus = train_corpus
self._total_num_perturbations = total_num_perturbations
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
self._est_type = est_type
self._estimator_type = estimator_type

super().__init__(train_corpus)

Expand All @@ -82,7 +82,8 @@ def get_results(
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
self._samples.append(sample)
# add copy of sample for antithetic perturbation pair
if self._est_type == (blackbox_optimizers.EstimatorType.ANTITHETIC):
if self._estimator_type == (
blackbox_optimizers.EstimatorType.ANTITHETIC):
self._samples.append(sample)

compile_args = zip(perturbations, self._samples)
Expand Down Expand Up @@ -111,10 +112,10 @@ class TraceBlackboxEvaluator(BlackboxEvaluator):
"""A blackbox evaluator that utilizes trace based cost modelling."""

def __init__(self, train_corpus: corpus.Corpus,
est_type: blackbox_optimizers.EstimatorType, bb_trace_path: str,
function_index_path: str):
estimator_type: blackbox_optimizers.EstimatorType,
bb_trace_path: str, function_index_path: str):
self._train_corpus = train_corpus
self._est_type = est_type
self._estimator_type = estimator_type
self._bb_trace_path = bb_trace_path
self._function_index_path = function_index_path

Expand Down
7 changes: 4 additions & 3 deletions compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BlackboxLearnerConfig:
# What kind of ES training?
# - antithetic: for each perturbtation, try an antiperturbation
# - forward_fd: try total_num_perturbations independent perturbations
est_type: blackbox_optimizers.EstimatorType
estimator_type: blackbox_optimizers.EstimatorType

# Should the rewards for blackbox optimization in a single step be normalized?
fvalues_normalization: bool
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(self,
self._summary_writer = tf.summary.create_file_writer(output_dir)

self._evaluator = self._config.evaluator(self._train_corpus,
self._config.est_type)
self._config.estimator_type)

def _get_perturbations(self) -> List[npt.NDArray[np.float32]]:
"""Get perturbations for the model weights."""
Expand Down Expand Up @@ -270,7 +270,8 @@ def run_step(self, pool: FixedWorkerPool) -> None:

initial_perturbations = self._get_perturbations()
# positive-negative pairs
if self._config.est_type == blackbox_optimizers.EstimatorType.ANTITHETIC:
if (self._config.estimator_type ==
blackbox_optimizers.EstimatorType.ANTITHETIC):
initial_perturbations = [
p for p in initial_perturbations for p in (p, -p)
]
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setUp(self):
self._learner_config = blackbox_learner.BlackboxLearnerConfig(
total_steps=1,
blackbox_optimizer=blackbox_optimizers.Algorithm.MONTE_CARLO,
est_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
estimator_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
fvalues_normalization=True,
hyperparameters_update_method=blackbox_optimizers.UpdateMethod
.NO_METHOD,
Expand Down Expand Up @@ -117,7 +117,7 @@ def _policy_saver_fn(parameters: npt.NDArray[np.float32],
self._learner = blackbox_learner.BlackboxLearner(
blackbox_opt=blackbox_optimizers.MonteCarloBlackboxOptimizer(
precision_parameter=1.0,
est_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
estimator_type=blackbox_optimizers.EstimatorType.ANTITHETIC,
normalize_fvalues=True,
hyperparameters_update_method=blackbox_optimizers.UpdateMethod
.NO_METHOD,
Expand Down
68 changes: 34 additions & 34 deletions compiler_opt/es/blackbox_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class UpdateMethod(enum.Enum):

def filter_top_directions(
perturbations: FloatArray2D, function_values: FloatArray,
est_type: EstimatorType,
estimator_type: EstimatorType,
num_top_directions: int) -> Tuple[FloatArray, FloatArray]:
"""Select the subset of top-performing perturbations.
Expand All @@ -134,7 +134,7 @@ def filter_top_directions(
p, -p in the even/odd entries, so the directions p_1,...,p_n
will be ordered (p_1, -p_1, p_2, -p_2,...)
function_values: np array of reward values (maximization)
est_type: (forward_fd | antithetic)
estimator_type: (forward_fd | antithetic)
num_top_directions: the number of top directions to include
For antithetic, the total number of perturbations will
be 2* this number, because we count p, -p as a single
Expand All @@ -148,16 +148,16 @@ def filter_top_directions(
"""
if not num_top_directions > 0:
return (perturbations, function_values)
if est_type == EstimatorType.FORWARD_FD:
if estimator_type == EstimatorType.FORWARD_FD:
top_index = np.argsort(-function_values)
elif est_type == EstimatorType.ANTITHETIC:
elif estimator_type == EstimatorType.ANTITHETIC:
top_index = np.argsort(-np.abs(function_values[0::2] -
function_values[1::2]))
top_index = top_index[:num_top_directions]
if est_type == EstimatorType.FORWARD_FD:
if estimator_type == EstimatorType.FORWARD_FD:
perturbations = perturbations[top_index]
function_values = function_values[top_index]
elif est_type == EstimatorType.ANTITHETIC:
elif estimator_type == EstimatorType.ANTITHETIC:
perturbations = np.concatenate(
(perturbations[2 * top_index], perturbations[2 * top_index + 1]),
axis=0)
Expand Down Expand Up @@ -245,11 +245,11 @@ class StatefulOptimizer(BlackboxOptimizer):
Class contains common methods for handling the state.
"""

def __init__(self, est_type: EstimatorType, normalize_fvalues: bool,
def __init__(self, estimator_type: EstimatorType, normalize_fvalues: bool,
hyperparameters_update_method: UpdateMethod,
extra_params: Optional[Sequence[int]]):

self.est_type = est_type
self.estimator_type = estimator_type
self.normalize_fvalues = normalize_fvalues
self.hyperparameters_update_method = hyperparameters_update_method
if hyperparameters_update_method == UpdateMethod.STATE_NORMALIZATION:
Expand Down Expand Up @@ -321,7 +321,7 @@ class MonteCarloBlackboxOptimizer(StatefulOptimizer):

def __init__(self,
precision_parameter: float,
est_type: EstimatorType,
estimator_type: EstimatorType,
normalize_fvalues: bool,
hyperparameters_update_method: UpdateMethod,
extra_params: Optional[Sequence[int]],
Expand All @@ -342,8 +342,8 @@ def __init__(self,
self.precision_parameter = precision_parameter
self.num_top_directions = num_top_directions
self.gradient_ascent_optimizer = gradient_ascent_optimizer
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
extra_params)
super().__init__(estimator_type, normalize_fvalues,
hyperparameters_update_method, extra_params)

# TODO: Issue #285
def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
Expand All @@ -358,14 +358,14 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
function_values = np.array(normalized_values[:-1])
current_value = normalized_values[-1]
top_ps, top_fs = filter_top_directions(perturbations, function_values,
self.est_type,
self.estimator_type,
self.num_top_directions)
gradient = np.zeros(dim)
for i, perturbation in enumerate(top_ps):
function_value = top_fs[i]
if self.est_type == EstimatorType.FORWARD_FD:
if self.estimator_type == EstimatorType.FORWARD_FD:
gradient_sample = (function_value - current_value) * perturbation
elif self.est_type == EstimatorType.ANTITHETIC:
elif self.estimator_type == EstimatorType.ANTITHETIC:
gradient_sample = function_value * perturbation
gradient_sample /= self.precision_parameter**2
gradient += gradient_sample
Expand All @@ -374,7 +374,7 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
# in that code, the denominator for antithetic was num_top_directions.
# we maintain compatibility for now so that the same hyperparameters
# currently used in Toaster will have the same effect
if self.est_type == EstimatorType.ANTITHETIC and \
if self.estimator_type == EstimatorType.ANTITHETIC and \
len(top_ps) < len(perturbations):
gradient *= 2
# Use the gradient ascent optimizer to compute the next parameters with the
Expand All @@ -396,7 +396,7 @@ class SklearnRegressionBlackboxOptimizer(StatefulOptimizer):
def __init__(self,
regression_method: RegressionType,
regularizer: float,
est_type: EstimatorType,
estimator_type: EstimatorType,
normalize_fvalues: bool,
hyperparameters_update_method: UpdateMethod,
extra_params: Optional[Sequence[int]],
Expand All @@ -422,8 +422,8 @@ def __init__(self,
else:
raise ValueError('Optimization procedure option not available')
self.gradient_ascent_optimizer = gradient_ascent_optimizer
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
extra_params)
super().__init__(estimator_type, normalize_fvalues,
hyperparameters_update_method, extra_params)

def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
current_input: FloatArray, current_value: float) -> FloatArray:
Expand All @@ -439,11 +439,11 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,

matrix = None
b_vector = None
if self.est_type == EstimatorType.FORWARD_FD:
if self.estimator_type == EstimatorType.FORWARD_FD:
matrix = np.array(perturbations)
b_vector = (
function_values - np.array([current_value] * len(function_values)))
elif self.est_type == EstimatorType.ANTITHETIC:
elif self.estimator_type == EstimatorType.ANTITHETIC:
matrix = np.array(perturbations[::2])
function_even_values = np.array(function_values.tolist()[::2])
function_odd_values = np.array(function_values.tolist()[1::2])
Expand Down Expand Up @@ -495,20 +495,20 @@ def normalize_function_values(


def monte_carlo_gradient(precision_parameter: float,
est_type: EstimatorType,
estimator_type: EstimatorType,
perturbations: FloatArray2D,
function_values: FloatArray,
current_value: float,
energy: Optional[float] = 0) -> FloatArray:
"""Calculates Monte Carlo gradient.
There are several ways of estimating the gradient. This is specified by the
attribute self.est_type. Currently, forward finite difference (FFD) and
attribute self.estimator_type. Currently, forward finite difference (FFD) and
antithetic are supported.
Args:
precision_parameter: sd of Gaussian perturbations
est_type: 'forward_fd' (FFD) or 'antithetic'
estimator_type: 'forward_fd' (FFD) or 'antithetic'
perturbations: the simulated perturbations
function_values: reward from perturbations (possibly normalized)
current_value: estimated reward at current point (possibly normalized)
Expand All @@ -522,11 +522,11 @@ def monte_carlo_gradient(precision_parameter: float,
"""
dim = len(perturbations[0])
b_vector = None
if est_type == EstimatorType.FORWARD_FD:
if estimator_type == EstimatorType.FORWARD_FD:
b_vector = (function_values -
np.array([current_value] * len(function_values))) / (
precision_parameter * precision_parameter)
elif est_type == EstimatorType.ANTITHETIC:
elif estimator_type == EstimatorType.ANTITHETIC:
b_vector = function_values / (2.0 * precision_parameter *
precision_parameter)
else:
Expand All @@ -543,15 +543,15 @@ def monte_carlo_gradient(precision_parameter: float,
return gradient


def sklearn_regression_gradient(clf: LinearModel, est_type: EstimatorType,
def sklearn_regression_gradient(clf: LinearModel, estimator_type: EstimatorType,
perturbations: FloatArray2D,
function_values: FloatArray,
current_value: float) -> FloatArray:
"""Calculates gradient by function difference regression.
Args:
clf: an object (SkLearn linear model) which fits Ax = b
est_type: 'forward_fd' (FFD) or 'antithetic'
estimator_type: 'forward_fd' (FFD) or 'antithetic'
perturbations: the simulated perturbations
function_values: reward from perturbations (possibly normalized)
current_value: estimated reward at current point (possibly normalized)
Expand All @@ -565,11 +565,11 @@ def sklearn_regression_gradient(clf: LinearModel, est_type: EstimatorType,
matrix = None
b_vector = None
dim = perturbations[0].size
if est_type == EstimatorType.FORWARD_FD:
if estimator_type == EstimatorType.FORWARD_FD:
matrix = np.array(perturbations)
b_vector = (
function_values - np.array([current_value] * len(function_values)))
elif est_type == EstimatorType.ANTITHETIC:
elif estimator_type == EstimatorType.ANTITHETIC:
matrix = np.array(perturbations[::2])
function_even_values = np.array(function_values.tolist()[::2])
function_odd_values = np.array(function_values.tolist()[1::2])
Expand Down Expand Up @@ -903,14 +903,14 @@ class TrustRegionOptimizer(StatefulOptimizer):
schedule that would have to be tuned.
"""

def __init__(self, precision_parameter: float, est_type: EstimatorType,
def __init__(self, precision_parameter: float, estimator_type: EstimatorType,
normalize_fvalues: bool,
hyperparameters_update_method: UpdateMethod,
extra_params: Optional[Sequence[int]], tr_params: Mapping[str,
Any]):
self.precision_parameter = precision_parameter
super().__init__(est_type, normalize_fvalues, hyperparameters_update_method,
extra_params)
super().__init__(estimator_type, normalize_fvalues,
hyperparameters_update_method, extra_params)

self.accepted_quadratic_model = None
self.accepted_function_value = None
Expand Down Expand Up @@ -1147,12 +1147,12 @@ def update_quadratic_model(self, perturbations: FloatArray2D,
current_value = normalized_values[1]
self.normalized_current_value = current_value
if self.params['grad_type'] == GradientType.REGRESSION:
new_gradient = sklearn_regression_gradient(self.clf, self.est_type,
new_gradient = sklearn_regression_gradient(self.clf, self.estimator_type,
perturbations, function_values,
current_value)
else:
new_gradient = monte_carlo_gradient(self.precision_parameter,
self.est_type, perturbations,
self.estimator_type, perturbations,
function_values, current_value)
new_gradient *= -1 # TR subproblem solver performs minimization
if not is_update:
Expand Down
24 changes: 13 additions & 11 deletions compiler_opt/es/blackbox_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ class BlackboxOptimizationAlgorithmsTest(parameterized.TestCase):
blackbox_optimizers.EstimatorType.FORWARD_FD, 5,
np.array([[4, 2], [8, -6], [-1, 5], [0, -3], [2, -1]
]), np.array([10, 8, 4, 2, 1])))
def test_filtering(self, perturbations, function_values, est_type,
def test_filtering(self, perturbations, function_values, estimator_type,
num_top_directions, expected_ps, expected_fs):
top_ps, top_fs = blackbox_optimizers.filter_top_directions(
perturbations, function_values, est_type, num_top_directions)
perturbations, function_values, estimator_type, num_top_directions)
np.testing.assert_array_equal(expected_ps, top_ps)
np.testing.assert_array_equal(expected_fs, top_fs)

Expand All @@ -88,13 +88,14 @@ def test_filtering(self, perturbations, function_values, est_type,
blackbox_optimizers.EstimatorType.ANTITHETIC, 0, np.array([102, -34])),
(perturbation_array, function_value_array,
blackbox_optimizers.EstimatorType.FORWARD_FD, 0, np.array([74, -34])))
def test_monte_carlo_gradient(self, perturbations, function_values, est_type,
num_top_directions, expected_gradient):
def test_monte_carlo_gradient(self, perturbations, function_values,
estimator_type, num_top_directions,
expected_gradient):
precision_parameter = 0.1
step_size = 0.01
current_value = 2
blackbox_object = blackbox_optimizers.MonteCarloBlackboxOptimizer(
precision_parameter, est_type, False,
precision_parameter, estimator_type, False,
blackbox_optimizers.UpdateMethod.NO_METHOD, None, step_size,
num_top_directions)
current_input = np.zeros(2)
Expand All @@ -118,7 +119,7 @@ def test_monte_carlo_gradient(self, perturbations, function_values, est_type,
(perturbation_array, function_value_array,
blackbox_optimizers.EstimatorType.FORWARD_FD, 0, np.array([74, -34])))
def test_monte_carlo_gradient_with_gradient_ascent_optimizer(
self, perturbations, function_values, est_type, num_top_directions,
self, perturbations, function_values, estimator_type, num_top_directions,
expected_gradient):
precision_parameter = 0.1
step_size = 0.01
Expand All @@ -128,7 +129,7 @@ def test_monte_carlo_gradient_with_gradient_ascent_optimizer(
step_size, 0.0))
blackbox_object = (
blackbox_optimizers.MonteCarloBlackboxOptimizer(
precision_parameter, est_type, False,
precision_parameter, estimator_type, False,
blackbox_optimizers.UpdateMethod.NO_METHOD, None, None,
num_top_directions, gradient_ascent_optimizer))
current_input = np.zeros(2)
Expand All @@ -154,8 +155,9 @@ def test_monte_carlo_gradient_with_gradient_ascent_optimizer(
(perturbation_array, function_value_array,
blackbox_optimizers.EstimatorType.FORWARD_FD, 0,
np.array([0.030203, 0.001796])))
def test_sklearn_gradient(self, perturbations, function_values, est_type,
num_top_directions, expected_gradient):
def test_sklearn_gradient(self, perturbations, function_values,
estimator_type, num_top_directions,
expected_gradient):
precision_parameter = 0.1
step_size = 0.01
current_value = 2
Expand All @@ -164,8 +166,8 @@ def test_sklearn_gradient(self, perturbations, function_values, est_type,
gradient_ascent_optimization_algorithms.MomentumOptimizer(
step_size, 0.0))
blackbox_object = blackbox_optimizers.SklearnRegressionBlackboxOptimizer(
blackbox_optimizers.RegressionType.RIDGE, regularizer, est_type, True,
blackbox_optimizers.UpdateMethod.NO_METHOD, [], None,
blackbox_optimizers.RegressionType.RIDGE, regularizer, estimator_type,
True, blackbox_optimizers.UpdateMethod.NO_METHOD, [], None,
gradient_ascent_optimizer)
current_input = np.zeros(2)
step = blackbox_object.run_step(perturbations, function_values,
Expand Down
Loading

0 comments on commit a59ca65

Please sign in to comment.