Skip to content

Commit a27d9d7

Browse files
chuyang-dengChuyang Dengicywang86rui
authored
fix: preserve model_dir bool value (#1954)
* fix: preserve model_dir bool value * avoid calling hyperparaneters twince Co-authored-by: Chuyang Deng <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent 813625f commit a27d9d7

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

src/sagemaker/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
11131113

11141114
config = _Job._load_config(inputs, estimator)
11151115

1116-
if estimator.hyperparameters() is not None:
1117-
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
1116+
current_hyperparameters = estimator.hyperparameters()
1117+
if current_hyperparameters is not None:
1118+
hyperparameters = {str(k): str(v) for (k, v) in current_hyperparameters.items()}
11181119

11191120
train_args = config.copy()
11201121
train_args["input_mode"] = estimator.input_mode

src/sagemaker/tensorflow/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def __init__(
135135
kwargs["enable_sagemaker_metrics"] = True
136136

137137
super(TensorFlow, self).__init__(image_uri=image_uri, **kwargs)
138-
139138
self.model_dir = model_dir
140139
self.distribution = distribution or {}
141140

tests/component/test_tf_estimator.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from mock import Mock
16+
from mock import Mock, ANY
1717
from sagemaker.tensorflow import TensorFlow
1818

1919

@@ -24,13 +24,15 @@
2424
INSTANCE_COUNT = 1
2525
INSTANCE_TYPE_GPU = "ml.p2.xlarge"
2626
INSTANCE_TYPE_CPU = "ml.m4.xlarge"
27-
CPU_IMAGE_NAME = "sagemaker-tensorflow-py2-cpu"
28-
GPU_IMAGE_NAME = "sagemaker-tensorflow-py2-gpu"
27+
REPOSITORY = "tensorflow-inference"
28+
PROCESSOR = "cpu"
2929
REGION = "us-west-2"
30-
IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}"
30+
IMAGE_URI_FORMAT_STRING = "763104351884.dkr.ecr.{}.amazonaws.com/{}:{}-{}"
3131
REGION = "us-west-2"
3232
ROLE = "SagemakerRole"
3333
SOURCE_DIR = "s3://fefergerger"
34+
ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"}
35+
ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]}
3436

3537

3638
@pytest.fixture()
@@ -39,48 +41,64 @@ def sagemaker_session():
3941
ims = Mock(
4042
name="sagemaker_session",
4143
boto_session=boto_mock,
44+
boto_region_name=REGION,
4245
config=None,
4346
local_mode=False,
44-
region_name=REGION,
47+
s3_resource=None,
48+
s3_client=None,
4549
)
4650
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
4751
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
4852
ims.sagemaker_client.describe_training_job = Mock(
4953
return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
5054
)
55+
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
56+
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5157
return ims
5258

5359

60+
def test_model_dir_false(sagemaker_session):
61+
estimator = TensorFlow(
62+
entry_point=SCRIPT,
63+
source_dir=SOURCE_DIR,
64+
role=ROLE,
65+
framework_version="2.3.0",
66+
py_version="py37",
67+
instance_type="ml.m4.xlarge",
68+
instance_count=1,
69+
model_dir=False,
70+
)
71+
estimator.hyperparameters()
72+
assert estimator.model_dir is False
73+
74+
5475
# Test that we pass all necessary fields from estimator to the session when we call deploy
55-
def test_deploy(sagemaker_session, tf_version):
76+
def test_deploy(sagemaker_session):
5677
estimator = TensorFlow(
5778
entry_point=SCRIPT,
5879
source_dir=SOURCE_DIR,
5980
role=ROLE,
60-
framework_version=tf_version,
61-
train_instance_count=2,
62-
train_instance_type=INSTANCE_TYPE_CPU,
81+
framework_version="2.3.0",
82+
py_version="py37",
83+
instance_count=2,
84+
instance_type=INSTANCE_TYPE_CPU,
6385
sagemaker_session=sagemaker_session,
6486
base_job_name="test-cifar",
6587
)
6688

6789
estimator.fit("s3://mybucket/train")
68-
print("job succeeded: {}".format(estimator.latest_training_job.name))
6990

7091
estimator.deploy(initial_instance_count=1, instance_type=INSTANCE_TYPE_CPU)
71-
image = IMAGE_URI_FORMAT_STRING.format(REGION, CPU_IMAGE_NAME, tf_version, "cpu", "py2")
92+
image = IMAGE_URI_FORMAT_STRING.format(REGION, REPOSITORY, "2.3.0", PROCESSOR)
7293
sagemaker_session.create_model.assert_called_with(
73-
estimator._current_job_name,
94+
ANY,
7495
ROLE,
7596
{
76-
"Environment": {
77-
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
78-
"SAGEMAKER_SUBMIT_DIRECTORY": SOURCE_DIR,
79-
"SAGEMAKER_REQUIREMENTS": "",
80-
"SAGEMAKER_REGION": REGION,
81-
"SAGEMAKER_PROGRAM": SCRIPT,
82-
},
8397
"Image": image,
98+
"Environment": {"SAGEMAKER_TFS_NGINX_LOGLEVEL": "info"},
8499
"ModelDataUrl": "s3://m/m.tar.gz",
85100
},
101+
vpc_config=None,
102+
enable_network_isolation=False,
103+
tags=None,
86104
)

tests/unit/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def create_model(
119119
entry_point=None,
120120
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
121121
enable_network_isolation=None,
122+
model_dir=None,
122123
**kwargs
123124
):
124125
if enable_network_isolation is None:

0 commit comments

Comments
 (0)