Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 95538ac

Browse files
Merge branch 'main' into add-support-for-eks
2 parents c81ee56 + 01e18c3 commit 95538ac

File tree

6 files changed

+321
-21
lines changed

6 files changed

+321
-21
lines changed

‎src/stepfunctions/steps/fields.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class Field(Enum):
5959
HeartbeatSeconds = 'heartbeat_seconds'
6060
HeartbeatSecondsPath = 'heartbeat_seconds_path'
6161

62-
6362
# Retry and catch fields
6463
ErrorEquals = 'error_equals'
6564
IntervalSeconds = 'interval_seconds'

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from stepfunctions.inputs import Placeholder
2020
from stepfunctions.steps.states import Task
2121
from stepfunctions.steps.fields import Field
22-
from stepfunctions.steps.utils import tags_dict_to_kv_list
22+
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list
2323
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
2424

2525
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
@@ -30,6 +30,7 @@
3030

3131
SAGEMAKER_SERVICE_NAME = "sagemaker"
3232

33+
3334
class SageMakerApi(Enum):
3435
CreateTrainingJob = "createTrainingJob"
3536
CreateTransformJob = "createTransformJob"
@@ -479,7 +480,9 @@ class ProcessingStep(Task):
479480
Creates a Task State to execute a SageMaker Processing Job.
480481
"""
481482

482-
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
483+
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
484+
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
485+
tags=None, **kwargs):
483486
"""
484487
Args:
485488
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -491,15 +494,18 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
491494
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
492495
the processing job. These can be specified as either path strings or
493496
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
494-
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
495-
container_arguments ([str]): The arguments for a container used to run a processing job.
496-
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
497-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
497+
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None)
498+
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job.
499+
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job.
500+
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
498501
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
499502
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
500503
The KmsKeyId is applied to all outputs.
501504
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
502-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
505+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
506+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
507+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
508+
503509
"""
504510
if wait_for_completion:
505511
"""
@@ -518,22 +524,25 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
518524
SageMakerApi.CreateProcessingJob)
519525

520526
if isinstance(job_name, str):
521-
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
527+
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
522528
else:
523-
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
529+
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
524530

525531
if isinstance(job_name, Placeholder):
526-
parameters['ProcessingJobName'] = job_name
532+
processing_parameters['ProcessingJobName'] = job_name
527533

528534
if experiment_config is not None:
529-
parameters['ExperimentConfig'] = experiment_config
530-
535+
processing_parameters['ExperimentConfig'] = experiment_config
536+
531537
if tags:
532-
parameters['Tags'] = tags_dict_to_kv_list(tags)
533-
534-
if 'S3Operations' in parameters:
535-
del parameters['S3Operations']
536-
537-
kwargs[Field.Parameters.value] = parameters
538+
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
539+
540+
if 'S3Operations' in processing_parameters:
541+
del processing_parameters['S3Operations']
542+
543+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
544+
# Update processing_parameters with input parameters
545+
merge_dicts(processing_parameters, kwargs[Field.Parameters.value])
538546

547+
kwargs[Field.Parameters.value] = processing_parameters
539548
super(ProcessingStep, self).__init__(state_id, **kwargs)

‎src/stepfunctions/steps/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import boto3
1616
import logging
17+
from stepfunctions.inputs import Placeholder
1718

1819
logger = logging.getLogger('stepfunctions')
1920

@@ -45,3 +46,28 @@ def get_aws_partition():
4546
return cur_partition
4647

4748
return cur_partition
49+
50+
51+
def merge_dicts(target, source):
52+
"""
53+
Merges source dictionary into the target dictionary.
54+
Values in the target dict are updated with the values of the source dict.
55+
Args:
56+
target (dict): Base dictionary into which source is merged
57+
source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value
58+
will overwrite target's value for the corresponding key
59+
"""
60+
if isinstance(target, dict) and isinstance(source, dict):
61+
for key, value in source.items():
62+
if key in target:
63+
if isinstance(target[key], dict) and isinstance(source[key], dict):
64+
merge_dicts(target[key], source[key])
65+
elif target[key] == value:
66+
pass
67+
else:
68+
logger.info(
69+
f"Property: <{key}> with value: <{target[key]}>"
70+
f" will be overwritten with provided value: <{value}>")
71+
target[key] = source[key]
72+
else:
73+
target[key] = source[key]

