13
13
from __future__ import absolute_import
14
14
import os
15
15
16
- from sagemaker import hyperparameters , image_uris , model_uris , script_uris
16
+ from sagemaker import hyperparameters , metric_definitions , image_uris , model_uris , script_uris
17
17
from sagemaker .estimator import Estimator
18
18
from sagemaker .jumpstart .constants import (
19
19
INFERENCE_ENTRY_POINT_SCRIPT_NAME ,
35
35
36
36
def test_jumpstart_transfer_learning_estimator_class (setup ):
37
37
38
- model_id , model_version = "huggingface-spc-bert-base-cased" , "1.0.0 "
38
+ model_id , model_version = "huggingface-spc-bert-base-cased" , "1.2.3 "
39
39
training_instance_type = "ml.p3.2xlarge"
40
40
inference_instance_type = "ml.p2.xlarge"
41
41
instance_count = 1
@@ -66,6 +66,11 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
66
66
67
67
default_hyperparameters ["epochs" ] = "1"
68
68
69
+ default_metric_definitions = metric_definitions .retrieve_default (
70
+ model_id = model_id ,
71
+ model_version = model_version ,
72
+ )
73
+
69
74
estimator = Estimator (
70
75
image_uri = image_uri ,
71
76
source_dir = script_uri ,
@@ -78,6 +83,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
78
83
tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
79
84
instance_count = instance_count ,
80
85
instance_type = training_instance_type ,
86
+ metric_definitions = default_metric_definitions ,
81
87
)
82
88
83
89
estimator .fit (
0 commit comments