Skip to content

Commit 06dacca

Browse files
committed
change: clean up Pipeline unit tests
1 parent d4203da commit 06dacca

File tree

8 files changed

+44
-340
lines changed

8 files changed

+44
-340
lines changed

tests/unit/sagemaker/workflow/conftest.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616

1717
import pytest
1818

19-
from sagemaker import Session
2019
from sagemaker.workflow.pipeline_context import PipelineSession
2120

2221
REGION = "us-west-2"
2322
BUCKET = "my-bucket"
2423
ROLE = "DummyRole"
2524
IMAGE_URI = "fakeimage"
25+
INSTANCE_TYPE = "ml.m4.xlarge"
2626

2727

2828
@pytest.fixture(scope="module")
29-
def client():
29+
def mock_client():
3030
"""Mock client.
3131
3232
Considerations when appropriate:
@@ -38,11 +38,12 @@ def client():
3838
client_mock._client_config.user_agent = (
3939
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
4040
)
41+
client_mock.describe_model.return_value = {"PrimaryContainer": {}, "Containers": {}}
4142
return client_mock
4243

4344

4445
@pytest.fixture(scope="module")
45-
def boto_session(client):
46+
def mock_boto_session(client):
4647
role_mock = Mock()
4748
type(role_mock).arn = PropertyMock(return_value=ROLE)
4849

@@ -57,19 +58,9 @@ def boto_session(client):
5758

5859

5960
@pytest.fixture(scope="module")
60-
def pipeline_session(boto_session, client):
61+
def pipeline_session(mock_boto_session, mock_client):
6162
return PipelineSession(
62-
boto_session=boto_session,
63-
sagemaker_client=client,
64-
default_bucket=BUCKET,
65-
)
66-
67-
68-
@pytest.fixture(scope="module")
69-
def sagemaker_session(boto_session, client):
70-
return Session(
71-
boto_session=boto_session,
72-
sagemaker_client=client,
73-
sagemaker_runtime_client=client,
63+
boto_session=mock_boto_session,
64+
sagemaker_client=mock_client,
7465
default_bucket=BUCKET,
7566
)

tests/unit/sagemaker/workflow/test_automl_step.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,57 +15,14 @@
1515
import json
1616

1717
import pytest
18-
from mock import Mock, PropertyMock
1918
from sagemaker.automl.automl import AutoML, AutoMLInput
2019
from sagemaker.exceptions import AutoMLStepInvalidModeError
2120
from sagemaker.workflow import ParameterString
2221

2322
from sagemaker.workflow.automl_step import AutoMLStep
2423
from sagemaker.workflow.model_step import ModelStep
2524
from sagemaker.workflow.pipeline import Pipeline
26-
from sagemaker.workflow.pipeline_context import PipelineSession
27-
28-
REGION = "us-west-2"
29-
BUCKET = "my-bucket"
30-
ROLE = "DummyRole"
31-
32-
33-
@pytest.fixture
34-
def client():
35-
"""Mock client.
36-
Considerations when appropriate:
37-
* utilize botocore.stub.Stubber
38-
* separate runtime client from client
39-
"""
40-
client_mock = Mock()
41-
client_mock._client_config.user_agent = (
42-
"Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
43-
)
44-
return client_mock
45-
46-
47-
@pytest.fixture
48-
def boto_session(client):
49-
role_mock = Mock()
50-
type(role_mock).arn = PropertyMock(return_value=ROLE)
51-
52-
resource_mock = Mock()
53-
resource_mock.Role.return_value = role_mock
54-
55-
session_mock = Mock(region_name=REGION)
56-
session_mock.resource.return_value = resource_mock
57-
session_mock.client.return_value = client
58-
59-
return session_mock
60-
61-
62-
@pytest.fixture
63-
def pipeline_session(boto_session, client):
64-
return PipelineSession(
65-
boto_session=boto_session,
66-
sagemaker_client=client,
67-
default_bucket=BUCKET,
68-
)
25+
from tests.unit.sagemaker.workflow.conftest import ROLE
6926

7027

7128
def test_single_automl_step(pipeline_session):

tests/unit/sagemaker/workflow/test_model_step.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@
5454
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
5555
from tests.unit import DATA_DIR
5656
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
57-
from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE
58-
59-
_IMAGE_URI = "fakeimage"
60-
_INSTANCE_TYPE = "ml.m4.xlarge"
57+
from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE, IMAGE_URI, INSTANCE_TYPE
6158

6259
_SAGEMAKER_PROGRAM = SCRIPT_PARAM_NAME.upper()
6360
_SAGEMAKER_SUBMIT_DIRECTORY = DIR_PARAM_NAME.upper()
@@ -79,7 +76,7 @@ def model_data_param():
7976
def model(pipeline_session, model_data_param):
8077
return Model(
8178
name="MyModel",
82-
image_uri=_IMAGE_URI,
79+
image_uri=IMAGE_URI,
8380
model_data=model_data_param,
8481
sagemaker_session=pipeline_session,
8582
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
@@ -159,7 +156,7 @@ def test_register_model_with_runtime_repack(pipeline_session, model_data_param,
159156
assert arguments["ModelApprovalStatus"] == "PendingManualApproval"
160157
assert len(arguments["InferenceSpecification"]["Containers"]) == 1
161158
container = arguments["InferenceSpecification"]["Containers"][0]
162-
assert container["Image"] == _IMAGE_URI
159+
assert container["Image"] == IMAGE_URI
163160
assert container["ModelDataUrl"] == {
164161
"Get": f"Steps.{expected_repack_step_name}.ModelArtifacts.S3ModelArtifacts"
165162
}
@@ -238,7 +235,7 @@ def test_create_model_with_runtime_repack(pipeline_session, model_data_param, mo
238235
assert step["Name"] == f"MyModelStep-{_CREATE_MODEL_NAME_BASE}"
239236
arguments = step["Arguments"]
240237
container = arguments["PrimaryContainer"]
241-
assert container["Image"] == _IMAGE_URI
238+
assert container["Image"] == IMAGE_URI
242239
assert container["ModelDataUrl"] == {
243240
"Get": f"Steps.{expected_repack_step_name}.ModelArtifacts.S3ModelArtifacts"
244241
}
@@ -335,7 +332,7 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_
335332
assert containers[0]["ModelDataUrl"] == {"Get": "Parameters.ModelData"}
336333
assert containers[1]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME
337334
assert containers[1]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME
338-
assert containers[1]["Image"] == _IMAGE_URI
335+
assert containers[1]["Image"] == IMAGE_URI
339336
assert containers[1]["ModelDataUrl"] == {
340337
"Get": f"Steps.{expected_repack_step_name}.ModelArtifacts.S3ModelArtifacts"
341338
}
@@ -371,7 +368,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat
371368
)
372369
# The model need to runtime repack
373370
model = Model(
374-
image_uri=_IMAGE_URI,
371+
image_uri=IMAGE_URI,
375372
model_data=model_data_param,
376373
sagemaker_session=pipeline_session,
377374
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
@@ -431,7 +428,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat
431428
assert containers[0]["ModelDataUrl"] == {"Get": "Parameters.ModelData"}
432429
assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME
433430
assert "s3://" in containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY]
434-
assert containers[1]["Image"] == _IMAGE_URI
431+
assert containers[1]["Image"] == IMAGE_URI
435432
assert containers[1]["ModelDataUrl"] == {
436433
"Get": f"Steps.{expected_repack_step_name}.ModelArtifacts.S3ModelArtifacts"
437434
}
@@ -459,7 +456,7 @@ def test_register_model_without_repack(pipeline_session):
459456
model_name = "MyModel"
460457
model = Model(
461458
name=model_name,
462-
image_uri=_IMAGE_URI,
459+
image_uri=IMAGE_URI,
463460
model_data=model_data,
464461
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
465462
sagemaker_session=pipeline_session,
@@ -489,7 +486,7 @@ def test_register_model_without_repack(pipeline_session):
489486
assert arguments["ModelApprovalStatus"] == "PendingManualApproval"
490487
containers = arguments["InferenceSpecification"]["Containers"]
491488
assert len(containers) == 1
492-
assert containers[0]["Image"] == _IMAGE_URI
489+
assert containers[0]["Image"] == IMAGE_URI
493490
assert containers[0]["ModelDataUrl"] == {"Get": "Parameters.ModelData"}
494491
assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME
495492
assert (
@@ -506,7 +503,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session):
506503
model_name = "MyModel"
507504
model = Model(
508505
name=model_name,
509-
image_uri=_IMAGE_URI,
506+
image_uri=IMAGE_URI,
510507
model_data=f"s3://{BUCKET}/model.tar.gz",
511508
sagemaker_session=pipeline_session,
512509
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
@@ -527,7 +524,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session):
527524
assert len(step_dsl_list) == 2
528525
assert step_dsl_list[0]["Name"] == "MyModelStep-CreateModel"
529526
arguments = step_dsl_list[0]["Arguments"]
530-
assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI
527+
assert arguments["PrimaryContainer"]["Image"] == IMAGE_URI
531528
assert (
532529
arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{BUCKET}/{model_name}/model.tar.gz"
533530
)
@@ -609,7 +606,7 @@ def test_conditional_model_create_and_regis(
609606
assert arguments["ModelApprovalStatus"] == "PendingManualApproval"
610607
assert len(arguments["InferenceSpecification"]["Containers"]) == 1
611608
container = arguments["InferenceSpecification"]["Containers"][0]
612-
assert container["Image"] == _IMAGE_URI
609+
assert container["Image"] == IMAGE_URI
613610
assert container["ModelDataUrl"] == {
614611
"Get": f"Steps.{expected_repack_step_name}.ModelArtifacts.S3ModelArtifacts"
615612
}
@@ -619,7 +616,7 @@ def test_conditional_model_create_and_regis(
619616
assert step["Name"] == f"MyModelStepCreate-{_CREATE_MODEL_NAME_BASE}"
620617
arguments = step["Arguments"]
621618
container = arguments["PrimaryContainer"]
622-
assert container["Image"] == _IMAGE_URI
619+
assert container["Image"] == IMAGE_URI
623620
assert container["ModelDataUrl"] == {"Get": "Parameters.ModelData"}
624621
assert not container.get("Environment", {})
625622
else:
@@ -645,7 +642,7 @@ def test_conditional_model_create_and_regis(
645642
SKLearnModel(
646643
name="MySKModel",
647644
model_data="dummy_model_data",
648-
image_uri=_IMAGE_URI,
645+
image_uri=IMAGE_URI,
649646
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
650647
role=ROLE,
651648
enable_network_isolation=True,
@@ -658,7 +655,7 @@ def test_conditional_model_create_and_regis(
658655
name="MYXGBoostModel",
659656
model_data="dummy_model_data",
660657
framework_version="1.11.0",
661-
image_uri=_IMAGE_URI,
658+
image_uri=IMAGE_URI,
662659
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
663660
role=ROLE,
664661
enable_network_isolation=False,
@@ -669,7 +666,7 @@ def test_conditional_model_create_and_regis(
669666
PyTorchModel(
670667
name="MyPyTorchModel",
671668
model_data="dummy_model_data",
672-
image_uri=_IMAGE_URI,
669+
image_uri=IMAGE_URI,
673670
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
674671
role=ROLE,
675672
framework_version="1.5.0",
@@ -681,7 +678,7 @@ def test_conditional_model_create_and_regis(
681678
MXNetModel(
682679
name="MyMXNetModel",
683680
model_data="dummy_model_data",
684-
image_uri=_IMAGE_URI,
681+
image_uri=IMAGE_URI,
685682
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
686683
role=ROLE,
687684
framework_version="1.2.0",
@@ -692,7 +689,7 @@ def test_conditional_model_create_and_regis(
692689
HuggingFaceModel(
693690
name="MyHuggingFaceModel",
694691
model_data="dummy_model_data",
695-
image_uri=_IMAGE_URI,
692+
image_uri=IMAGE_URI,
696693
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
697694
role=ROLE,
698695
),
@@ -702,7 +699,7 @@ def test_conditional_model_create_and_regis(
702699
TensorFlowModel(
703700
name="MyTensorFlowModel",
704701
model_data="dummy_model_data",
705-
image_uri=_IMAGE_URI,
702+
image_uri=IMAGE_URI,
706703
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
707704
role=ROLE,
708705
code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH,
@@ -713,7 +710,7 @@ def test_conditional_model_create_and_regis(
713710
ChainerModel(
714711
name="MyChainerModel",
715712
model_data="dummy_model_data",
716-
image_uri=_IMAGE_URI,
713+
image_uri=IMAGE_URI,
717714
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
718715
role=ROLE,
719716
),
@@ -824,7 +821,7 @@ def assert_test_result(steps: list):
824821
TensorFlowModel(
825822
model_data="dummy_model_step",
826823
role=ROLE,
827-
image_uri=_IMAGE_URI,
824+
image_uri=IMAGE_URI,
828825
entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"),
829826
),
830827
{
@@ -840,7 +837,7 @@ def assert_test_result(steps: list):
840837
TensorFlowModel(
841838
model_data="dummy_model_step",
842839
role=ROLE,
843-
image_uri=_IMAGE_URI,
840+
image_uri=IMAGE_URI,
844841
),
845842
{
846843
"expected_step_num": 1,
@@ -975,10 +972,10 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
975972
[
976973
(
977974
Processor(
978-
image_uri=_IMAGE_URI,
975+
image_uri=IMAGE_URI,
979976
role=ROLE,
980977
instance_count=1,
981-
instance_type=_INSTANCE_TYPE,
978+
instance_type=INSTANCE_TYPE,
982979
),
983980
dict(target_fun="run", func_args={}),
984981
),
@@ -999,8 +996,8 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
999996
estimator=Estimator(
1000997
role=ROLE,
1001998
instance_count=1,
1002-
instance_type=_INSTANCE_TYPE,
1003-
image_uri=_IMAGE_URI,
999+
instance_type=INSTANCE_TYPE,
1000+
image_uri=IMAGE_URI,
10041001
),
10051002
objective_metric_name="test:acc",
10061003
hyperparameter_ranges={"batch-size": IntegerParameter(64, 128)},
@@ -1011,8 +1008,8 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
10111008
Estimator(
10121009
role=ROLE,
10131010
instance_count=1,
1014-
instance_type=_INSTANCE_TYPE,
1015-
image_uri=_IMAGE_URI,
1011+
instance_type=INSTANCE_TYPE,
1012+
image_uri=IMAGE_URI,
10161013
),
10171014
dict(target_fun="fit", func_args={}),
10181015
),

0 commit comments

Comments
 (0)