Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 39640f8

Browse files
feat: Support placeholders for TuningStep parameters
1 parent 8b6d0eb commit 39640f8

File tree

3 files changed

+339
-9
lines changed

3 files changed

+339
-9
lines changed

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
454454
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
455455
where each instance is a different channel of training data.
456456
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
457-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
457+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
458+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
459+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
460+
458461
"""
459462
if wait_for_completion:
460463
"""
@@ -472,19 +475,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
472475
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
473476
SageMakerApi.CreateHyperParameterTuningJob)
474477

475-
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
478+
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
476479

477480
if job_name is not None:
478-
parameters['HyperParameterTuningJobName'] = job_name
481+
tuning_parameters['HyperParameterTuningJobName'] = job_name
479482

480-
if 'S3Operations' in parameters:
481-
del parameters['S3Operations']
483+
if 'S3Operations' in tuning_parameters:
484+
del tuning_parameters['S3Operations']
482485

483486
if tags:
484-
parameters['Tags'] = tags_dict_to_kv_list(tags)
487+
tuning_parameters['Tags'] =tagsifisinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
485488

486-
kwargs[Field.Parameters.value] = parameters
489+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
490+
# Update tuning parameters with input parameters
491+
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
487492

493+
kwargs[Field.Parameters.value] = tuning_parameters
488494
super(TuningStep, self).__init__(state_id, **kwargs)
489495

490496

‎tests/integ/test_sagemaker_steps.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
347347
delete_sagemaker_model(model.name, sagemaker_session)
348348
# End of Cleanup
349349

350+
350351
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
351352
job_name = generate_job_name()
352353

@@ -398,6 +399,98 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
398399
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
399400
# End of Cleanup
400401

402+
403+
def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
404+
kmeans = KMeans(
405+
role=sagemaker_role_arn,
406+
instance_count=1,
407+
instance_type=INSTANCE_TYPE,
408+
k=10
409+
)
410+
411+
hyperparameter_ranges = {
412+
"extra_center_factor": IntegerParameter(4, 10),
413+
"mini_batch_size": IntegerParameter(10, 100),
414+
"epochs": IntegerParameter(1, 2),
415+
"init_method": CategoricalParameter(["kmeans++", "random"]),
416+
}
417+
418+
tuner = HyperparameterTuner(
419+
estimator=kmeans,
420+
objective_metric_name="test:msd",
421+
hyperparameter_ranges=hyperparameter_ranges,
422+
objective_type="Maximize",
423+
max_jobs=2,
424+
max_parallel_jobs=1,
425+
)
426+
427+
execution_input = ExecutionInput(schema={
428+
'job_name': str,
429+
'objective_metric_name': str,
430+
'objective_type': str,
431+
'max_jobs': int,
432+
'max_parallel_jobs': int,
433+
'early_stopping_type': str,
434+
'strategy': str,
435+
})
436+
437+
parameters = {
438+
'HyperParameterTuningJobConfig': {
439+
'HyperParameterTuningJobObjective': {
440+
'MetricName': execution_input['objective_metric_name'],
441+
'Type': execution_input['objective_type']
442+
},
443+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
444+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
445+
'Strategy': execution_input['strategy'],
446+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
447+
},
448+
'TrainingJobDefinition': {
449+
'AlgorithmSpecification': {
450+
'TrainingInputMode': 'File'
451+
}
452+
}
453+
}
454+
455+
# Build workflow definition
456+
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
457+
data=record_set_for_hyperparameter_tuning, parameters=parameters)
458+
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
459+
workflow_graph = Chain([tuning_step])
460+
461+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
462+
# Create workflow and check definition
463+
workflow = create_workflow_and_check_definition(
464+
workflow_graph=workflow_graph,
465+
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
466+
sfn_client=sfn_client,
467+
sfn_role_arn=sfn_role_arn
468+
)
469+
470+
job_name = generate_job_name()
471+
472+
inputs = {
473+
'job_name': job_name,
474+
'objective_metric_name': 'test:msd',
475+
'objective_type': 'Minimize',
476+
'max_jobs': 2,
477+
'max_parallel_jobs': 2,
478+
'early_stopping_type': 'Off',
479+
'strategy': 'Bayesian',
480+
}
481+
482+
# Execute workflow
483+
execution = workflow.execute(inputs=inputs)
484+
execution_output = execution.get_output(wait=True)
485+
486+
# Check workflow output
487+
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
488+
489+
# Cleanup
490+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
491+
# End of Cleanup
492+
493+
401494
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
402495
region = boto3.session.Session().region_name
403496
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)

