Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f8008e8

Browse files
Update env param directly and add test
1 parent ea3e482 commit f8008e8

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

‎src/stepfunctions/steps/sagemaker.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,20 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
286286
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
287287
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
288288
"""
289-
if isinstance(model, Model):
289+
if isinstance(model, FrameworkModel):
290290
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
291291
if model_name:
292292
parameters['ModelName'] = model_name
293+
elif isinstance(model, Model):
294+
parameters = {
295+
'ExecutionRoleArn': model.role,
296+
'ModelName': model_name or model.name,
297+
'PrimaryContainer': {
298+
'Environment': model.env,
299+
'Image': model.image_uri,
300+
'ModelDataUrl': model.model_data
301+
}
302+
}
293303
else:
294304
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
295305

‎tests/unit/test_sagemaker_steps.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ def pca_estimator_with_env():
8181
environment={
8282
'JobName': "job_name",
8383
'ModelName': "model_name"
84-
}
84+
},
85+
subnets=[
86+
'subnet-00000000000000000',
87+
'subnet-00000000000000001'
88+
]
8589
)
8690

8791
pca.set_hyperparameters(
@@ -187,6 +191,31 @@ def pca_model():
187191
)
188192

189193

194+
@pytest.fixture
195+
def pca_model_with_env():
196+
model_data = 's3://sagemaker/models/pca.tar.gz'
197+
return Model(
198+
model_data=model_data,
199+
image_uri=PCA_IMAGE,
200+
role=EXECUTION_ROLE,
201+
name='pca-model',
202+
env={
203+
'JobName': "job_name",
204+
'ModelName': "model_name"
205+
},
206+
vpc_config={
207+
"SecurityGroupIds": ["sg-00000000000000000"],
208+
"Subnets": ["subnet-00000000000000000", "subnet-00000000000000001"]
209+
},
210+
image_config={
211+
"RepositoryAccessMode": "Vpc",
212+
"RepositoryAuthConfig": {
213+
"RepositoryCredentialsProviderArn": "arn"
214+
}
215+
}
216+
)
217+
218+
190219
@pytest.fixture
191220
def pca_transformer(pca_model):
192221
return Transformer(
@@ -855,6 +884,31 @@ def test_get_expected_model(pca_estimator):
855884
}
856885

857886

887+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
888+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
889+
def test_get_expected_model_with_env(pca_estimator_with_env):
890+
training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob')
891+
expected_model = training_step.get_expected_model()
892+
model_step = ModelStep('Create model', model=expected_model, model_name='pca-model')
893+
assert model_step.to_dict() == {
894+
'Type': 'Task',
895+
'Parameters': {
896+
'ExecutionRoleArn': EXECUTION_ROLE,
897+
'ModelName': 'pca-model',
898+
'PrimaryContainer': {
899+
'Environment': {
900+
'JobName': 'job_name',
901+
'ModelName': 'model_name'
902+
},
903+
'Image': expected_model.image_uri,
904+
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
905+
}
906+
},
907+
'Resource': 'arn:aws:states:::sagemaker:createModel',
908+
'End': True
909+
}
910+
911+
858912
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
859913
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
860914
def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
@@ -908,6 +962,29 @@ def test_model_step_creation(pca_model):
908962
}
909963

910964

965+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
966+
def test_model_step_creation_with_env(pca_model_with_env):
967+
step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS)
968+
assert step.to_dict() == {
969+
'Type': 'Task',
970+
'Parameters': {
971+
'ExecutionRoleArn': EXECUTION_ROLE,
972+
'ModelName': 'pca-model',
973+
'PrimaryContainer': {
974+
'Environment': {
975+
'JobName': 'job_name',
976+
'ModelName': 'model_name'
977+
},
978+
'Image': pca_model_with_env.image_uri,
979+
'ModelDataUrl': pca_model_with_env.model_data
980+
},
981+
'Tags': DEFAULT_TAGS_LIST
982+
},
983+
'Resource': 'arn:aws:states:::sagemaker:createModel',
984+
'End': True
985+
}
986+
987+
911988
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
912989
def test_endpoint_config_step_creation(pca_model):
913990
data_capture_config = DataCaptureConfig(

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /