Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 71ffc5b

Browse files
feat: Support placeholders for TuningStep parameters
1 parent 23878de commit 71ffc5b

File tree

3 files changed

+364
-9
lines changed

3 files changed

+364
-9
lines changed

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
444444
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
445445
where each instance is a different channel of training data.
446446
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
447-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
447+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
448+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
449+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
450+
448451
"""
449452
if wait_for_completion:
450453
"""
@@ -462,19 +465,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
462465
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
463466
SageMakerApi.CreateHyperParameterTuningJob)
464467

465-
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
468+
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
466469

467470
if job_name is not None:
468-
parameters['HyperParameterTuningJobName'] = job_name
471+
tuning_parameters['HyperParameterTuningJobName'] = job_name
469472

470-
if 'S3Operations' in parameters:
471-
del parameters['S3Operations']
473+
if 'S3Operations' in tuning_parameters:
474+
del tuning_parameters['S3Operations']
472475

473476
if tags:
474-
parameters['Tags'] = tags_dict_to_kv_list(tags)
477+
tuning_parameters['Tags'] =tagsifisinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
475478

476-
kwargs[Field.Parameters.value] = parameters
479+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
480+
# Update tuning parameters with input parameters
481+
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
477482

483+
kwargs[Field.Parameters.value] = tuning_parameters
478484
super(TuningStep, self).__init__(state_id, **kwargs)
479485

480486

‎tests/integ/test_sagemaker_steps.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
257257
delete_sagemaker_model(model.name, sagemaker_session)
258258
# End of Cleanup
259259

260+
260261
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
261262
job_name = generate_job_name()
262263

@@ -308,6 +309,123 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
308309
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
309310
# End of Cleanup
310311

312+
313+
def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
314+
kmeans = KMeans(
315+
role=sagemaker_role_arn,
316+
instance_count=1,
317+
instance_type=INSTANCE_TYPE,
318+
k=10
319+
)
320+
321+
hyperparameter_ranges = {
322+
"extra_center_factor": IntegerParameter(4, 10),
323+
"mini_batch_size": IntegerParameter(10, 100),
324+
"epochs": IntegerParameter(1, 2),
325+
"init_method": CategoricalParameter(["kmeans++", "random"]),
326+
}
327+
328+
tuner = HyperparameterTuner(
329+
estimator=kmeans,
330+
objective_metric_name="test:msd",
331+
hyperparameter_ranges=hyperparameter_ranges,
332+
objective_type="Maximize",
333+
max_jobs=2,
334+
max_parallel_jobs=1,
335+
)
336+
337+
execution_input = ExecutionInput(schema={
338+
'job_name': str,
339+
'data_input': str,
340+
'objective_metric_name': str,
341+
'objective_type': str,
342+
'max_jobs': int,
343+
'max_parallel_jobs': int,
344+
'early_stopping_type': str,
345+
'strategy': str,
346+
})
347+
348+
parameters = {
349+
'HyperParameterTuningJobConfig': {
350+
'HyperParameterTuningJobObjective': {
351+
'MetricName': execution_input['objective_metric_name'],
352+
'Type': execution_input['objective_type']
353+
},
354+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
355+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
356+
'Strategy': execution_input['strategy'],
357+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
358+
},
359+
'TrainingJobDefinition': {
360+
'AlgorithmSpecification': {
361+
'TrainingInputMode': 'File'
362+
},
363+
'InputDataConfig': execution_input['data_input']
364+
}
365+
}
366+
367+
# Build workflow definition
368+
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
369+
data=record_set_for_hyperparameter_tuning, parameters=parameters)
370+
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
371+
workflow_graph = Chain([tuning_step])
372+
373+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
374+
# Create workflow and check definition
375+
workflow = create_workflow_and_check_definition(
376+
workflow_graph=workflow_graph,
377+
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
378+
sfn_client=sfn_client,
379+
sfn_role_arn=sfn_role_arn
380+
)
381+
382+
job_name = generate_job_name()
383+
data_input = [
384+
{
385+
"DataSource": {
386+
"S3DataSource": {
387+
"S3DataType": "ManifestFile",
388+
"S3Uri": "s3://sagemaker-us-east-1-585192044892/sagemaker-record-sets/PCA-2021年10月19日-00-19-10-799/.amazon.manifest",
389+
"S3DataDistributionType": "ShardedByS3Key"
390+
}
391+
},
392+
"ChannelName": "train"
393+
},
394+
{
395+
"DataSource": {
396+
"S3DataSource": {
397+
"S3DataType": "ManifestFile",
398+
"S3Uri": "s3://sagemaker-us-east-1-585192044892/sagemaker-record-sets/PCA-2021年10月19日-00-19-15-087/.amazon.manifest",
399+
"S3DataDistributionType": "ShardedByS3Key"
400+
}
401+
},
402+
"ChannelName": "test"
403+
}
404+
]
405+
406+
inputs = {
407+
'job_name': job_name,
408+
'data_input': data_input,
409+
'objective_metric_name': 'test:msd',
410+
'objective_type': 'Minimize',
411+
'max_jobs': 2,
412+
'max_parallel_jobs': 2,
413+
'early_stopping_type': 'Off',
414+
'strategy': 'Bayesian',
415+
}
416+
417+
# Execute workflow
418+
execution = workflow.execute(inputs=inputs)
419+
execution_output = execution.get_output(wait=True)
420+
421+
# Check workflow output
422+
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
423+
424+
# Cleanup
425+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
426+
# End of Cleanup
427+
428+
311429
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
312430
region = boto3.session.Session().region_name
313431
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /