Skip to content

Commit d748287

Browse files
committed
fix: xfail gated training test if capacity error
1 parent 14fc003 commit d748287

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tests.integ.sagemaker.jumpstart.utils import (
2727
get_sm_session,
2828
get_training_dataset_for_model_and_version,
29+
x_fail_if_ice,
2930
)
3031

3132
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
@@ -75,8 +76,7 @@ def test_jumpstart_estimator(setup):
7576
assert response is not None
7677

7778

78-
# instance capacity errors require retries
79-
@pytest.mark.flaky(reruns=5, reruns_delay=60)
79+
@x_fail_if_ice
8080
@pytest.mark.skipif(
8181
tests.integ.test_region() not in GATED_TRAINING_MODEL_SUPPORTED_REGIONS,
8282
reason=f"JumpStart gated training models unavailable in {tests.integ.test_region()}.",

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import functools
1415
import json
1516

1617
import uuid
@@ -19,6 +20,7 @@
1920
import pandas as pd
2021
import os
2122
from botocore.config import Config
23+
import pytest
2224

2325

2426
from tests.integ.sagemaker.jumpstart.constants import (
@@ -50,6 +52,19 @@ def get_training_dataset_for_model_and_version(model_id: str, version: str) -> d
5052
return TRAINING_DATASET_MODEL_DICT[(model_id, version)]
5153

5254

55+
def x_fail_if_ice(func):
56+
@functools.wraps(func)
57+
def wrapper(*args, **kwargs):
58+
try:
59+
return func(*args, **kwargs)
60+
except Exception as e:
61+
if "CapacityError" in str(e):
62+
pytest.xfail(str(e))
63+
raise
64+
65+
return wrapper
66+
67+
5368
def download_inference_assets():
5469

5570
if not os.path.exists(TMP_DIRECTORY_PATH):

0 commit comments

Comments
 (0)