‎tests/integ/test_sagemaker_steps.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.tuner import HyperparameterTuner
3030
from sagemaker.processing import ProcessingInput, ProcessingOutput
3131

32+
from stepfunctions.inputs import ExecutionInput
3233
from stepfunctions.steps import Chain
3334
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
3435
from stepfunctions.workflow import Workflow
@@ -352,3 +353,98 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
352353
# Cleanup
353354
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
354355
# End of Cleanup
356+
357+
358+
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
359+
sagemaker_role_arn):
360+
region = boto3.session.Session().region_name
361+
input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv"
362+
363+
input_s3 = sagemaker_session.upload_data(
364+
path=os.path.join(DATA_DIR, 'sklearn_processing'),
365+
bucket=sagemaker_session.default_bucket(),
366+
key_prefix='integ-test-data/sklearn_processing/code'
367+
)
368+
369+
output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing"
370+
371+
inputs = [
372+
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
373+
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code',
374+
input_name='code'),
375+
]
376+
377+
outputs = [
378+
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data',
379+
output_name='train_data'),
380+
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data',
381+
output_name='test_data'),
382+
]
383+
384+
# Build workflow definition
385+
execution_input = ExecutionInput(schema={
386+
'image_uri': str,
387+
'instance_count': int,
388+
'entrypoint': str,
389+
'role': str,
390+
'volume_size_in_gb': int,
391+
'max_runtime_in_seconds': int,
392+
'container_arguments': [str],
393+
})
394+
395+
parameters = {
396+
'AppSpecification': {
397+
'ContainerEntrypoint': execution_input['entrypoint'],
398+
'ImageUri': execution_input['image_uri']
399+
},
400+
'ProcessingResources': {
401+
'ClusterConfig': {
402+
'InstanceCount': execution_input['instance_count'],
403+
'VolumeSizeInGB': execution_input['volume_size_in_gb']
404+
}
405+
},
406+
'RoleArn': execution_input['role'],
407+
'StoppingCondition': {
408+
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
409+
}
410+
}
411+
412+
job_name = generate_job_name()
413+
processing_step = ProcessingStep('create_processing_job_step',
414+
processor=sklearn_processor_fixture,
415+
job_name=job_name,
416+
inputs=inputs,
417+
outputs=outputs,
418+
container_arguments=execution_input['container_arguments'],
419+
container_entrypoint=execution_input['entrypoint'],
420+
parameters=parameters
421+
)
422+
workflow_graph = Chain([processing_step])
423+
424+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
425+
workflow = create_workflow_and_check_definition(
426+
workflow_graph=workflow_graph,
427+
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
428+
sfn_client=sfn_client,
429+
sfn_role_arn=sfn_role_arn
430+
)
431+
432+
execution_input = {
433+
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
434+
'instance_count': 1,
435+
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
436+
'role': sagemaker_role_arn,
437+
'volume_size_in_gb': 30,
438+
'max_runtime_in_seconds': 500,
439+
'container_arguments': ['--train-test-split-ratio', '0.2']
440+
}
441+
442+
# Execute workflow
443+
execution = workflow.execute(inputs=execution_input)
444+
execution_output = execution.get_output(wait=True)
445+
446+
# Check workflow output
447+
assert execution_output.get("ProcessingJobStatus") == "Completed"
448+
449+
# Cleanup
450+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)

‎tests/unit/test_sagemaker_steps.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
from unittest.mock import MagicMock, patch
2929
from stepfunctions.inputs import ExecutionInput, StepInput
30-
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep
30+
from stepfunctions.steps.fields import Field
31+
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\
32+
ProcessingStep
3133
from stepfunctions.steps.sagemaker import tuning_config
3234

