12
12
# language governing permissions and limitations under the License.
13
13
"""This module contains helper methods for tests of SageMaker Lineage"""
14
14
from __future__ import absolute_import
15
+ from urllib import response
15
16
16
17
import uuid
17
18
from datetime import datetime
@@ -86,9 +87,12 @@ def visit(arn, visited: set):
86
87
87
88
88
89
class LineageResourceHelper :
89
- def __init__ (self ):
90
- self .client = boto3 . client ( "sagemaker" , config = Config ( connect_timeout = 5 , read_timeout = 60 , retries = { 'max_attempts' : 20 }))
90
+ def __init__ (self , sagemaker_session ):
91
+ self .client = sagemaker_session . sagemaker_client
91
92
self .artifacts = []
93
+ self .actions = []
94
+ self .contexts = []
95
+ self .trialComponents = []
92
96
self .associations = []
93
97
94
98
def create_artifact (self , artifact_name , artifact_type = "Dataset" ):
@@ -106,6 +110,42 @@ def create_artifact(self, artifact_name, artifact_type="Dataset"):
106
110
107
111
return response ["ArtifactArn" ]
108
112
113
+ def create_action (self , action_name , action_type = "ModelDeployment" ):
114
+ response = self .client .create_action (
115
+ ActionName = action_name ,
116
+ Source = {
117
+ "SourceUri" : "Test-action-" + action_name ,
118
+ "SourceTypes" : [
119
+ {"SourceIdType" : "S3ETag" , "Value" : "Test-action-sourceId-value" },
120
+ ],
121
+ },
122
+ ActionType = action_type
123
+ )
124
+ self .actions .append (response ["ActionArn" ])
125
+
126
+ return response ["ActionArn" ]
127
+
128
+ def create_context (self , context_name , context_type = "Endpoint" ):
129
+ response = self .client .create_context (
130
+ ContextName = context_name ,
131
+ Source = {
132
+ "SourceUri" : "Test-context-" + context_name ,
133
+ "SourceTypes" : [
134
+ {"SourceIdType" : "S3ETag" , "Value" : "Test-context-sourceId-value" },
135
+ ],
136
+ },
137
+ ContextType = context_type
138
+ )
139
+ self .contexts .append (response ["ContextArn" ])
140
+
141
+ return response ["ContextArn" ]
142
+
143
+ def create_trialComponent (self , trialComponent_name , trialComponent_type = "TrainingJob" ):
144
+ response = self .client .create_trial_component (
145
+ TrialComponentName = trialComponent_name ,
146
+
147
+ )
148
+
109
149
def create_association (self , source_arn , dest_arn , association_type = "AssociatedWith" ):
110
150
response = self .client .add_association (
111
151
SourceArn = source_arn , DestinationArn = dest_arn , AssociationType = association_type
@@ -130,3 +170,10 @@ def clean_all(self):
130
170
time .sleep (0.5 )
131
171
except Exception as e :
132
172
print ("skipped " + str (e ))
173
+
174
+ for action_arn in self .actions :
175
+ try :
176
+ self .client .delete_action (ActionArn = action_arn )
177
+ time .sleep (0.5 )
178
+ except Exception as e :
179
+ print ("skipped " + str (e ))
0 commit comments