Skip to content

Commit d73bb97

Browse files
Balaji SankarRuban Hussain
authored andcommitted
fix: fix broken unit tests due to refactoring
1 parent 6086451 commit d73bb97

File tree

14 files changed

+84
-11
lines changed

14 files changed

+84
-11
lines changed

src/sagemaker/local/local_session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,9 @@ def _initialize(
682682
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
683683
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
684684
self.sagemaker_config = (
685-
sagemaker_config if sagemaker_config else fetch_sagemaker_config(
686-
s3_resource=self.s3_resource)
685+
sagemaker_config
686+
if sagemaker_config
687+
else fetch_sagemaker_config(s3_resource=self.s3_resource)
687688
)
688689
else:
689690
self.sagemaker_config = (

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def sagemaker_session(
166166
sagemaker_client=sagemaker_client,
167167
sagemaker_runtime_client=runtime_client,
168168
sagemaker_metrics_client=metrics_client,
169-
sagemaker_config={"SchemaVersion": "1.0"},
169+
sagemaker_config={},
170170
)
171171

172172

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def test_deploy_accelerator_type(
168168
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
169169
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
170170
def test_deploy_endpoint_name(sagemaker_session):
171+
sagemaker_session.sagemaker_config = {}
171172
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
172173

173174
endpoint_name = "blah"
@@ -371,6 +372,7 @@ def test_deploy_async_inference(production_variant, name_from_base, sagemaker_se
371372
@patch("sagemaker.model.Model._create_sagemaker_model")
372373
@patch("sagemaker.production_variant")
373374
def test_deploy_serverless_inference(production_variant, create_sagemaker_model, sagemaker_session):
375+
sagemaker_session.sagemaker_config = {}
374376
model = Model(
375377
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
376378
)
@@ -439,8 +441,8 @@ def test_deploy_wrong_serverless_config(sagemaker_session):
439441
@patch("sagemaker.session.Session")
440442
@patch("sagemaker.local.LocalSession")
441443
def test_deploy_creates_correct_session(local_session, session):
442-
local_session.sagemaker_config = {}
443-
session.sagemaker_config = {}
444+
local_session.return_value.sagemaker_config = {}
445+
session.return_value.sagemaker_config = {}
444446
# We expect a LocalSession when deploying to instance_type = 'local'
445447
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE)
446448
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)

tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def test_git_support_repo_not_provided(sagemaker_session):
208208
),
209209
)
210210
def test_git_support_git_clone_fail(sagemaker_session):
211+
sagemaker_session.sagemaker_config = {}
211212
entry_point = "source_dir/entry_point"
212213
git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH}
213214
with pytest.raises(subprocess.CalledProcessError) as error:
@@ -257,6 +258,7 @@ def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session):
257258
side_effect=ValueError("Entry point does not exist in the repo."),
258259
)
259260
def test_git_support_entry_point_not_exist(sagemaker_session):
261+
sagemaker_session.sagemaker_config = {}
260262
entry_point = "source_dir/entry_point"
261263
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
262264
with pytest.raises(ValueError) as error:
@@ -272,6 +274,7 @@ def test_git_support_entry_point_not_exist(sagemaker_session):
272274
side_effect=ValueError("Source directory does not exist in the repo."),
273275
)
274276
def test_git_support_source_dir_not_exist(sagemaker_session):
277+
sagemaker_session.sagemaker_config = {}
275278
entry_point = "entry_point"
276279
source_dir = "source_dir_that_does_not_exist"
277280
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
@@ -291,6 +294,7 @@ def test_git_support_source_dir_not_exist(sagemaker_session):
291294
side_effect=ValueError("Dependency no-such-dir does not exist in the repo."),
292295
)
293296
def test_git_support_dependencies_not_exist(sagemaker_session):
297+
sagemaker_session.sagemaker_config = {}
294298
entry_point = "entry_point"
295299
dependencies = ["foo", "no_such_dir"]
296300
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}

tests/unit/sagemaker/model/test_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def test_create_sagemaker_model_generates_model_name_each_time(
326326
@patch("sagemaker.session.Session")
327327
@patch("sagemaker.local.LocalSession")
328328
def test_create_sagemaker_model_creates_correct_session(local_session, session):
329+
local_session.return_value.sagemaker_config = {}
330+
session.return_value.sagemaker_config = {}
329331
model = Model(MODEL_IMAGE, MODEL_DATA)
330332
model._create_sagemaker_model("local")
331333
assert model.sagemaker_session == local_session.return_value
@@ -433,6 +435,8 @@ def test_model_create_transformer_base_name(sagemaker_session):
433435
@patch("sagemaker.session.Session")
434436
@patch("sagemaker.local.LocalSession")
435437
def test_transformer_creates_correct_session(local_session, session):
438+
local_session.return_value.sagemaker_config = {}
439+
session.return_value.sagemaker_config = {}
436440
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=None)
437441
transformer = model.transformer(instance_count=1, instance_type="local")
438442
assert model.sagemaker_session == local_session.return_value

tests/unit/sagemaker/model/test_neo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def test_compile_model_for_cloud_tflite(sagemaker_session):
170170
@patch("sagemaker.session.Session")
171171
def test_compile_creates_session(session):
172172
session.return_value.boto_region_name = REGION
173+
session.return_value.sagemaker_config = {}
173174

174175
model = _create_model()
175176
model.compile(
@@ -313,6 +314,7 @@ def test_compile_with_framework_version_16(sagemaker_session):
313314
@patch("sagemaker.session.Session")
314315
def test_compile_with_pytorch_neo_in_ml_inf(session):
315316
session.return_value.boto_region_name = REGION
317+
session.return_value.sagemaker_config = {}
316318

317319
model = _create_model()
318320
model.compile(
@@ -336,6 +338,7 @@ def test_compile_with_pytorch_neo_in_ml_inf(session):
336338
@patch("sagemaker.session.Session")
337339
def test_compile_with_tensorflow_neo_in_ml_inf(session):
338340
session.return_value.boto_region_name = REGION
341+
session.return_value.sagemaker_config = {}
339342

340343
model = _create_model()
341344
model.compile(

tests/unit/sagemaker/monitor/test_model_monitoring.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,8 +877,14 @@ def test_data_quality_batch_transform_monitor_create_schedule_with_sagemaker_con
877877
data_quality_monitor,
878878
sagemaker_session,
879879
):
880+
from sagemaker.utils import get_config_value
880881

881882
sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MONITORING_SCHEDULE
883+
sagemaker_session._append_sagemaker_config_tags = Mock(
884+
name="_append_sagemaker_config_tags",
885+
side_effect=lambda tags, config_path_to_tags: tags
886+
+ get_config_value(config_path_to_tags, SAGEMAKER_CONFIG_MONITORING_SCHEDULE),
887+
)
882888

883889
sagemaker_session.sagemaker_client.create_monitoring_schedule = Mock()
884890
data_quality_monitor.sagemaker_session = sagemaker_session

tests/unit/sagemaker/spark/test_processing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def sagemaker_session():
6161
settings=SessionSettings(),
6262
)
6363
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
64+
session_mock.sagemaker_config = {}
6465

6566
return session_mock
6667

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock):
7171

7272
def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock):
7373
# For tests which doesn't verify config file injection, operate with empty config
74+
pipeline_role_arn = "arn:aws:iam::111111111111:role/ConfigRole"
7475
sagemaker_session_mock.sagemaker_config = {
75-
"SageMaker": {"Pipeline": {"RoleArn": "ConfigRoleArn"}}
76+
"SchemaVersion": "1.0",
77+
"SageMaker": {"Pipeline": {"RoleArn": pipeline_role_arn}},
7678
}
7779
sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = {
7880
"PipelineArn": "pipeline-arn"
@@ -85,15 +87,21 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock
8587
)
8688
pipeline.create()
8789
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with(
88-
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn"
90+
PipelineName="MyPipeline",
91+
PipelineDefinition=pipeline.definition(),
92+
RoleArn=pipeline_role_arn,
8993
)
9094
pipeline.update()
9195
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
92-
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn"
96+
PipelineName="MyPipeline",
97+
PipelineDefinition=pipeline.definition(),
98+
RoleArn=pipeline_role_arn,
9399
)
94100
pipeline.upsert()
95101
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
96-
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn="ConfigRoleArn"
102+
PipelineName="MyPipeline",
103+
PipelineDefinition=pipeline.definition(),
104+
RoleArn=pipeline_role_arn,
97105
)
98106

99107

@@ -129,6 +137,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
129137

130138
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
131139
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
140+
sagemaker_session_mock.sagemaker_config = {}
132141
parameter = ParameterString("MyStr")
133142
pipeline = Pipeline(
134143
name="MyPipeline",
@@ -151,6 +160,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
151160

152161

153162
def test_pipeline_update(sagemaker_session_mock, role_arn):
163+
sagemaker_session_mock.sagemaker_config = {}
154164
pipeline = Pipeline(
155165
name="MyPipeline",
156166
parameters=[],
@@ -201,6 +211,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
201211

202212
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
203213
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
214+
sagemaker_session_mock.sagemaker_config = {}
204215
parameter = ParameterString("MyStr")
205216
pipeline = Pipeline(
206217
name="MyPipeline",

tests/unit/sagemaker/workflow/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,11 @@ def test_inject_repack_script_s3(estimator, tmp, fake_s3):
226226
model_data = Properties(step_name="MyStep", shape_name="DescribeModelOutput")
227227
entry_point = "inference.py"
228228
source_dir_path = "s3://fake/location"
229+
session_mock = fake_s3.sagemaker_session
230+
session_mock.sagemaker_config = {}
229231
step = _RepackModelStep(
230232
name="MyRepackModelStep",
231-
sagemaker_session=fake_s3.sagemaker_session,
233+
sagemaker_session=session_mock,
232234
role=estimator.role,
233235
image_uri="foo",
234236
model_data=model_data,

0 commit comments

Comments
 (0)