Skip to content

Commit c0a5671

Browse files
nmadanNamrata Madan
authored andcommitted
fix: remote function include_local_workdir default value (#1342)
Co-authored-by: Namrata Madan <[email protected]>
1 parent 680cb67 commit c0a5671

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def remote(
6565
pre_execution_script: str = None,
6666
environment_variables: Dict[str, str] = None,
6767
image_uri: str = None,
68-
include_local_workdir: bool = False,
68+
include_local_workdir: bool = None,
6969
custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
7070
instance_count: int = 1,
7171
instance_type: str = None,
@@ -495,7 +495,7 @@ def __init__(
495495
pre_execution_script: str = None,
496496
environment_variables: Dict[str, str] = None,
497497
image_uri: str = None,
498-
include_local_workdir: bool = False,
498+
include_local_workdir: bool = None,
499499
custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
500500
instance_count: int = 1,
501501
instance_type: str = None,

src/sagemaker/workflow/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -971,8 +971,8 @@ def result(self, step_name: str):
971971
try:
972972
self.wait()
973973
except WaiterError as e:
974-
if "Waiter encountered a terminal failure state" in str(e):
975-
pass
974+
if "Waiter encountered a terminal failure state" not in str(e):
975+
raise
976976
step = next(filter(lambda x: x["StepName"] == step_name, self.list_steps()), None)
977977
if not step:
978978
raise ValueError(f"Invalid step name {step_name}")

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import threading
1617
import time
1718

1819
import pytest
1920
from mock import MagicMock, patch, Mock, ANY, call
21+
22+
from sagemaker.config import load_sagemaker_config
2023
from sagemaker.exceptions import UnexpectedStatusException
2124

2225
from botocore.exceptions import ClientError
@@ -36,6 +39,7 @@
3639
RuntimeEnvironmentError,
3740
)
3841
from sagemaker.remote_function.job import _RunInfo
42+
from tests.unit import DATA_DIR
3943

4044
from tests.unit.sagemaker.experiments.helpers import (
4145
mock_tc_load_or_create_func,
@@ -50,6 +54,7 @@
5054
EXPECTED_JOB_RESULT = [1, 2, 3]
5155
PATH_TO_SRC_DIR = "path/to/src/dir"
5256
HMAC_KEY = "some-hmac-key"
57+
ROLE_ARN = "arn:aws:iam::555555555555:role/my_execution_role_arn"
5358

5459

5560
def describe_training_job_response(job_status):
@@ -175,6 +180,37 @@ def square(x):
175180
assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE
176181

177182

183+
@patch(
184+
"sagemaker.remote_function.client.serialization.deserialize_obj_from_s3",
185+
return_value=EXPECTED_JOB_RESULT,
186+
)
187+
@patch("sagemaker.remote_function.client._Job.start")
188+
@patch("sagemaker.remote_function.job.Session")
189+
def test_decorator_with_config_file(session, mock_start, mock_deserialize_obj_from_s3):
190+
session().get_caller_identity_arn = lambda: ROLE_ARN
191+
session().sagemaker_config = load_sagemaker_config(
192+
additional_config_paths=[os.path.join(DATA_DIR, "remote_function")]
193+
)
194+
195+
mock_job = Mock(job_name=TRAINING_JOB_NAME)
196+
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
197+
198+
mock_start.return_value = mock_job
199+
200+
@remote(image_uri=IMAGE, s3_root_uri=S3_URI)
201+
def square(x):
202+
return x * x
203+
204+
result = square(5)
205+
assert result == EXPECTED_JOB_RESULT
206+
assert square.job_settings.image_uri == IMAGE
207+
assert square.job_settings.s3_root_uri == S3_URI
208+
# assert values are read from sagemaker defaults config file
209+
assert square.job_settings.include_local_workdir is True
210+
assert square.job_settings.custom_file_filter.ignore_name_patterns == ["data", "test"]
211+
assert square.job_settings.s3_kms_key == "someS3KmsKey"
212+
213+
178214
@patch(
179215
"sagemaker.remote_function.client.serialization.deserialize_obj_from_s3",
180216
return_value=EXPECTED_JOB_RESULT,
@@ -609,6 +645,37 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
609645
assert future_4.done()
610646

611647

648+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
649+
@patch("sagemaker.remote_function.client._Job.start")
650+
@patch("sagemaker.session.Session")
651+
def test_executor_submit_with_config_file(session, mock_start):
652+
session().get_caller_identity_arn = lambda: ROLE_ARN
653+
session().sagemaker_config = load_sagemaker_config(
654+
additional_config_paths=[os.path.join(DATA_DIR, "remote_function")]
655+
)
656+
657+
mock_job = create_mock_job("job_1", COMPLETED_TRAINING_JOB)
658+
mock_start.side_effect = [mock_job]
659+
660+
with RemoteExecutor(
661+
max_parallel_jobs=1,
662+
s3_root_uri="s3://bucket/",
663+
image_uri=IMAGE,
664+
sagemaker_session=session(),
665+
) as e:
666+
future = e.submit(job_function, 1, 2, c=3, d=4)
667+
668+
# assert values are read from sagemaker defaults config file
669+
assert e.job_settings.include_local_workdir is True
670+
assert e.job_settings.custom_file_filter.ignore_name_patterns == ["data", "test"]
671+
assert e.job_settings.s3_kms_key == "someS3KmsKey"
672+
673+
mock_start.assert_called_with(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None)
674+
mock_job.describe.assert_called()
675+
676+
assert future.done()
677+
678+
612679
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
613680
@patch("sagemaker.remote_function.client._JobSettings")
614681
@patch("sagemaker.remote_function.client._Job.start")

0 commit comments

Comments
 (0)