Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1ea8346

Browse files
feat: Support placeholders for TuningStep (#173)
1 parent 2091850 commit 1ea8346

File tree

3 files changed

+340
-16
lines changed

3 files changed

+340
-16
lines changed

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
465465
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
466466
where each instance is a different channel of training data.
467467
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
468-
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
468+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
469+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
470+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
471+
469472
"""
470473
if wait_for_completion:
471474
"""
@@ -483,19 +486,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
483486
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
484487
SageMakerApi.CreateHyperParameterTuningJob)
485488

486-
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
489+
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
487490

488491
if job_name is not None:
489-
parameters['HyperParameterTuningJobName'] = job_name
492+
tuning_parameters['HyperParameterTuningJobName'] = job_name
490493

491-
if 'S3Operations' in parameters:
492-
del parameters['S3Operations']
494+
if 'S3Operations' in tuning_parameters:
495+
del tuning_parameters['S3Operations']
493496

494497
if tags:
495-
parameters['Tags'] = tags_dict_to_kv_list(tags)
498+
tuning_parameters['Tags'] =tagsifisinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
496499

497-
kwargs[Field.Parameters.value] = parameters
500+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
501+
# Update tuning parameters with input parameters
502+
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
498503

504+
kwargs[Field.Parameters.value] = tuning_parameters
499505
super(TuningStep, self).__init__(state_id, **kwargs)
500506

501507

‎tests/integ/test_sagemaker_steps.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf
104104

105105
# Cleanup
106106
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
107-
# End of Cleanup
108107

109108

110109
def test_training_step_with_placeholders(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn):
@@ -193,7 +192,7 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
193192
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
194193
model_name = get_resource_name_from_arn(execution_output.get("ModelArn")).split("/")[1]
195194
delete_sagemaker_model(model_name, sagemaker_session)
196-
# End of Cleanup
195+
197196

198197

199198
def test_model_step_with_placeholders(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
@@ -288,7 +287,6 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
288287

289288
# Cleanup
290289
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
291-
# End of Cleanup
292290

293291

294292
def test_transform_step_with_placeholder(trained_estimator, sfn_client, sfn_role_arn):
@@ -413,7 +411,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
413411
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
414412
delete_sagemaker_endpoint_config(endpoint_config_name, sagemaker_session)
415413
delete_sagemaker_model(model.name, sagemaker_session)
416-
# End of Cleanup
414+
417415

418416
def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client, sagemaker_session, sfn_role_arn):
419417
# Setup: Create model and endpoint config for trained estimator in SageMaker
@@ -456,7 +454,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
456454
delete_sagemaker_endpoint(endpoint_name, sagemaker_session)
457455
delete_sagemaker_endpoint_config(model.name, sagemaker_session)
458456
delete_sagemaker_model(model.name, sagemaker_session)
459-
# End of Cleanup
457+
460458

461459
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
462460
job_name = generate_job_name()
@@ -507,7 +505,97 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
507505

508506
# Cleanup
509507
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
510-
# End of Cleanup
508+
509+
510+
def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
511+
kmeans = KMeans(
512+
role=sagemaker_role_arn,
513+
instance_count=1,
514+
instance_type=INSTANCE_TYPE,
515+
k=10
516+
)
517+
518+
hyperparameter_ranges = {
519+
"extra_center_factor": IntegerParameter(4, 10),
520+
"mini_batch_size": IntegerParameter(10, 100),
521+
"epochs": IntegerParameter(1, 2),
522+
"init_method": CategoricalParameter(["kmeans++", "random"]),
523+
}
524+
525+
tuner = HyperparameterTuner(
526+
estimator=kmeans,
527+
objective_metric_name="test:msd",
528+
hyperparameter_ranges=hyperparameter_ranges,
529+
objective_type="Maximize",
530+
max_jobs=2,
531+
max_parallel_jobs=1,
532+
)
533+
534+
execution_input = ExecutionInput(schema={
535+
'job_name': str,
536+
'objective_metric_name': str,
537+
'objective_type': str,
538+
'max_jobs': int,
539+
'max_parallel_jobs': int,
540+
'early_stopping_type': str,
541+
'strategy': str,
542+
})
543+
544+
parameters = {
545+
'HyperParameterTuningJobConfig': {
546+
'HyperParameterTuningJobObjective': {
547+
'MetricName': execution_input['objective_metric_name'],
548+
'Type': execution_input['objective_type']
549+
},
550+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
551+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
552+
'Strategy': execution_input['strategy'],
553+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
554+
},
555+
'TrainingJobDefinition': {
556+
'AlgorithmSpecification': {
557+
'TrainingInputMode': 'File'
558+
}
559+
}
560+
}
561+
562+
# Build workflow definition
563+
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
564+
data=record_set_for_hyperparameter_tuning, parameters=parameters)
565+
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
566+
workflow_graph = Chain([tuning_step])
567+
568+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
569+
# Create workflow and check definition
570+
workflow = create_workflow_and_check_definition(
571+
workflow_graph=workflow_graph,
572+
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
573+
sfn_client=sfn_client,
574+
sfn_role_arn=sfn_role_arn
575+
)
576+
577+
job_name = generate_job_name()
578+
579+
inputs = {
580+
'job_name': job_name,
581+
'objective_metric_name': 'test:msd',
582+
'objective_type': 'Minimize',
583+
'max_jobs': 2,
584+
'max_parallel_jobs': 2,
585+
'early_stopping_type': 'Off',
586+
'strategy': 'Bayesian',
587+
}
588+
589+
# Execute workflow
590+
execution = workflow.execute(inputs=inputs)
591+
execution_output = execution.get_output(wait=True)
592+
593+
# Check workflow output
594+
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
595+
596+
# Cleanup
597+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
598+
511599

512600
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
513601
region = boto3.session.Session().region_name
@@ -561,7 +649,6 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
561649

562650
# Cleanup
563651
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
564-
# End of Cleanup
565652

566653

567654
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /