Skip to content

Commit aeeee7f

Browse files
committed
change: Add more tests
1 parent b32fe65 commit aeeee7f

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def runtime_env_manager():
5050
return mocked_runtime_env_manager
5151

5252

53+
def custom_file_filter():
54+
pass
55+
56+
5357
@pytest.fixture
5458
def remote_decorator_config(sagemaker_session):
5559
return Mock(
@@ -73,6 +77,24 @@ def config_uploader(remote_decorator_config, runtime_env_manager):
7377
return ConfigUploader(remote_decorator_config, runtime_env_manager)
7478

7579

80+
@pytest.fixture
81+
def remote_decorator_config_with_filter(sagemaker_session):
82+
return Mock(
83+
_JobSettings,
84+
sagemaker_session=sagemaker_session,
85+
s3_root_uri="some_s3_uri",
86+
s3_kms_key="some_kms",
87+
spark_config=SparkConfig(),
88+
dependencies=None,
89+
include_local_workdir=True,
90+
pre_execution_commands="some_commands",
91+
pre_execution_script="some_path",
92+
python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
93+
environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
94+
custom_file_filter=custom_file_filter,
95+
)
96+
97+
7698
@patch("sagemaker.feature_store.feature_processor._config_uploader.StoredFunction")
7799
def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func):
78100
mock_stored_function.save(wrapped_func).return_value = None
@@ -113,6 +135,41 @@ def test_prepare_and_upload_dependencies(mock_upload, config_uploader):
113135
)
114136

115137

138+
@patch(
139+
"sagemaker.feature_store.feature_processor._config_uploader._prepare_and_upload_dependencies",
140+
return_value="some_s3_uri",
141+
)
142+
def test_prepare_and_upload_dependencies_with_filter(
143+
mock_job_upload, remote_decorator_config_with_filter, runtime_env_manager
144+
):
145+
config_uploader_with_filter = ConfigUploader(
146+
remote_decorator_config=remote_decorator_config_with_filter,
147+
runtime_env_manager=runtime_env_manager,
148+
)
149+
remote_decorator_config = config_uploader_with_filter.remote_decorator_config
150+
config_uploader_with_filter._prepare_and_upload_dependencies(
151+
local_dependencies_path="some/path/to/dependency",
152+
include_local_workdir=True,
153+
pre_execution_commands=remote_decorator_config.pre_execution_commands,
154+
pre_execution_script_local_path=remote_decorator_config.pre_execution_script,
155+
s3_base_uri=remote_decorator_config.s3_root_uri,
156+
s3_kms_key=remote_decorator_config.s3_kms_key,
157+
sagemaker_session=sagemaker_session,
158+
custom_file_filter=remote_decorator_config_with_filter.custom_file_filter,
159+
)
160+
161+
mock_job_upload.assert_called_once_with(
162+
local_dependencies_path="some/path/to/dependency",
163+
include_local_workdir=True,
164+
pre_execution_commands=remote_decorator_config.pre_execution_commands,
165+
pre_execution_script_local_path=remote_decorator_config.pre_execution_script,
166+
s3_base_uri=remote_decorator_config.s3_root_uri,
167+
s3_kms_key=remote_decorator_config.s3_kms_key,
168+
sagemaker_session=sagemaker_session,
169+
custom_file_filter=custom_file_filter,
170+
)
171+
172+
116173
@patch(
117174
"sagemaker.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts",
118175
return_value="some_s3_uri",

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,32 @@ def square(x):
175175
assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE
176176

177177

178+
@patch(
179+
"sagemaker.remote_function.client.serialization.deserialize_obj_from_s3",
180+
return_value=EXPECTED_JOB_RESULT,
181+
)
182+
@patch("sagemaker.remote_function.client._JobSettings")
183+
@patch("sagemaker.remote_function.client._Job.start")
184+
def test_decorator_with_custom_file_filter(
185+
mock_start, mock_job_settings, mock_deserialize_obj_from_s3
186+
):
187+
mock_job = Mock(job_name=TRAINING_JOB_NAME)
188+
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
189+
190+
mock_start.return_value = mock_job
191+
192+
def custom_file_filter():
193+
pass
194+
195+
@remote(image_uri=IMAGE, s3_root_uri=S3_URI, custom_file_filter=custom_file_filter)
196+
def square(x):
197+
return x * x
198+
199+
result = square(5)
200+
assert result == EXPECTED_JOB_RESULT
201+
assert mock_job_settings.call_args.kwargs["custom_file_filter"] == custom_file_filter
202+
203+
178204
@patch(
179205
"sagemaker.remote_function.client.serialization.deserialize_exception_from_s3",
180206
return_value=ZeroDivisionError("division by zero"),

0 commit comments

Comments
 (0)