Skip to content

Commit 21acd45

Browse files
fix: repack model locally when local_code local mode
1 parent 57d4763 commit 21acd45

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import logging
1919
import os
20+
import pdb
2021
import re
2122
import copy
2223

@@ -454,9 +455,15 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
454455
if is_pipeline_variable(self.model_data):
455456
# model is not yet there, defer repacking to later during pipeline execution
456457
return
457-
458-
bucket = self.bucket or self.sagemaker_session.default_bucket()
459-
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
458+
pdb.set_trace()
459+
if local_code and self.model_data.startswith("file://"):
460+
repacked_model_data = self.model_data
461+
else:
462+
bucket = self.bucket or self.sagemaker_session.default_bucket()
463+
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
464+
self.uploaded_code = fw_utils.UploadedCode(
465+
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
466+
)
460467

461468
utils.repack_model(
462469
inference_script=self.entry_point,
@@ -469,9 +476,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
469476
)
470477

471478
self.repacked_model_data = repacked_model_data
472-
self.uploaded_code = fw_utils.UploadedCode(
473-
s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point)
474-
)
475479

476480
def _script_mode_env_vars(self):
477481
"""Returns a mapping of environment variables for script mode execution"""

tests/integ/test_local_mode.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import tests.integ.lock as lock
2626
from tests.integ import DATA_DIR
27+
from mock import Mock, ANY
2728

2829
from sagemaker import image_uris
2930

@@ -221,6 +222,13 @@ def test_mxnet_local_data_local_script(
221222
):
222223
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
223224
script_path = os.path.join(data_path, "mnist.py")
225+
local_no_s3_session = LocalNoS3Session()
226+
local_no_s3_session.boto_session.resource = Mock(
227+
side_effect=local_no_s3_session.boto_session.resource
228+
)
229+
local_no_s3_session.boto_session.client = Mock(
230+
side_effect=local_no_s3_session.boto_session.client
231+
)
224232

225233
mx = MXNet(
226234
entry_point=script_path,
@@ -229,7 +237,7 @@ def test_mxnet_local_data_local_script(
229237
instance_type="local",
230238
framework_version=mxnet_training_latest_version,
231239
py_version=mxnet_training_latest_py_version,
232-
sagemaker_session=LocalNoS3Session(),
240+
sagemaker_session=local_no_s3_session,
233241
)
234242

235243
train_input = "file://" + os.path.join(data_path, "train")
@@ -243,6 +251,11 @@ def test_mxnet_local_data_local_script(
243251
predictor = mx.deploy(1, "local", endpoint_name=endpoint_name)
244252
data = numpy.zeros(shape=(1, 1, 28, 28))
245253
predictor.predict(data)
254+
# check if no boto_session s3 calls were made
255+
with pytest.raises(AssertionError):
256+
local_no_s3_session.boto_session.resource.assert_called_with("s3", region_name=ANY)
257+
with pytest.raises(AssertionError):
258+
local_no_s3_session.boto_session.client.assert_called_with("s3", region_name=ANY)
246259
finally:
247260
predictor.delete_endpoint()
248261

0 commit comments

Comments
 (0)