Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9559849

Browse files
ravinkohlinabenabe0928
andauthored
[FIX] Additional metrics during training (#316)
* additional_metrics during training * fix flake * Add test for unsupported budget type * Apply suggestions from code review Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com> Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
1 parent f089845 commit 9559849

File tree

3 files changed

+123
-21
lines changed

3 files changed

+123
-21
lines changed

‎autoPyTorch/evaluation/abstract_evaluator.py‎

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -490,37 +490,23 @@ def __init__(self, backend: Backend,
490490
))
491491

492492
self.additional_metrics: Optional[List[autoPyTorchMetric]] = None
493+
metrics_dict: Optional[Dict[str, List[str]]] = None
493494
if all_supported_metrics:
494495
self.additional_metrics = get_metrics(dataset_properties=self.dataset_properties,
495496
all_supported_metrics=all_supported_metrics)
497+
# Update fit dictionary with metrics passed to the evaluator
498+
metrics_dict = {'additional_metrics': []}
499+
metrics_dict['additional_metrics'].append(self.metric.name)
500+
for metric in self.additional_metrics:
501+
metrics_dict['additional_metrics'].append(metric.name)
496502

497-
self.fit_dictionary: Dict[str, Any] = {'dataset_properties': self.dataset_properties}
498503
self._init_params = init_params
499-
self.fit_dictionary.update({
500-
'X_train': self.X_train,
501-
'y_train': self.y_train,
502-
'X_test': self.X_test,
503-
'y_test': self.y_test,
504-
'backend': self.backend,
505-
'logger_port': logger_port,
506-
'optimize_metric': self.metric.name
507-
})
504+
508505
assert self.pipeline_class is not None, "Could not infer pipeline class"
509506
pipeline_config = pipeline_config if pipeline_config is not None \
510507
else self.pipeline_class.get_default_pipeline_options()
511508
self.budget_type = pipeline_config['budget_type'] if budget_type is None else budget_type
512509
self.budget = pipeline_config[self.budget_type] if budget == 0 else budget
513-
self.fit_dictionary = {**pipeline_config, **self.fit_dictionary}
514-
515-
# If the budget is epochs, we want to limit that in the fit dictionary
516-
if self.budget_type == 'epochs':
517-
self.fit_dictionary['epochs'] = budget
518-
self.fit_dictionary.pop('runtime', None)
519-
elif self.budget_type == 'runtime':
520-
self.fit_dictionary['runtime'] = budget
521-
self.fit_dictionary.pop('epochs', None)
522-
else:
523-
raise ValueError(f"Unsupported budget type {self.budget_type} provided")
524510

525511
self.num_run = 0 if num_run is None else num_run
526512

@@ -533,13 +519,65 @@ def __init__(self, backend: Backend,
533519
port=logger_port,
534520
)
535521

522+
self._init_fit_dictionary(logger_port=logger_port, pipeline_config=pipeline_config, metrics_dict=metrics_dict)
536523
self.Y_optimization: Optional[np.ndarray] = None
537524
self.Y_actual_train: Optional[np.ndarray] = None
538525
self.pipelines: Optional[List[BaseEstimator]] = None
539526
self.pipeline: Optional[BaseEstimator] = None
540527
self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(dict_repr(self.fit_dictionary)))
541528
self.logger.debug("Search space updates :{}".format(self.search_space_updates))
542529

530+
def _init_fit_dictionary(
531+
self,
532+
logger_port: int,
533+
pipeline_config: Dict[str, Any],
534+
metrics_dict: Optional[Dict[str, List[str]]] = None,
535+
) -> None:
536+
"""
537+
Initialises the fit dictionary
538+
539+
Args:
540+
logger_port (int):
541+
Logging is performed using a socket-server scheme to be robust against many
542+
parallel entities that want to write to the same file. This integer states the
543+
socket port for the communication channel.
544+
pipeline_config (Dict[str, Any]):
545+
Defines the content of the pipeline being evaluated. For example, it
546+
contains pipeline specific settings like logging name, or whether or not
547+
to use tensorboard.
548+
metrics_dict (Optional[Dict[str, List[str]]]):
549+
Contains a list of metric names to be evaluated in Trainer with key `additional_metrics`. Defaults to None.
550+
551+
Returns:
552+
None
553+
"""
554+
555+
self.fit_dictionary: Dict[str, Any] = {'dataset_properties': self.dataset_properties}
556+
557+
if metrics_dict is not None:
558+
self.fit_dictionary.update(metrics_dict)
559+
560+
self.fit_dictionary.update({
561+
'X_train': self.X_train,
562+
'y_train': self.y_train,
563+
'X_test': self.X_test,
564+
'y_test': self.y_test,
565+
'backend': self.backend,
566+
'logger_port': logger_port,
567+
'optimize_metric': self.metric.name
568+
})
569+
570+
self.fit_dictionary.update(pipeline_config)
571+
# If the budget is epochs, we want to limit that in the fit dictionary
572+
if self.budget_type == 'epochs':
573+
self.fit_dictionary['epochs'] = self.budget
574+
self.fit_dictionary.pop('runtime', None)
575+
elif self.budget_type == 'runtime':
576+
self.fit_dictionary['runtime'] = self.budget
577+
self.fit_dictionary.pop('epochs', None)
578+
else:
579+
raise ValueError(f"budget type must be `epochs` or `runtime`, but got {self.budget_type}")
580+
543581
def _get_pipeline(self) -> BaseEstimator:
544582
"""
545583
Implements a pipeline object based on the self.configuration attribute.

‎test/test_evaluation/test_abstract_evaluator.py‎

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,35 @@ def test_file_output(self):
282282
'.autoPyTorch', 'runs', '1_0_1.0')))
283283

284284
shutil.rmtree(self.working_directory, ignore_errors=True)
285+
286+
def test_error_unsupported_budget_type(self):
287+
shutil.rmtree(self.working_directory, ignore_errors=True)
288+
os.mkdir(self.working_directory)
289+
290+
queue_mock = unittest.mock.Mock()
291+
292+
context = BackendContext(
293+
prefix='autoPyTorch',
294+
temporary_directory=os.path.join(self.working_directory, 'tmp'),
295+
output_directory=os.path.join(self.working_directory, 'out'),
296+
delete_tmp_folder_after_terminate=True,
297+
delete_output_folder_after_terminate=True,
298+
)
299+
with unittest.mock.patch.object(Backend, 'load_datamanager') as load_datamanager_mock:
300+
load_datamanager_mock.return_value = get_multiclass_classification_datamanager()
301+
302+
backend = Backend(context, prefix='autoPyTorch')
303+
304+
try:
305+
AbstractEvaluator(
306+
backend=backend,
307+
output_y_hat_optimization=False,
308+
queue=queue_mock,
309+
pipeline_config={'budget_type': "error", 'error': 0},
310+
metric=accuracy,
311+
budget=0,
312+
configuration=1)
313+
except Exception as e:
314+
self.assertIsInstance(e, ValueError)
315+
316+
shutil.rmtree(self.working_directory, ignore_errors=True)

‎test/test_evaluation/test_train_evaluator.py‎

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,35 @@ def test_get_results(self):
262262
self.assertEqual(len(result), 5)
263263
self.assertEqual(result[0][0], 0)
264264
self.assertAlmostEqual(result[0][1], 1.0)
265+
266+
@unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline')
267+
def test_additional_metrics_during_training(self, pipeline_mock):
268+
pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 50}
269+
# Binary iris, contains 69 train samples, 31 test samples
270+
D = get_binary_classification_datamanager()
271+
pipeline_mock.predict_proba.side_effect = \
272+
lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1))
273+
pipeline_mock.side_effect = lambda **kwargs: pipeline_mock
274+
pipeline_mock.get_additional_run_info.return_value = None
275+
276+
# Binary iris, contains 69 train samples, 31 test samples
277+
D = get_binary_classification_datamanager()
278+
279+
configuration = unittest.mock.Mock(spec=Configuration)
280+
backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch')
281+
backend_api.load_datamanager = lambda: D
282+
queue_ = multiprocessing.Queue()
283+
284+
evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0,
285+
pipeline_config={'budget_type': 'epochs', 'epochs': 50}, all_supported_metrics=True)
286+
evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output)
287+
evaluator.file_output.return_value = (None, {})
288+
289+
evaluator.fit_predict_and_loss()
290+
291+
rval = read_queue(evaluator.queue)
292+
self.assertEqual(len(rval), 1)
293+
result = rval[0]
294+
self.assertIn('additional_run_info', result)
295+
self.assertIn('opt_loss', result['additional_run_info'])
296+
self.assertGreater(len(result['additional_run_info']['opt_loss'].keys()), 1)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /