diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 841ec02..5098ce8 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -14,5 +14,15 @@ import os +from stepfunctions.steps import Retry + DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") -DEFAULT_TIMEOUT_MINUTES = 25 \ No newline at end of file +DEFAULT_TIMEOUT_MINUTES = 25 + +# Default retry strategy for SageMaker steps used in integration tests +SAGEMAKER_RETRY_STRATEGY = Retry( + error_equals=["SageMaker.AmazonSageMakerException"], + interval_seconds=5, + max_attempts=5, + backoff_rate=2 +) \ No newline at end of file diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index f840302..d9f2b1c 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -34,7 +34,7 @@ from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep from stepfunctions.workflow import Workflow -from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES +from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY from tests.integ.timeout import timeout from tests.integ.utils import ( state_machine_delete_wait, @@ -83,6 +83,7 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf # Build workflow definition job_name = generate_job_name() training_step = TrainingStep('create_training_job_step', estimator=pca_estimator_fixture, job_name=job_name, data=record_set_fixture, mini_batch_size=200) + training_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([training_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -110,6 +111,7 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a # Build workflow definition model_name = generate_job_name() model_step = ModelStep('create_model_step', model=trained_estimator.create_model(), model_name=model_name) + model_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([model_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -142,6 +144,7 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn): # Create a model step to save the model model_step = ModelStep('create_model_step', model=trained_estimator.create_model(), model_name=job_name) + model_step.add_retry(SAGEMAKER_RETRY_STRATEGY) # Upload data for transformation to S3 data_path = os.path.join(DATA_DIR, "one_p_mnist") @@ -153,6 +156,7 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn): # Build workflow definition transform_step = TransformStep('create_transform_job_step', pca_transformer, job_name=job_name, model_name=job_name, data=transform_input, content_type="text/csv") + transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([model_step, transform_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -184,6 +188,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session, # Build workflow definition endpoint_config_name = unique_name_from_base("integ-test-endpoint-config") endpoint_config_step = EndpointConfigStep('create_endpoint_config_step', endpoint_config_name=endpoint_config_name, model_name=model.name, initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE) + endpoint_config_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([endpoint_config_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -224,6 +229,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client, # Build workflow definition endpoint_name = unique_name_from_base("integ-test-endpoint") endpoint_step = EndpointStep('create_endpoint_step', endpoint_name=endpoint_name, endpoint_config_name=model.name) + endpoint_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([endpoint_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -279,6 +285,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker # Build workflow definition tuning_step = TuningStep('Tuning', tuner=tuner, job_name=job_name, data=record_set_for_hyperparameter_tuning) + tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([tuning_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -332,6 +339,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien container_arguments=['--train-test-split-ratio', '0.2'], container_entrypoint=['python3', '/opt/ml/processing/input/code/preprocessor.py'], ) + processing_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([processing_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): @@ -419,6 +427,7 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_ container_entrypoint=execution_input['entrypoint'], parameters=parameters ) + processing_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([processing_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):