3335
from tests.unit.utils import mock_boto_api_call
@@ -962,3 +964,136 @@ def test_processing_step_creation(sklearn_processor):
962964
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
963965
'End': True
964966
}
967+
968+
969+
def test_processing_step_creation_with_placeholders(sklearn_processor):
970+
execution_input = ExecutionInput(schema={
971+
'image_uri': str,
972+
'instance_count': int,
973+
'entrypoint': str,
974+
'output_kms_key': str,
975+
'role': str,
976+
'env': str,
977+
'volume_size_in_gb': int,
978+
'volume_kms_key': str,
979+
'max_runtime_in_seconds': int,
980+
'tags': [{str: str}],
981+
'container_arguments': [str]
982+
})
983+
984+
step_input = StepInput(schema={
985+
'instance_type': str
986+
})
987+
988+
parameters = {
989+
'AppSpecification': {
990+
'ContainerEntrypoint': execution_input['entrypoint'],
991+
'ImageUri': execution_input['image_uri']
992+
},
993+
'Environment': execution_input['env'],
994+
'ProcessingOutputConfig': {
995+
'KmsKeyId': execution_input['output_kms_key']
996+
},
997+
'ProcessingResources': {
998+
'ClusterConfig': {
999+
'InstanceCount': execution_input['instance_count'],
1000+
'InstanceType': step_input['instance_type'],
1001+
'VolumeKmsKeyId': execution_input['volume_kms_key'],
1002+
'VolumeSizeInGB': execution_input['volume_size_in_gb']
1003+
}
1004+
},
1005+
'RoleArn': execution_input['role'],
1006+
'StoppingCondition': {
1007+
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
1008+
},
1009+
'Tags': execution_input['tags']
1010+
}
1011+
1012+
inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')]
1013+
outputs = [
1014+
ProcessingOutput(source='/opt/ml/processing/output/train'),
1015+
ProcessingOutput(source='/opt/ml/processing/output/validation'),
1016+
ProcessingOutput(source='/opt/ml/processing/output/test')
1017+
]
1018+
step = ProcessingStep(
1019+
'Feature Transformation',
1020+
sklearn_processor,
1021+
'MyProcessingJob',
1022+
container_entrypoint=execution_input['entrypoint'],
1023+
container_arguments=execution_input['container_arguments'],
1024+
kms_key_id=execution_input['output_kms_key'],
1025+
inputs=inputs,
1026+
outputs=outputs,
1027+
parameters=parameters
1028+
)
1029+
assert step.to_dict() == {
1030+
'Type': 'Task',
1031+
'Parameters': {
1032+
'AppSpecification': {
1033+
'ContainerArguments.$': "$$.Execution.Input['container_arguments']",
1034+
'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']",
1035+
'ImageUri.$': "$$.Execution.Input['image_uri']"
1036+
},
1037+
'Environment.$': "$$.Execution.Input['env']",
1038+
'ProcessingInputs': [
1039+
{
1040+
'InputName': None,
1041+
'AppManaged': False,
1042+
'S3Input': {
1043+
'LocalPath': '/opt/ml/processing/input',
1044+
'S3CompressionType': 'None',
1045+
'S3DataDistributionType': 'FullyReplicated',
1046+
'S3DataType': 'S3Prefix',
1047+
'S3InputMode': 'File',
1048+
'S3Uri': 'dataset.csv'
1049+
}
1050+
}
1051+
],
1052+
'ProcessingOutputConfig': {
1053+
'KmsKeyId.$': "$$.Execution.Input['output_kms_key']",
1054+
'Outputs': [
1055+
{
1056+
'OutputName': None,
1057+
'AppManaged': False,
1058+
'S3Output': {
1059+
'LocalPath': '/opt/ml/processing/output/train',
1060+
'S3UploadMode': 'EndOfJob',
1061+
'S3Uri': None
1062+
}
1063+
},
1064+
{
1065+
'OutputName': None,
1066+
'AppManaged': False,
1067+
'S3Output': {
1068+
'LocalPath': '/opt/ml/processing/output/validation',
1069+
'S3UploadMode': 'EndOfJob',
1070+
'S3Uri': None
1071+
}
1072+
},
1073+
{
1074+
'OutputName': None,
1075+
'AppManaged': False,
1076+
'S3Output': {
1077+
'LocalPath': '/opt/ml/processing/output/test',
1078+
'S3UploadMode': 'EndOfJob',
1079+
'S3Uri': None
1080+
}
1081+
}
1082+
]
1083+
},
1084+
'ProcessingResources': {
1085+
'ClusterConfig': {
1086+
'InstanceCount.$': "$$.Execution.Input['instance_count']",
1087+
'InstanceType.$': "$['instance_type']",
1088+
'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']",
1089+
'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']"
1090+
}
1091+
},
1092+
'ProcessingJobName': 'MyProcessingJob',
1093+
'RoleArn.$': "$$.Execution.Input['role']",
1094+
'Tags.$': "$$.Execution.Input['tags']",
1095+
'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"},
1096+
},
1097+
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
1098+
'End': True
1099+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /