Skip to content

Commit 51cd999

Browse files
committed
Create fake uploaded_code in framework model
1 parent 93f0f58 commit 51cd999

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

src/sagemaker/model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import sagemaker
1818

1919
from sagemaker.local import LocalSession
20-
from sagemaker.fw_utils import UploadedCode, tar_and_upload_dir, parse_s3_url, model_code_key_prefix
20+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, model_code_key_prefix
2121
from sagemaker.session import Session
2222
from sagemaker.utils import name_from_image, get_config_value
2323

@@ -188,9 +188,7 @@ def prepare_container_def(self, instance_type): # pylint disable=unused-argumen
188188
def _upload_code(self, key_prefix):
189189
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
190190
if self.sagemaker_session.local_mode and local_code:
191-
script_name = self.entry_point
192-
dir_name = 'file://' + self.source_dir
193-
self.uploaded_code = UploadedCode(s3_prefix=dir_name, script_name=script_name)
191+
self.uploaded_code = None
194192
else:
195193
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
196194
bucket=self.bucket or self.sagemaker_session.default_bucket(),
@@ -204,7 +202,7 @@ def _framework_env_vars(self):
204202
dir_name = self.uploaded_code.s3_prefix
205203
else:
206204
script_name = self.entry_point
207-
dir_name = self.source_dir
205+
dir_name = 'file://' + self.source_dir
208206

209207
return {
210208
SCRIPT_PARAM_NAME.upper(): script_name,

src/sagemaker/workflow/airflow.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -293,22 +293,13 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
293293
deploy_image = fw_utils.create_image_uri(
294294
region_name, model.__framework_name__, instance_type, model.framework_version, model.py_version)
295295

296-
deploy_env = dict(model.env)
297-
deploy_env.update(model._framework_env_vars())
298-
try:
299-
if model.model_server_workers:
300-
deploy_env[sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(model.model_server_workers)
301-
except AttributeError:
302-
# This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model
303-
pass
304-
305-
container_def = sagemaker.container_def(deploy_image, model.model_data, deploy_env)
306-
base_name = utils.base_name_from_image(container_def['Image'])
296+
base_name = utils.base_name_from_image(deploy_image)
307297
model.name = model.name or utils.airflow_name_from_base(base_name)
308298

309299
bucket = model.bucket or model.sagemaker_session._default_bucket
310-
key = '{}/source/sourcedir.tar.gz'.format(model.name)
311300
script = os.path.basename(model.entry_point)
301+
key = '{}/source/sourcedir.tar.gz'.format(model.name)
302+
312303
if model.source_dir and model.source_dir.lower().startswith('s3://'):
313304
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=model.source_dir, script_name=script)
314305
else:
@@ -321,6 +312,18 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
321312
'Tar': True
322313
}]
323314

315+
deploy_env = dict(model.env)
316+
deploy_env.update(model._framework_env_vars())
317+
318+
try:
319+
if model.model_server_workers:
320+
deploy_env[sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(model.model_server_workers)
321+
except AttributeError:
322+
# This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model
323+
pass
324+
325+
container_def = sagemaker.container_def(deploy_image, model.model_data, deploy_env)
326+
324327
return container_def
325328

326329

tests/unit/test_airflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def test_byo_framework_model_config(sagemaker_session):
609609
'Environment': {
610610
'{{ key }}': '{{ value }}',
611611
'SAGEMAKER_PROGRAM': '{{ entry_point }}',
612-
'SAGEMAKER_SUBMIT_DIRECTORY': '{{ source_dir }}',
612+
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://output/model/source/sourcedir.tar.gz',
613613
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
614614
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
615615
'SAGEMAKER_REGION': 'us-west-2'
@@ -648,7 +648,9 @@ def test_framework_model_config(sagemaker_session):
648648
'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3',
649649
'Environment': {
650650
'SAGEMAKER_PROGRAM': '{{ entry_point }}',
651-
'SAGEMAKER_SUBMIT_DIRECTORY': '{{ source_dir }}',
651+
'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/sagemaker-chainer-"
652+
"{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
653+
"/source/sourcedir.tar.gz",
652654
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
653655
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
654656
'SAGEMAKER_REGION': 'us-west-2',

0 commit comments

Comments
 (0)