‎tests/unit/test_sagemaker_steps.py

Lines changed: 233 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from sagemaker.debugger import Rule, rule_configs, DebuggerHookConfig, CollectionConfig
2525
from sagemaker.sklearn.processing import SKLearnProcessor
2626
from sagemaker.processing import ProcessingInput, ProcessingOutput
27+
from sagemaker.parameter import IntegerParameter, CategoricalParameter
28+
from sagemaker.tuner import HyperparameterTuner
2729

2830
from unittest.mock import MagicMock, patch
2931
from stepfunctions.inputs import ExecutionInput, StepInput
30-
from stepfunctions.steps.fields import Field
3132
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\
32-
ProcessingStep
33+
ProcessingStep, TuningStep
3334
from stepfunctions.steps.sagemaker import tuning_config
3435

3536
from tests.unit.utils import mock_boto_api_call
@@ -1412,3 +1413,233 @@ def test_processing_step_creation_with_placeholders(sklearn_processor):
14121413
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
14131414
'End': True
14141415
}
1416+
1417+
1418+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
1419+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1420+
def test_tuning_step_creation_with_framework(tensorflow_estimator):
1421+
hyperparameter_ranges = {
1422+
"extra_center_factor": IntegerParameter(4, 10),
1423+
"epochs": IntegerParameter(1, 2),
1424+
"init_method": CategoricalParameter(["kmeans++", "random"]),
1425+
}
1426+
1427+
tuner = HyperparameterTuner(
1428+
estimator=tensorflow_estimator,
1429+
objective_metric_name="test:msd",
1430+
hyperparameter_ranges=hyperparameter_ranges,
1431+
objective_type="Minimize",
1432+
max_jobs=2,
1433+
max_parallel_jobs=2,
1434+
)
1435+
1436+
step = TuningStep('Tuning',
1437+
tuner=tuner,
1438+
data={'train': 's3://sagemaker/train'},
1439+
job_name='tensorflow-job',
1440+
tags=DEFAULT_TAGS
1441+
)
1442+
1443+
state_machine_definition = step.to_dict()
1444+
# The sagemaker_job_name is generated - expected name will be taken from the generated definition
1445+
generated_sagemaker_job_name = state_machine_definition['Parameters']['TrainingJobDefinition']\
1446+
['StaticHyperParameters']['sagemaker_job_name']
1447+
expected_definition = {
1448+
'Type': 'Task',
1449+
'Parameters': {
1450+
'HyperParameterTuningJobConfig': {
1451+
'HyperParameterTuningJobObjective': {
1452+
'MetricName': 'test:msd',
1453+
'Type': 'Minimize'
1454+
},
1455+
'ParameterRanges': {
1456+
'CategoricalParameterRanges': [
1457+
{
1458+
'Name': 'init_method',
1459+
'Values': ['"kmeans++"', '"random"']
1460+
}],
1461+
'ContinuousParameterRanges': [],
1462+
'IntegerParameterRanges': [
1463+
{
1464+
'MaxValue': '10',
1465+
'MinValue': '4',
1466+
'Name': 'extra_center_factor',
1467+
'ScalingType': 'Auto'
1468+
},
1469+
{
1470+
'MaxValue': '2',
1471+
'MinValue': '1',
1472+
'Name': 'epochs',
1473+
'ScalingType': 'Auto'
1474+
}
1475+
]
1476+
},
1477+
'ResourceLimits': {'MaxNumberOfTrainingJobs': 2,
1478+
'MaxParallelTrainingJobs': 2},
1479+
'Strategy': 'Bayesian',
1480+
'TrainingJobEarlyStoppingType': 'Off'
1481+
},
1482+
'HyperParameterTuningJobName': 'tensorflow-job',
1483+
'Tags': [{'Key': 'Purpose', 'Value': 'unittests'}],
1484+
'TrainingJobDefinition': {
1485+
'AlgorithmSpecification': {
1486+
'TrainingImage': '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2',
1487+
'TrainingInputMode': 'File'
1488+
},
1489+
'InputDataConfig': [{'ChannelName': 'train',
1490+
'DataSource': {'S3DataSource': {
1491+
'S3DataDistributionType': 'FullyReplicated',
1492+
'S3DataType': 'S3Prefix',
1493+
'S3Uri': 's3://sagemaker/train'}}}],
1494+
'OutputDataConfig': {'S3OutputPath': 's3://sagemaker/models'},
1495+
'ResourceConfig': {'InstanceCount': 1,
1496+
'InstanceType': 'ml.p2.xlarge',
1497+
'VolumeSizeInGB': 30},
1498+
'RoleArn': 'execution-role',
1499+
'StaticHyperParameters': {
1500+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
1501+
'evaluation_steps': '100',
1502+
'sagemaker_container_log_level': '20',
1503+
'sagemaker_estimator_class_name': '"TensorFlow"',
1504+
'sagemaker_estimator_module': '"sagemaker.tensorflow.estimator"',
1505+
'sagemaker_job_name': generated_sagemaker_job_name,
1506+
'sagemaker_program': '"tf_train.py"',
1507+
'sagemaker_region': '"us-east-1"',
1508+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
1509+
'training_steps': '1000'},
1510+
'StoppingCondition': {'MaxRuntimeInSeconds': 86400}}},
1511+
'Resource': 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync',
1512+
'End': True
1513+
}
1514+
1515+
assert state_machine_definition == expected_definition
1516+
1517+
1518+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
1519+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1520+
def test_tuning_step_creation_with_placeholders(tensorflow_estimator):
1521+
execution_input = ExecutionInput(schema={
1522+
'data_input': str,
1523+
'tags': list,
1524+
'objective_metric_name': str,
1525+
'hyperparameter_ranges': str,
1526+
'objective_type': str,
1527+
'max_jobs': int,
1528+
'max_parallel_jobs': int,
1529+
'early_stopping_type': str,
1530+
'strategy': str,
1531+
})
1532+
1533+
step_input = StepInput(schema={
1534+
'job_name': str
1535+
})
1536+
1537+
hyperparameter_ranges = {
1538+
"extra_center_factor": IntegerParameter(4, 10),
1539+
"epochs": IntegerParameter(1, 2),
1540+
"init_method": CategoricalParameter(["kmeans++", "random"]),
1541+
}
1542+
1543+
tuner = HyperparameterTuner(
1544+
estimator=tensorflow_estimator,
1545+
objective_metric_name="test:msd",
1546+
hyperparameter_ranges=hyperparameter_ranges,
1547+
objective_type="Minimize",
1548+
max_jobs=2,
1549+
max_parallel_jobs=2,
1550+
)
1551+
1552+
parameters = {
1553+
'HyperParameterTuningJobConfig': {
1554+
'HyperParameterTuningJobObjective': {
1555+
'MetricName': execution_input['objective_metric_name'],
1556+
'Type': execution_input['objective_type']
1557+
},
1558+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
1559+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
1560+
'Strategy': execution_input['strategy'],
1561+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
1562+
},
1563+
'TrainingJobDefinition': {
1564+
'AlgorithmSpecification': {
1565+
'TrainingInputMode': 'File'
1566+
},
1567+
'HyperParameterRanges': execution_input['hyperparameter_ranges'],
1568+
'InputDataConfig': execution_input['data_input']
1569+
}
1570+
}
1571+
1572+
step = TuningStep('Tuning',
1573+
tuner=tuner,
1574+
data={'train': 's3://sagemaker/train'},
1575+
job_name=step_input['job_name'],
1576+
tags=execution_input['tags'],
1577+
parameters=parameters
1578+
)
1579+
1580+
state_machine_definition = step.to_dict()
1581+
# The sagemaker_job_name is generated - expected name will be taken from the generated definition
1582+
generated_sagemaker_job_name = state_machine_definition['Parameters']['TrainingJobDefinition']['StaticHyperParameters']['sagemaker_job_name']
1583+
expected_parameters = {
1584+
'HyperParameterTuningJobConfig': {
1585+
'HyperParameterTuningJobObjective': {
1586+
'MetricName.$': "$$.Execution.Input['objective_metric_name']",
1587+
'Type.$': "$$.Execution.Input['objective_type']"
1588+
},
1589+
'ParameterRanges': {
1590+
'CategoricalParameterRanges': [
1591+
{
1592+
'Name': 'init_method',
1593+
'Values': ['"kmeans++"', '"random"']
1594+
}],
1595+
'ContinuousParameterRanges': [],
1596+
'IntegerParameterRanges': [
1597+
{
1598+
'MaxValue': '10',
1599+
'MinValue': '4',
1600+
'Name': 'extra_center_factor',
1601+
'ScalingType': 'Auto'
1602+
},
1603+
{
1604+
'MaxValue': '2',
1605+
'MinValue': '1',
1606+
'Name': 'epochs',
1607+
'ScalingType': 'Auto'
1608+
}
1609+
]
1610+
},
1611+
'ResourceLimits': {'MaxNumberOfTrainingJobs.$': "$$.Execution.Input['max_jobs']",
1612+
'MaxParallelTrainingJobs.$': "$$.Execution.Input['max_parallel_jobs']"},
1613+
'Strategy.$': "$$.Execution.Input['strategy']",
1614+
'TrainingJobEarlyStoppingType.$': "$$.Execution.Input['early_stopping_type']"
1615+
},
1616+
'HyperParameterTuningJobName.$': "$['job_name']",
1617+
'Tags.$': "$$.Execution.Input['tags']",
1618+
'TrainingJobDefinition': {
1619+
'AlgorithmSpecification': {
1620+
'TrainingImage': '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2',
1621+
'TrainingInputMode': 'File'
1622+
},
1623+
'HyperParameterRanges.$': "$$.Execution.Input['hyperparameter_ranges']",
1624+
'InputDataConfig.$': "$$.Execution.Input['data_input']",
1625+
'OutputDataConfig': {'S3OutputPath': 's3://sagemaker/models'},
1626+
'ResourceConfig': {'InstanceCount': 1,
1627+
'InstanceType': 'ml.p2.xlarge',
1628+
'VolumeSizeInGB': 30},
1629+
'RoleArn': 'execution-role',
1630+
'StaticHyperParameters': {
1631+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
1632+
'evaluation_steps': '100',
1633+
'sagemaker_container_log_level': '20',
1634+
'sagemaker_estimator_class_name': '"TensorFlow"',
1635+
'sagemaker_estimator_module': '"sagemaker.tensorflow.estimator"',
1636+
'sagemaker_job_name': generated_sagemaker_job_name,
1637+
'sagemaker_program': '"tf_train.py"',
1638+
'sagemaker_region': '"us-east-1"',
1639+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
1640+
'training_steps': '1000'},
1641+
'StoppingCondition': {'MaxRuntimeInSeconds': 86400}
1642+
}
1643+
}
1644+
1645+
assert state_machine_definition['Parameters'] == expected_parameters

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /