Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8dd38ec

Browse files
Merge branch 'main' into support-placeholders-for-tuning-step
2 parents 3b38a6a + 19761b0 commit 8dd38ec

File tree

3 files changed

+158
-62
lines changed

3 files changed

+158
-62
lines changed

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
7474
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
7575
* (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator.
7676
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
77-
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
77+
experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None)
7878
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
79-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
79+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
8080
output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
8181
artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
82+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>`_. (Default: None)
83+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
8284
"""
8385
self.estimator = estimator
8486
self.job_name = job_name
@@ -105,44 +107,48 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
105107
data = data.to_jsonpath()
106108

107109
if isinstance(job_name, str):
108-
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
110+
training_parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
109111
else:
110-
parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
112+
training_parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
111113

112114
if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False:
113-
parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
115+
training_parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
114116

115117
if estimator.rules != None:
116-
parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
118+
training_parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
117119

118120
if isinstance(job_name, Placeholder):
119-
parameters['TrainingJobName'] = job_name
121+
training_parameters['TrainingJobName'] = job_name
120122

121123
if output_data_config_path is not None:
122-
parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path
124+
training_parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path
123125

124126
if data is not None and is_data_placeholder:
125127
# Replace the 'S3Uri' key with one that supports JSONpath value.
126128
# Support for uri str only: The list will only contain 1 element
127-
data_uri = parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
128-
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
129+
data_uri = training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
130+
training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
129131

130132
if hyperparameters is not None:
131133
if not isinstance(hyperparameters, Placeholder):
132134
if estimator.hyperparameters() is not None:
133135
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
134-
parameters['HyperParameters'] = hyperparameters
136+
training_parameters['HyperParameters'] = hyperparameters
135137

136138
if experiment_config is not None:
137-
parameters['ExperimentConfig'] = experiment_config
139+
training_parameters['ExperimentConfig'] = experiment_config
138140

139-
if 'S3Operations' in parameters:
140-
del parameters['S3Operations']
141+
if 'S3Operations' in training_parameters:
142+
del training_parameters['S3Operations']
141143

142144
if tags:
143-
parameters['Tags'] = tags_dict_to_kv_list(tags)
145+
training_parameters['Tags'] =tagsifisinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
144146

145-
kwargs[Field.Parameters.value] = parameters
147+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
148+
# Update training parameters with input parameters
149+
merge_dicts(training_parameters, kwargs[Field.Parameters.value])
150+
151+
kwargs[Field.Parameters.value] = training_parameters
146152
super(TrainingStep, self).__init__(state_id, **kwargs)
147153

148154
def get_expected_model(self, model_name=None):

‎tests/integ/test_sagemaker_steps.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,64 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf
107107
# End of Cleanup
108108

109109

110+
def test_training_step_with_placeholders(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn):
111+
execution_input = ExecutionInput(schema={
112+
'JobName': str,
113+
'HyperParameters': str,
114+
'InstanceCount': int,
115+
'InstanceType': str,
116+
'MaxRun': int
117+
})
118+
119+
parameters = {
120+
'HyperParameters': execution_input['HyperParameters'],
121+
'ResourceConfig': {
122+
'InstanceCount': execution_input['InstanceCount'],
123+
'InstanceType': execution_input['InstanceType']
124+
},
125+
'StoppingCondition': {
126+
'MaxRuntimeInSeconds': execution_input['MaxRun']
127+
}
128+
}
129+
130+
training_step = TrainingStep('create_training_job_step', estimator=pca_estimator_fixture,
131+
job_name=execution_input['JobName'], data=record_set_fixture, mini_batch_size=200,
132+
parameters=parameters)
133+
training_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
134+
workflow_graph = Chain([training_step])
135+
136+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
137+
# Create workflow and check definition
138+
workflow = create_workflow_and_check_definition(
139+
workflow_graph=workflow_graph,
140+
workflow_name=unique_name_from_base("integ-test-training-step-workflow"),
141+
sfn_client=sfn_client,
142+
sfn_role_arn=sfn_role_arn
143+
)
144+
145+
inputs = {
146+
'JobName': generate_job_name(),
147+
'HyperParameters': {
148+
"num_components": "48",
149+
"feature_dim": "784",
150+
"mini_batch_size": "250"
151+
},
152+
'InstanceCount': INSTANCE_COUNT,
153+
'InstanceType': INSTANCE_TYPE,
154+
'MaxRun': 100000
155+
}
156+
157+
# Execute workflow
158+
execution = workflow.execute(inputs=inputs)
159+
execution_output = execution.get_output(wait=True)
160+
161+
# Check workflow output
162+
assert execution_output.get("TrainingJobStatus") == "Completed"
163+
164+
# Cleanup
165+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
166+
167+
110168
def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
111169
# Build workflow definition
112170
model_name = generate_job_name()

‎tests/unit/test_sagemaker_steps.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -336,68 +336,100 @@ def test_training_step_creation_with_placeholders(pca_estimator):
336336
execution_input = ExecutionInput(schema={
337337
'Data': str,
338338
'OutputPath': str,
339-
'HyperParameters': str
339+
'HyperParameters': str,
340+
'ExperimentConfig': str,
341+
'Tags': str,
342+
'InstanceCount': int,
343+
'InstanceType': str,
344+
'MaxRun': int,
345+
'MetricDefinitions': str,
346+
'MaxWait': int,
347+
'CheckpointS3Uri': str,
348+
'CheckpointLocalPath': str,
349+
'EnableSagemakerMetrics': bool,
350+
'EnableNetworkIsolation': bool,
351+
'Environment': str
340352
})
341353

342354
step_input = StepInput(schema={
343355
'JobName': str,
344356
})
345357

346-
step = TrainingStep('Training',
347-
estimator=pca_estimator,
348-
job_name=step_input['JobName'],
349-
data=execution_input['Data'],
350-
output_data_config_path=execution_input['OutputPath'],
351-
experiment_config={
352-
'ExperimentName': 'pca_experiment',
353-
'TrialName': 'pca_trial',
354-
'TrialComponentDisplayName': 'Training'
355-
},
356-
tags=DEFAULT_TAGS,
357-
hyperparameters=execution_input['HyperParameters']
358-
)
359-
assert step.to_dict() == {
360-
'Type': 'Task',
361-
'Parameters': {
358+
parameters = {
362359
'AlgorithmSpecification': {
363360
'TrainingImage': PCA_IMAGE,
364-
'TrainingInputMode': 'File'
361+
'TrainingInputMode': 'File',
362+
'MetricDefinitions': execution_input['MetricDefinitions'],
363+
'EnableSageMakerMetricsTimeSeries': execution_input['EnableSagemakerMetrics']
365364
},
366-
'OutputDataConfig': {
367-
'S3OutputPath.$': "$$.Execution.Input['OutputPath']"
365+
'CheckpointConfig': {
366+
'S3Uri': execution_input['CheckpointS3Uri'],
367+
'LocalPath': execution_input['CheckpointLocalPath']
368368
},
369+
'EnableNetworkIsolation': execution_input['EnableNetworkIsolation'],
369370
'StoppingCondition': {
370-
'MaxRuntimeInSeconds': 86400
371+
'MaxRuntimeInSeconds': execution_input['MaxRun'],
372+
'MaxWaitTimeInSeconds': execution_input['MaxWait']
371373
},
372374
'ResourceConfig': {
373-
'InstanceCount': 1,
374-
'InstanceType': 'ml.c4.xlarge',
375-
'VolumeSizeInGB': 30
375+
'InstanceCount': execution_input['InstanceCount'],
376+
'InstanceType': execution_input['InstanceType']
376377
},
377-
'RoleArn': EXECUTION_ROLE,
378-
'HyperParameters.$': "$$.Execution.Input['HyperParameters']",
379-
'InputDataConfig': [
380-
{
381-
'ChannelName': 'training',
382-
'DataSource': {
383-
'S3DataSource': {
384-
'S3DataDistributionType': 'FullyReplicated',
385-
'S3DataType': 'S3Prefix',
386-
'S3Uri.$': "$$.Execution.Input['Data']"
387-
}
378+
'Environment': execution_input['Environment'],
379+
'ExperimentConfig': execution_input['ExperimentConfig']
380+
}
381+
382+
step = TrainingStep('Training',
383+
estimator=pca_estimator,
384+
job_name=step_input['JobName'],
385+
data=execution_input['Data'],
386+
output_data_config_path=execution_input['OutputPath'],
387+
experiment_config=execution_input['ExperimentConfig'],
388+
tags=execution_input['Tags'],
389+
mini_batch_size=1000,
390+
hyperparameters=execution_input['HyperParameters'],
391+
parameters=parameters
392+
)
393+
assert step.to_dict()['Parameters'] == {
394+
'AlgorithmSpecification': {
395+
'EnableSageMakerMetricsTimeSeries.$': "$$.Execution.Input['EnableSagemakerMetrics']",
396+
'MetricDefinitions.$': "$$.Execution.Input['MetricDefinitions']",
397+
'TrainingImage': PCA_IMAGE,
398+
'TrainingInputMode': 'File'
399+
},
400+
'CheckpointConfig': {'LocalPath.$': "$$.Execution.Input['CheckpointLocalPath']",
401+
'S3Uri.$': "$$.Execution.Input['CheckpointS3Uri']"},
402+
'EnableNetworkIsolation.$': "$$.Execution.Input['EnableNetworkIsolation']",
403+
'Environment.$': "$$.Execution.Input['Environment']",
404+
'OutputDataConfig': {
405+
'S3OutputPath.$': "$$.Execution.Input['OutputPath']"
406+
},
407+
'StoppingCondition': {
408+
'MaxRuntimeInSeconds.$': "$$.Execution.Input['MaxRun']",
409+
'MaxWaitTimeInSeconds.$': "$$.Execution.Input['MaxWait']"
410+
},
411+
'ResourceConfig': {
412+
'InstanceCount.$': "$$.Execution.Input['InstanceCount']",
413+
'InstanceType.$': "$$.Execution.Input['InstanceType']",
414+
'VolumeSizeInGB': 30
415+
},
416+
'RoleArn': EXECUTION_ROLE,
417+
'HyperParameters.$': "$$.Execution.Input['HyperParameters']",
418+
'InputDataConfig': [
419+
{
420+
'ChannelName': 'training',
421+
'DataSource': {
422+
'S3DataSource': {
423+
'S3DataDistributionType': 'FullyReplicated',
424+
'S3DataType': 'S3Prefix',
425+
'S3Uri.$': "$$.Execution.Input['Data']"
388426
}
389427
}
390-
],
391-
'ExperimentConfig': {
392-
'ExperimentName': 'pca_experiment',
393-
'TrialName': 'pca_trial',
394-
'TrialComponentDisplayName': 'Training'
395-
},
396-
'TrainingJobName.$': "$['JobName']",
397-
'Tags': DEFAULT_TAGS_LIST
398-
},
399-
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
400-
'End': True
428+
}
429+
],
430+
'ExperimentConfig.$': "$$.Execution.Input['ExperimentConfig']",
431+
'TrainingJobName.$': "$['JobName']",
432+
'Tags.$': "$$.Execution.Input['Tags']"
401433
}
402434

403435

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /