34
34
from stepfunctions .steps .sagemaker import TrainingStep , TransformStep , ModelStep , EndpointStep , EndpointConfigStep , TuningStep , ProcessingStep
35
35
from stepfunctions .workflow import Workflow
36
36
37
- from tests .integ import DATA_DIR , DEFAULT_TIMEOUT_MINUTES
37
+ from tests .integ import DATA_DIR , DEFAULT_TIMEOUT_MINUTES , SAGEMAKER_RETRY_STRATEGY
38
38
from tests .integ .timeout import timeout
39
39
from tests .integ .utils import (
40
40
state_machine_delete_wait ,
@@ -83,6 +83,7 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf
83
83
# Build workflow definition
84
84
job_name = generate_job_name ()
85
85
training_step = TrainingStep ('create_training_job_step' , estimator = pca_estimator_fixture , job_name = job_name , data = record_set_fixture , mini_batch_size = 200 )
86
+ training_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
86
87
workflow_graph = Chain ([training_step ])
87
88
88
89
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -110,6 +111,7 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
110
111
# Build workflow definition
111
112
model_name = generate_job_name ()
112
113
model_step = ModelStep ('create_model_step' , model = trained_estimator .create_model (), model_name = model_name )
114
+ model_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
113
115
workflow_graph = Chain ([model_step ])
114
116
115
117
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -142,6 +144,7 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
142
144
143
145
# Create a model step to save the model
144
146
model_step = ModelStep ('create_model_step' , model = trained_estimator .create_model (), model_name = job_name )
147
+ model_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
145
148
146
149
# Upload data for transformation to S3
147
150
data_path = os .path .join (DATA_DIR , "one_p_mnist" )
@@ -153,6 +156,7 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
153
156
154
157
# Build workflow definition
155
158
transform_step = TransformStep ('create_transform_job_step' , pca_transformer , job_name = job_name , model_name = job_name , data = transform_input , content_type = "text/csv" )
159
+ transform_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
156
160
workflow_graph = Chain ([model_step , transform_step ])
157
161
158
162
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -184,6 +188,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
184
188
# Build workflow definition
185
189
endpoint_config_name = unique_name_from_base ("integ-test-endpoint-config" )
186
190
endpoint_config_step = EndpointConfigStep ('create_endpoint_config_step' , endpoint_config_name = endpoint_config_name , model_name = model .name , initial_instance_count = INSTANCE_COUNT , instance_type = INSTANCE_TYPE )
191
+ endpoint_config_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
187
192
workflow_graph = Chain ([endpoint_config_step ])
188
193
189
194
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -224,6 +229,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
224
229
# Build workflow definition
225
230
endpoint_name = unique_name_from_base ("integ-test-endpoint" )
226
231
endpoint_step = EndpointStep ('create_endpoint_step' , endpoint_name = endpoint_name , endpoint_config_name = model .name )
232
+ endpoint_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
227
233
workflow_graph = Chain ([endpoint_step ])
228
234
229
235
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -279,6 +285,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
279
285
280
286
# Build workflow definition
281
287
tuning_step = TuningStep ('Tuning' , tuner = tuner , job_name = job_name , data = record_set_for_hyperparameter_tuning )
288
+ tuning_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
282
289
workflow_graph = Chain ([tuning_step ])
283
290
284
291
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -332,6 +339,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
332
339
container_arguments = ['--train-test-split-ratio' , '0.2' ],
333
340
container_entrypoint = ['python3' , '/opt/ml/processing/input/code/preprocessor.py' ],
334
341
)
342
+ processing_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
335
343
workflow_graph = Chain ([processing_step ])
336
344
337
345
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -419,6 +427,7 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_
419
427
container_entrypoint = execution_input ['entrypoint' ],
420
428
parameters = parameters
421
429
)
430
+ processing_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
422
431
workflow_graph = Chain ([processing_step ])
423
432
424
433
with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
0 commit comments