Skip to content

Commit 1ed9349

Browse files
committed
fix: un-break local mode on FrameworkProcessor
Make Processors default to overriding LocalSession config local_code=False, since local_code is not (yet) supported for local mode processing. Fix corresponding integration test and add a unit test for the defaulting behavior. Clarify the error message.
1 parent bb9613b commit 1ed9349

File tree

4 files changed

+55
-11
lines changed

4 files changed

+55
-11
lines changed

src/sagemaker/local/local_session.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,30 @@ def invoke_endpoint(
475475

476476

477477
class LocalSession(Session):
478-
"""A LocalSession class definition."""
478+
"""A SageMaker ``Session`` class for Local Mode.
479479
480-
def __init__(self, boto_session=None, s3_endpoint_url=None):
480+
This class provides alternative Local Mode implementations for the functionality of
481+
:class:`~sagemaker.session.Session`.
482+
"""
483+
484+
def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False):
485+
"""Create a Local SageMaker Session.
486+
487+
Args:
488+
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
489+
calls are delegated to (default: None). If not provided, one is created with
490+
default AWS configuration chain.
491+
s3_endpoint_url (str): Override the default endpoint URL for Amazon S3, if set
492+
(default: None).
493+
disable_local_code (bool): Set ``True`` to override the default AWS configuration
494+
chain to disable the ``local.local_code`` setting, which may not be supported for
495+
some SDK features (default: False).
496+
"""
481497
self.s3_endpoint_url = s3_endpoint_url
498+
# We use this local variable to avoid disrupting the __init__->_initialize API of the
499+
# parent class... But overwriting it after constructor won't do anything, so prefix _ to
500+
# discourage external use:
501+
self._disable_local_code = disable_local_code
482502

483503
super(LocalSession, self).__init__(boto_session)
484504

@@ -530,6 +550,8 @@ def _initialize(
530550
raise e
531551

532552
self.config = yaml.load(open(sagemaker_config_file, "r"))
553+
if self._disable_local_code and "local" in self.config:
554+
self.config["local"]["local_code"] = False
533555

534556
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
535557
"""A no-op method meant to override the sagemaker client.

src/sagemaker/processing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def __init__(
128128

129129
if self.instance_type in ("local", "local_gpu"):
130130
if not isinstance(sagemaker_session, LocalSession):
131-
sagemaker_session = LocalSession()
131+
# Until Local Mode Processing supports local code, we need to disable it:
132+
sagemaker_session = LocalSession(disable_local_code=True)
132133

133134
self.sagemaker_session = sagemaker_session or Session()
134135

@@ -1568,10 +1569,11 @@ def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_
15681569

15691570
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
15701571
if self.sagemaker_session.local_mode and local_code:
1571-
# TODO: Can we be more prescriptive about how to not trigger this error?
1572-
# How can user or us force a local mode `Estimator` to run with `local_code=False`?
15731572
raise RuntimeError(
1574-
"Local *code* is not currently supported for SageMaker Processing in Local Mode"
1573+
"SageMaker Processing Local Mode does not currently support 'local code' mode. "
1574+
"Please use a LocalSession created with disable_local_code=True, or leave "
1575+
"sagemaker_session unspecified when creating your Processor to have one set up "
1576+
"automatically."
15751577
)
15761578

15771579
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.

tests/integ/test_local_mode.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
5858
self.local_mode = True
5959

6060

61+
@pytest.fixture(scope="module")
62+
def sagemaker_local_session_no_local_code(boto_session):
63+
return LocalSession(boto_session=boto_session, disable_local_code=True)
64+
65+
6166
@pytest.fixture(scope="module")
6267
def sklearn_image_uri(
6368
sklearn_latest_version,
@@ -322,7 +327,7 @@ def test_local_transform_mxnet(
322327

323328

324329
@pytest.mark.local_mode
325-
def test_local_processing_sklearn(sagemaker_local_session, sklearn_latest_version):
330+
def test_local_processing_sklearn(sagemaker_local_session_no_local_code, sklearn_latest_version):
326331
script_path = os.path.join(DATA_DIR, "dummy_script.py")
327332
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
328333

@@ -332,7 +337,7 @@ def test_local_processing_sklearn(sagemaker_local_session, sklearn_latest_versio
332337
instance_type="local",
333338
instance_count=1,
334339
command=["python3"],
335-
sagemaker_session=sagemaker_local_session,
340+
sagemaker_session=sagemaker_local_session_no_local_code,
336341
)
337342

338343
sklearn_processor.run(
@@ -344,12 +349,12 @@ def test_local_processing_sklearn(sagemaker_local_session, sklearn_latest_versio
344349

345350
job_description = sklearn_processor.latest_job.describe()
346351

347-
assert len(job_description["ProcessingInputs"]) == 2
352+
assert len(job_description["ProcessingInputs"]) == 3
348353
assert job_description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] == 1
349354
assert job_description["ProcessingResources"]["ClusterConfig"]["InstanceType"] == "local"
350355
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
351-
"python3",
352-
"/opt/ml/processing/input/code/dummy_script.py",
356+
"/bin/bash",
357+
"/opt/ml/processing/input/entrypoint/runproc.sh",
353358
]
354359
assert job_description["RoleArn"] == "<no_role>"
355360

tests/unit/test_processing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from sagemaker.sklearn.processing import SKLearnProcessor
3333
from sagemaker.pytorch.processing import PyTorchProcessor
34+
from sagemaker.utils import get_config_value
3435
from sagemaker.xgboost.processing import XGBoostEstimator
3536
from sagemaker.network import NetworkConfig
3637
from sagemaker.processing import FeatureStoreOutput
@@ -161,6 +162,20 @@ def test_sklearn_with_all_parameters(
161162
sagemaker_session.process.assert_called_with(**expected_args)
162163

163164

165+
166+
def test_local_mode_disables_local_code_by_default(sklearn_latest_version):
167+
processor = SKLearnProcessor(
168+
framework_version=sklearn_latest_version,
169+
role=ROLE,
170+
instance_count=1,
171+
instance_type="local",
172+
)
173+
174+
# Most tests use a fixture for sagemaker_session for consistent behaviour, so this unit test
175+
# checks that the default initialization disables unsupported 'local_code' mode:
176+
assert not get_config_value("local.local_code", processor.sagemaker_session.config)
177+
178+
164179
@patch("sagemaker.utils._botocore_resolver")
165180
@patch("os.path.exists", return_value=True)
166181
@patch("os.path.isfile", return_value=True)

0 commit comments

Comments
 (0)