diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index a3ddd47..79a6bc4 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -159,6 +159,8 @@ def get_expected_model(self, model_name=None): model.name = model_name else: model.name = self.job_name + if self.estimator.environment: + model.env = self.estimator.environment model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"] return model @@ -295,7 +297,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No 'ExecutionRoleArn': model.role, 'ModelName': model_name or model.name, 'PrimaryContainer': { - 'Environment': {}, + 'Environment': model.env, 'Image': model.image_uri, 'ModelDataUrl': model.model_data } diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 02c6083..e5c0548 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -68,6 +68,41 @@ def pca_estimator(): return pca +@pytest.fixture +def pca_estimator_with_env(): + s3_output_location = 's3://sagemaker/models' + + pca = sagemaker.estimator.Estimator( + PCA_IMAGE, + role=EXECUTION_ROLE, + instance_count=1, + instance_type='ml.c4.xlarge', + output_path=s3_output_location, + environment={ + 'JobName': "job_name", + 'ModelName': "model_name" + }, + subnets=[ + 'subnet-00000000000000000', + 'subnet-00000000000000001' + ] + ) + + pca.set_hyperparameters( + feature_dim=50000, + num_components=10, + subtract_mean=True, + algorithm_mode='randomized', + mini_batch_size=200 + ) + + pca.sagemaker_session = MagicMock() + pca.sagemaker_session.boto_region_name = 'us-east-1' + pca.sagemaker_session._default_bucket = 'sagemaker' + + return pca + + @pytest.fixture def pca_estimator_with_debug_hook(): s3_output_location = 's3://sagemaker/models' @@ -156,6 +191,31 @@ def pca_model(): ) +@pytest.fixture +def pca_model_with_env(): + model_data = 's3://sagemaker/models/pca.tar.gz' + return Model( + model_data=model_data, + image_uri=PCA_IMAGE, + role=EXECUTION_ROLE, + name='pca-model', + env={ + 'JobName': "job_name", + 'ModelName': "model_name" + }, + vpc_config={ + "SecurityGroupIds": ["sg-00000000000000000"], + "Subnets": ["subnet-00000000000000000", "subnet-00000000000000001"] + }, + image_config={ + "RepositoryAccessMode": "Vpc", + "RepositoryAuthConfig": { + "RepositoryCredentialsProviderArn": "arn" + } + } + ) + + @pytest.fixture def pca_transformer(pca_model): return Transformer( @@ -537,6 +597,63 @@ def test_training_step_creation_with_model(pca_estimator): } +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_training_step_creation_with_model_with_env(pca_estimator_with_env): + training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob') + model_step = ModelStep('Training - Save Model', training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) + training_step.next(model_step) + assert training_step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'AlgorithmSpecification': { + 'TrainingImage': PCA_IMAGE, + 'TrainingInputMode': 'File' + }, + 'OutputDataConfig': { + 'S3OutputPath': 's3://sagemaker/models' + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 86400 + }, + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge', + 'VolumeSizeInGB': 30 + }, + 'RoleArn': EXECUTION_ROLE, + 'HyperParameters': { + 'feature_dim': '50000', + 'num_components': '10', + 'subtract_mean': 'True', + 'algorithm_mode': 'randomized', + 'mini_batch_size': '200' + }, + 'TrainingJobName': 'TrainingJob' + }, + 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', + 'Next': 'Training - Save Model' + } + + assert model_step.to_dict() == { + 'Type': 'Task', + 'Resource': 'arn:aws:states:::sagemaker:createModel', + 'Parameters': { + 'ExecutionRoleArn': EXECUTION_ROLE, + 'ModelName.$': "$['TrainingJobName']", + 'PrimaryContainer': { + 'Environment': { + 'JobName': 'job_name', + 'ModelName': 'model_name' + }, + 'Image': PCA_IMAGE, + 'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']" + } + }, + 'End': True + } + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_framework(tensorflow_estimator): @@ -806,6 +923,31 @@ def test_get_expected_model(pca_estimator): } +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_get_expected_model_with_env(pca_estimator_with_env): + training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob') + expected_model = training_step.get_expected_model() + model_step = ModelStep('Create model', model=expected_model, model_name='pca-model') + assert model_step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'ExecutionRoleArn': EXECUTION_ROLE, + 'ModelName': 'pca-model', + 'PrimaryContainer': { + 'Environment': { + 'JobName': 'job_name', + 'ModelName': 'model_name' + }, + 'Image': expected_model.image_uri, + 'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']" + } + }, + 'Resource': 'arn:aws:states:::sagemaker:createModel', + 'End': True + } + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_get_expected_model_with_framework_estimator(tensorflow_estimator): @@ -859,6 +1001,29 @@ def test_model_step_creation(pca_model): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_model_step_creation_with_env(pca_model_with_env): + step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS) + assert step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'ExecutionRoleArn': EXECUTION_ROLE, + 'ModelName': 'pca-model', + 'PrimaryContainer': { + 'Environment': { + 'JobName': 'job_name', + 'ModelName': 'model_name' + }, + 'Image': pca_model_with_env.image_uri, + 'ModelDataUrl': pca_model_with_env.model_data + }, + 'Tags': DEFAULT_TAGS_LIST + }, + 'Resource': 'arn:aws:states:::sagemaker:createModel', + 'End': True + } + + @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_endpoint_config_step_creation(pca_model): data_capture_config = DataCaptureConfig(