Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit a6cbd81

Browse files
Placeholder hyperparameters overwrite the estimator hyperparameters
1 parent 966d436 commit a6cbd81

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6969
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
7070
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
7171
where each instance is a different channel of training data.
72-
hyperparameters (dict[str, str] or dict[str, Placeholder], optional): Parameters used for training.
73-
Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
72+
hyperparameters: Parameters used for training.
73+
* (dict[str, str], optional) - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
7474
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
75+
* (Placeholder, optional) - Hyperparameters supplied will overwrite the Hyperparameters specified in the estimator.
7576
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
7677
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
7778
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
@@ -127,8 +128,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
127128
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
128129

129130
if hyperparameters is not None:
130-
if estimator.hyperparameters() is not None:
131-
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
131+
if not isinstance(hyperparameters, Placeholder):
132+
if estimator.hyperparameters() is not None:
133+
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
132134
parameters['HyperParameters'] = hyperparameters
133135

134136
if experiment_config is not None:
@@ -173,7 +175,12 @@ def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyper
173175
estimator_hyperparameters (dict): Hyperparameters specified in the estimator
174176
"""
175177
merged_hyperparameters = estimator_hyperparameters.copy()
176-
merge_dicts(merged_hyperparameters, training_step_hyperparameters)
178+
for key, value in training_step_hyperparameters.items():
179+
if key in merged_hyperparameters:
180+
logger.info(
181+
f"hyperparameter property: <{key}> with value: <{merged_hyperparameters[key]}> provided in the"
182+
f" estimator will be overwritten with value provided in constructor: <{value}>")
183+
merged_hyperparameters[key] = value
177184
return merged_hyperparameters
178185

179186
class TransformStep(Task):

‎tests/unit/test_sagemaker_steps.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,77 @@ def test_training_step_creation(pca_estimator):
272272
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
273273
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
274274
def test_training_step_creation_with_placeholders(pca_estimator):
275+
execution_input = ExecutionInput(schema={
276+
'Data': str,
277+
'OutputPath': str,
278+
'HyperParameters': str
279+
})
280+
281+
step_input = StepInput(schema={
282+
'JobName': str,
283+
})
284+
285+
step = TrainingStep('Training',
286+
estimator=pca_estimator,
287+
job_name=step_input['JobName'],
288+
data=execution_input['Data'],
289+
output_data_config_path=execution_input['OutputPath'],
290+
experiment_config={
291+
'ExperimentName': 'pca_experiment',
292+
'TrialName': 'pca_trial',
293+
'TrialComponentDisplayName': 'Training'
294+
},
295+
tags=DEFAULT_TAGS,
296+
hyperparameters=execution_input['HyperParameters']
297+
)
298+
assert step.to_dict() == {
299+
'Type': 'Task',
300+
'Parameters': {
301+
'AlgorithmSpecification': {
302+
'TrainingImage': PCA_IMAGE,
303+
'TrainingInputMode': 'File'
304+
},
305+
'OutputDataConfig': {
306+
'S3OutputPath.$': "$$.Execution.Input['OutputPath']"
307+
},
308+
'StoppingCondition': {
309+
'MaxRuntimeInSeconds': 86400
310+
},
311+
'ResourceConfig': {
312+
'InstanceCount': 1,
313+
'InstanceType': 'ml.c4.xlarge',
314+
'VolumeSizeInGB': 30
315+
},
316+
'RoleArn': EXECUTION_ROLE,
317+
'HyperParameters.$': "$$.Execution.Input['HyperParameters']",
318+
'InputDataConfig': [
319+
{
320+
'ChannelName': 'training',
321+
'DataSource': {
322+
'S3DataSource': {
323+
'S3DataDistributionType': 'FullyReplicated',
324+
'S3DataType': 'S3Prefix',
325+
'S3Uri.$': "$$.Execution.Input['Data']"
326+
}
327+
}
328+
}
329+
],
330+
'ExperimentConfig': {
331+
'ExperimentName': 'pca_experiment',
332+
'TrialName': 'pca_trial',
333+
'TrialComponentDisplayName': 'Training'
334+
},
335+
'TrainingJobName.$': "$['JobName']",
336+
'Tags': DEFAULT_TAGS_LIST
337+
},
338+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
339+
'End': True
340+
}
341+
342+
343+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
344+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
345+
def test_training_step_creation_with_hyperparameters_containing_placeholders(pca_estimator):
275346
execution_input = ExecutionInput(schema={
276347
'Data': str,
277348
'OutputPath': str,
@@ -353,7 +424,6 @@ def test_training_step_creation_with_placeholders(pca_estimator):
353424
'End': True
354425
}
355426

356-
357427
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
358428
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
359429
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /