Skip to content

Commit 1091279

Browse files
committed
fix visualizer for pipeline processing job steps
1 parent 6948f17 commit 1091279

File tree

8 files changed

+102
-49
lines changed

8 files changed

+102
-49
lines changed

src/sagemaker/lineage/visualizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step):
105105
return None
106106

107107
metadata = pipeline_execution_step["Metadata"]
108-
jobs = ["TrainingJob", "ProccessingJob", "TransformJob"]
108+
jobs = ["TrainingJob", "ProcessingJob", "TransformJob"]
109109
for job in jobs:
110110
if job in metadata and metadata[job]:
111111
job_arn = metadata[job]["Arn"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.lineage import visualizer
17+
import unittest.mock
18+
19+
20+
@pytest.fixture
21+
def sagemaker_session():
22+
return unittest.mock.Mock()
23+
24+
25+
@pytest.fixture
26+
def viz(sagemaker_session):
27+
return visualizer.LineageTableVisualizer(sagemaker_session)

tests/unit/sagemaker/lineage/test_action.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import action, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_action.return_value = {
2923
"ActionArn": "bazz",

tests/unit/sagemaker/lineage/test_artifact.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import artifact, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.create_artifact.return_value = {
2923
"ArtifactArn": "bazz",

tests/unit/sagemaker/lineage/test_association.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
import datetime
1616
import unittest.mock
1717

18-
import pytest
1918
from sagemaker.lineage import association, _api_types
2019

2120

22-
@pytest.fixture
23-
def sagemaker_session():
24-
return unittest.mock.Mock()
25-
26-
2721
def test_create(sagemaker_session):
2822
sagemaker_session.sagemaker_client.add_association.return_value = {
2923
"AssociationArn": "bazz",

tests/unit/sagemaker/lineage/test_endpoint_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import context, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_models(sagemaker_session):
2721
obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz")
2822

tests/unit/sagemaker/lineage/test_model_artifact.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414

1515
import unittest.mock
1616

17-
import pytest
1817
from sagemaker.lineage import artifact, _api_types
1918

2019

21-
@pytest.fixture
22-
def sagemaker_session():
23-
return unittest.mock.Mock()
24-
25-
2620
def test_trained_models(sagemaker_session):
2721
model_artifact_obj = artifact.ModelArtifact(
2822
sagemaker_session, artifact_arn="model-artifact-arn"

tests/unit/sagemaker/lineage/test_visualizer.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,21 @@
1414

1515
import unittest.mock
1616

17-
import pytest
18-
from sagemaker.lineage import visualizer
1917
import pandas as pd
2018
from collections import OrderedDict
2119

2220

23-
@pytest.fixture
24-
def sagemaker_session():
25-
return unittest.mock.Mock()
26-
27-
28-
@pytest.fixture
29-
def vizualizer(sagemaker_session):
30-
return visualizer.LineageTableVisualizer(sagemaker_session)
31-
32-
33-
def test_friendly_name_short_uri(vizualizer, sagemaker_session):
21+
def test_friendly_name_short_uri(viz, sagemaker_session):
3422
uri = "s3://f-069083975568/train.txt"
3523
arn = "test_arn"
3624
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
3725
"Source": {"SourceUri": uri, "SourceTypes": ""}
3826
}
39-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
27+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
4028
assert uri == actual_name
4129

4230

43-
def test_friendly_name_long_uri(vizualizer, sagemaker_session):
31+
def test_friendly_name_long_uri(viz, sagemaker_session):
4432
uri = (
4533
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
4634
"preprocessed-data/tuning_data/train.txt"
@@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
4937
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
5038
"Source": {"SourceUri": uri, "SourceTypes": ""}
5139
}
52-
actual_name = vizualizer._get_friendly_name(name=None, arn=arn, entity_type="artifact")
40+
actual_name = viz._get_friendly_name(name=None, arn=arn, entity_type="artifact")
5341
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
5442
assert expected_name == actual_name
5543

5644

57-
def test_trial_component_name(sagemaker_session, vizualizer):
45+
def test_trial_component_name(viz, sagemaker_session):
5846
name = "tc-name"
5947

6048
sagemaker_session.sagemaker_client.describe_trial_component.return_value = {
@@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
9078
},
9179
]
9280

93-
df = vizualizer.show(trial_component_name=name)
81+
df = viz.show(trial_component_name=name)
9482

9583
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
9684
TrialComponentName=name,
@@ -119,3 +107,71 @@ def test_trial_component_name(sagemaker_session, vizualizer):
119107
)
120108

121109
pd.testing.assert_frame_equal(expected_dataframe, df)
110+
111+
112+
def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
113+
114+
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
115+
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
116+
}
117+
118+
sagemaker_session.sagemaker_client.list_associations.side_effect = [
119+
{
120+
"AssociationSummaries": [
121+
{
122+
"SourceArn": "a:b:c:d:e:artifact/src-arn-1",
123+
"SourceName": "source-name-1",
124+
"SourceType": "source-type-1",
125+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
126+
"DestinationName": "dest-name-1",
127+
"DestinationType": "dest-type-1",
128+
"AssociationType": "type-1",
129+
}
130+
]
131+
},
132+
{
133+
"AssociationSummaries": [
134+
{
135+
"SourceArn": "a:b:c:d:e:artifact/src-arn-2",
136+
"SourceName": "source-name-2",
137+
"SourceType": "source-type-2",
138+
"DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
139+
"DestinationName": "dest-name-2",
140+
"DestinationType": "dest-type-2",
141+
"AssociationType": "type-2",
142+
}
143+
]
144+
},
145+
]
146+
147+
step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
148+
149+
df = viz.show(pipeline_execution_step=step)
150+
151+
sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
152+
SourceArn="proc-job-arn",
153+
)
154+
155+
expected_calls = [
156+
unittest.mock.call(
157+
DestinationArn="tc-arn",
158+
),
159+
unittest.mock.call(
160+
SourceArn="tc-arn",
161+
),
162+
]
163+
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
164+
165+
expected_dataframe = pd.DataFrame.from_dict(
166+
OrderedDict(
167+
[
168+
("Name/Source", ["source-name-1", "dest-name-2"]),
169+
("Direction", ["Input", "Output"]),
170+
("Type", ["source-type-1", "dest-type-2"]),
171+
("Association Type", ["type-1", "type-2"]),
172+
("Lineage Type", ["artifact", "artifact"]),
173+
]
174+
)
175+
)
176+
177+
pd.testing.assert_frame_equal(expected_dataframe, df)

0 commit comments

Comments
 (0)