Skip to content

Commit dba80dc

Browse files
committed
chore: add integ test for default training metrics
1 parent 8b7d1c5 commit dba80dc

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4242

4343
TRAINING_DATASET_MODEL_DICT = {
4444
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
45+
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4546
}
4647

4748

tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414
import os
1515

16-
from sagemaker import hyperparameters, image_uris, model_uris, script_uris
16+
from sagemaker import hyperparameters, metric_definitions, image_uris, model_uris, script_uris
1717
from sagemaker.estimator import Estimator
1818
from sagemaker.jumpstart.constants import (
1919
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
@@ -35,7 +35,7 @@
3535

3636
def test_jumpstart_transfer_learning_estimator_class(setup):
3737

38-
model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
38+
model_id, model_version = "huggingface-spc-bert-base-cased", "1.2.3"
3939
training_instance_type = "ml.p3.2xlarge"
4040
inference_instance_type = "ml.p2.xlarge"
4141
instance_count = 1
@@ -66,6 +66,11 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
6666

6767
default_hyperparameters["epochs"] = "1"
6868

69+
default_metric_definitions = metric_definitions.retrieve_default(
70+
model_id=model_id,
71+
model_version=model_version,
72+
)
73+
6974
estimator = Estimator(
7075
image_uri=image_uri,
7176
source_dir=script_uri,
@@ -78,6 +83,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
7883
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7984
instance_count=instance_count,
8085
instance_type=training_instance_type,
86+
metric_definitions=default_metric_definitions,
8187
)
8288

8389
estimator.fit(

0 commit comments

Comments
 (0)