Skip to content

Commit b66cb98

Browse files
danabensajaykarpur
andauthored
fix: map user context is list associations response (#2238)
Co-authored-by: Ajay Karpur <[email protected]>
1 parent a167396 commit b66cb98

File tree

6 files changed

+60
-13
lines changed

6 files changed

+60
-13
lines changed

src/sagemaker/lineage/_api_types.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,36 @@ class ContextSummary(_base_types.ApiObject):
181181
last_modified_time = None
182182

183183

184+
class UserContext(_base_types.ApiObject):
185+
"""Summary model of a user context.
186+
187+
Attributes:
188+
user_profile_arn (str): User profile ARN.
189+
user_profile_name (str): User profile name.
190+
domain_id (str): DomainId.
191+
"""
192+
193+
user_profile_arn = None
194+
user_profile_name = None
195+
domain_id = None
196+
197+
def __init__(self, user_profile_arn=None, user_profile_name=None, domain_id=None, **kwargs):
198+
"""Initialize UserContext.
199+
200+
Args:
201+
user_profile_arn (str): User profile ARN.
202+
user_profile_name (str): User profile name.
203+
domain_id (str): DomainId.
204+
**kwargs: Arbitrary keyword arguments.
205+
"""
206+
super(UserContext, self).__init__(
207+
user_profile_arn=user_profile_arn,
208+
user_profile_name=user_profile_name,
209+
domain_id=domain_id,
210+
**kwargs
211+
)
212+
213+
184214
class AssociationSummary(_base_types.ApiObject):
185215
"""Summary model of an association.
186216
@@ -196,6 +226,9 @@ class AssociationSummary(_base_types.ApiObject):
196226
created_by (obj): Context on creator.
197227
"""
198228

229+
_custom_boto_types = {
230+
"created_by": (UserContext, False),
231+
}
199232
source_arn = None
200233
source_name = None
201234
destination_arn = None
@@ -204,4 +237,3 @@ class AssociationSummary(_base_types.ApiObject):
204237
destination_type = None
205238
association_type = None
206239
creation_time = None
207-
created_by = None

tests/integ/sagemaker/lineage/test_association.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def test_list(association_objs, sagemaker_session):
5656
# sanity check
5757
assert association_keys_listed
5858

59+
for listed_asscn in listed:
60+
assert listed_asscn.created_by is None
61+
5962

6063
@pytest.mark.timeout(30)
6164
def test_set_tag(association_obj, sagemaker_session):

tests/unit/sagemaker/lineage/test_association.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ def test_list(sagemaker_session):
9292
"DestinationType": "D" + str(i),
9393
"AssociationType": "E" + str(i),
9494
"CreationTime": creation_time + datetime.timedelta(hours=i),
95-
"CreatedBy": {},
95+
"CreatedBy": {
96+
"UserProfileArn": "profileArn",
97+
"UserProfileName": "profileName",
98+
"DomainId": "domainId",
99+
},
96100
}
97101
for i in range(10)
98102
],
@@ -109,7 +113,11 @@ def test_list(sagemaker_session):
109113
"DestinationType": "D" + str(i),
110114
"AssociationType": "E" + str(i),
111115
"CreationTime": creation_time + datetime.timedelta(hours=i),
112-
"CreatedBy": {},
116+
"CreatedBy": {
117+
"UserProfileArn": "profileArn",
118+
"UserProfileName": "profileName",
119+
"DomainId": "domainId",
120+
},
113121
}
114122
for i in range(10, 20)
115123
]
@@ -126,7 +134,11 @@ def test_list(sagemaker_session):
126134
destination_type="D" + str(i),
127135
association_type="E" + str(i),
128136
creation_time=creation_time + datetime.timedelta(hours=i),
129-
created_by={},
137+
created_by=_api_types.UserContext(
138+
user_profile_arn="profileArn",
139+
user_profile_name="profileName",
140+
domain_id="domainId",
141+
),
130142
)
131143
for i in range(20)
132144
]

tests/unit/sagemaker/lineage/test_dataset_artifact.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_trained_models(sagemaker_session):
4242
"DestinationType": "ModelDeployment",
4343
"AssociationType": "E1",
4444
"CreationTime": None,
45-
"CreatedBy": {},
45+
"CreatedBy": None,
4646
}
4747
],
4848
},
@@ -57,7 +57,7 @@ def test_trained_models(sagemaker_session):
5757
"DestinationType": "Context",
5858
"AssociationType": "E2",
5959
"CreationTime": None,
60-
"CreatedBy": {},
60+
"CreatedBy": None,
6161
}
6262
]
6363
},
@@ -79,7 +79,7 @@ def test_trained_models(sagemaker_session):
7979
destination_type="Context",
8080
association_type="E2",
8181
creation_time=None,
82-
created_by={},
82+
created_by=None,
8383
)
8484
]
8585
assert expected_model_list == model_list

tests/unit/sagemaker/lineage/test_endpoint_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_models(sagemaker_session):
3232
"DestinationType": "ModelDeployment",
3333
"AssociationType": "E1",
3434
"CreationTime": None,
35-
"CreatedBy": {},
35+
"CreatedBy": None,
3636
}
3737
],
3838
},
@@ -47,7 +47,7 @@ def test_models(sagemaker_session):
4747
"DestinationType": "Model",
4848
"AssociationType": "E2",
4949
"CreationTime": None,
50-
"CreatedBy": {},
50+
"CreatedBy": None,
5151
}
5252
]
5353
},
@@ -71,7 +71,7 @@ def test_models(sagemaker_session):
7171
destination_type="Model",
7272
association_type="E2",
7373
creation_time=None,
74-
created_by={},
74+
created_by=None,
7575
)
7676
]
7777
assert expected_model_list == model_list

tests/unit/sagemaker/lineage/test_model_artifact.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_trained_models(sagemaker_session):
3434
"DestinationType": "Action",
3535
"AssociationType": "E1",
3636
"CreationTime": None,
37-
"CreatedBy": {},
37+
"CreatedBy": None,
3838
}
3939
],
4040
},
@@ -49,7 +49,7 @@ def test_trained_models(sagemaker_session):
4949
"DestinationType": "Context",
5050
"AssociationType": "E2",
5151
"CreationTime": None,
52-
"CreatedBy": {},
52+
"CreatedBy": None,
5353
}
5454
]
5555
},
@@ -71,7 +71,7 @@ def test_trained_models(sagemaker_session):
7171
destination_type="Context",
7272
association_type="E2",
7373
creation_time=None,
74-
created_by={},
74+
created_by=None,
7575
)
7676
]
7777
assert expected_model_list == endpoint_context_list

0 commit comments

Comments
 (0)