|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import os |
15 | 16 | import threading
|
16 | 17 | import time
|
17 | 18 |
|
18 | 19 | import pytest
|
19 | 20 | from mock import MagicMock, patch, Mock, ANY, call
|
| 21 | + |
| 22 | +from sagemaker.config import load_sagemaker_config |
20 | 23 | from sagemaker.exceptions import UnexpectedStatusException
|
21 | 24 |
|
22 | 25 | from botocore.exceptions import ClientError
|
|
36 | 39 | RuntimeEnvironmentError,
|
37 | 40 | )
|
38 | 41 | from sagemaker.remote_function.job import _RunInfo
|
| 42 | +from tests.unit import DATA_DIR |
39 | 43 |
|
40 | 44 | from tests.unit.sagemaker.experiments.helpers import (
|
41 | 45 | mock_tc_load_or_create_func,
|
|
50 | 54 | EXPECTED_JOB_RESULT = [1, 2, 3]
|
51 | 55 | PATH_TO_SRC_DIR = "path/to/src/dir"
|
52 | 56 | HMAC_KEY = "some-hmac-key"
|
| 57 | +ROLE_ARN = "arn:aws:iam::555555555555:role/my_execution_role_arn" |
53 | 58 |
|
54 | 59 |
|
55 | 60 | def describe_training_job_response(job_status):
|
@@ -175,6 +180,37 @@ def square(x):
|
175 | 180 | assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE
|
176 | 181 |
|
177 | 182 |
|
| 183 | +@patch( |
| 184 | + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", |
| 185 | + return_value=EXPECTED_JOB_RESULT, |
| 186 | +) |
| 187 | +@patch("sagemaker.remote_function.client._Job.start") |
| 188 | +@patch("sagemaker.remote_function.job.Session") |
| 189 | +def test_decorator_with_config_file(session, mock_start, mock_deserialize_obj_from_s3): |
| 190 | + session().get_caller_identity_arn = lambda: ROLE_ARN |
| 191 | + session().sagemaker_config = load_sagemaker_config( |
| 192 | + additional_config_paths=[os.path.join(DATA_DIR, "remote_function")] |
| 193 | + ) |
| 194 | + |
| 195 | + mock_job = Mock(job_name=TRAINING_JOB_NAME) |
| 196 | + mock_job.describe.return_value = COMPLETED_TRAINING_JOB |
| 197 | + |
| 198 | + mock_start.return_value = mock_job |
| 199 | + |
| 200 | + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) |
| 201 | + def square(x): |
| 202 | + return x * x |
| 203 | + |
| 204 | + result = square(5) |
| 205 | + assert result == EXPECTED_JOB_RESULT |
| 206 | + assert square.job_settings.image_uri == IMAGE |
| 207 | + assert square.job_settings.s3_root_uri == S3_URI |
| 208 | + # assert values are read from sagemaker defaults config file |
| 209 | + assert square.job_settings.include_local_workdir is True |
| 210 | + assert square.job_settings.custom_file_filter.ignore_name_patterns == ["data", "test"] |
| 211 | + assert square.job_settings.s3_kms_key == "someS3KmsKey" |
| 212 | + |
| 213 | + |
178 | 214 | @patch(
|
179 | 215 | "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3",
|
180 | 216 | return_value=EXPECTED_JOB_RESULT,
|
@@ -609,6 +645,37 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
|
609 | 645 | assert future_4.done()
|
610 | 646 |
|
611 | 647 |
|
| 648 | +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) |
| 649 | +@patch("sagemaker.remote_function.client._Job.start") |
| 650 | +@patch("sagemaker.session.Session") |
| 651 | +def test_executor_submit_with_config_file(session, mock_start): |
| 652 | + session().get_caller_identity_arn = lambda: ROLE_ARN |
| 653 | + session().sagemaker_config = load_sagemaker_config( |
| 654 | + additional_config_paths=[os.path.join(DATA_DIR, "remote_function")] |
| 655 | + ) |
| 656 | + |
| 657 | + mock_job = create_mock_job("job_1", COMPLETED_TRAINING_JOB) |
| 658 | + mock_start.side_effect = [mock_job] |
| 659 | + |
| 660 | + with RemoteExecutor( |
| 661 | + max_parallel_jobs=1, |
| 662 | + s3_root_uri="s3://bucket/", |
| 663 | + image_uri=IMAGE, |
| 664 | + sagemaker_session=session(), |
| 665 | + ) as e: |
| 666 | + future = e.submit(job_function, 1, 2, c=3, d=4) |
| 667 | + |
| 668 | + # assert values are read from sagemaker defaults config file |
| 669 | + assert e.job_settings.include_local_workdir is True |
| 670 | + assert e.job_settings.custom_file_filter.ignore_name_patterns == ["data", "test"] |
| 671 | + assert e.job_settings.s3_kms_key == "someS3KmsKey" |
| 672 | + |
| 673 | + mock_start.assert_called_with(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None) |
| 674 | + mock_job.describe.assert_called() |
| 675 | + |
| 676 | + assert future.done() |
| 677 | + |
| 678 | + |
612 | 679 | @patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
|
613 | 680 | @patch("sagemaker.remote_function.client._JobSettings")
|
614 | 681 | @patch("sagemaker.remote_function.client._Job.start")
|
|
0 commit comments