Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 597fe1d

Browse files
Use model_config to generate CreateModelStep parameters for models without instance type
1 parent f8bbfaf commit 597fe1d

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

‎src/stepfunctions/steps/sagemaker.py‎

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def get_expected_model(self, model_name=None):
157157
model.name = model_name
158158
else:
159159
model.name = self.job_name
160+
if self.estimator.environment is not None:
161+
model.env = self.estimator.environment
160162
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
161163
return model
162164

@@ -284,20 +286,10 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
284286
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
285287
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
286288
"""
287-
if isinstance(model, FrameworkModel):
289+
if isinstance(model, Model):
288290
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
289291
if model_name:
290292
parameters['ModelName'] = model_name
291-
elif isinstance(model, Model):
292-
parameters = {
293-
'ExecutionRoleArn': model.role,
294-
'ModelName': model_name or model.name,
295-
'PrimaryContainer': {
296-
'Environment': {},
297-
'Image': model.image_uri,
298-
'ModelDataUrl': model.model_data
299-
}
300-
}
301293
else:
302294
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
303295

‎tests/unit/test_sagemaker_steps.py‎

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def pca_estimator():
5050
role=EXECUTION_ROLE,
5151
instance_count=1,
5252
instance_type='ml.c4.xlarge',
53-
output_path=s3_output_location
53+
output_path=s3_output_location,
54+
environment={
55+
'JobName': "job_name",
56+
'ModelName': "model_name"
57+
}
5458
)
5559

5660
pca.set_hyperparameters(
@@ -489,7 +493,10 @@ def test_training_step_creation_with_model(pca_estimator):
489493
'ExecutionRoleArn': EXECUTION_ROLE,
490494
'ModelName.$': "$['TrainingJobName']",
491495
'PrimaryContainer': {
492-
'Environment': {},
496+
'Environment': {
497+
'JobName': 'job_name',
498+
'ModelName': 'model_name'
499+
},
493500
'Image': PCA_IMAGE,
494501
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
495502
}
@@ -757,7 +764,10 @@ def test_get_expected_model(pca_estimator):
757764
'ExecutionRoleArn': EXECUTION_ROLE,
758765
'ModelName': 'pca-model',
759766
'PrimaryContainer': {
760-
'Environment': {},
767+
'Environment': {
768+
'JobName': 'job_name',
769+
'ModelName': 'model_name'
770+
},
761771
'Image': expected_model.image_uri,
762772
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
763773
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /