@@ -68,8 +68,9 @@ def prepare(
68
68
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
69
69
) from e
70
70
71
+ upload_artifacts = None
71
72
if self .model_server == ModelServer .TORCHSERVE :
72
- return self ._upload_torchserve_artifacts (
73
+ upload_artifacts = self ._upload_torchserve_artifacts (
73
74
model_path = model_path ,
74
75
sagemaker_session = sagemaker_session ,
75
76
secret_key = secret_key ,
@@ -78,7 +79,7 @@ def prepare(
78
79
)
79
80
80
81
if self .model_server == ModelServer .TRITON :
81
- return self ._upload_triton_artifacts (
82
+ upload_artifacts = self ._upload_triton_artifacts (
82
83
model_path = model_path ,
83
84
sagemaker_session = sagemaker_session ,
84
85
secret_key = secret_key ,
@@ -87,15 +88,15 @@ def prepare(
87
88
)
88
89
89
90
if self .model_server == ModelServer .DJL_SERVING :
90
- return self ._upload_djl_artifacts (
91
+ upload_artifacts = self ._upload_djl_artifacts (
91
92
model_path = model_path ,
92
93
sagemaker_session = sagemaker_session ,
93
94
s3_model_data_url = s3_model_data_url ,
94
95
image = image ,
95
96
)
96
97
97
98
if self .model_server == ModelServer .TGI :
98
- return self ._upload_tgi_artifacts (
99
+ upload_artifacts = self ._upload_tgi_artifacts (
99
100
model_path = model_path ,
100
101
sagemaker_session = sagemaker_session ,
101
102
s3_model_data_url = s3_model_data_url ,
@@ -104,15 +105,15 @@ def prepare(
104
105
)
105
106
106
107
if self .model_server == ModelServer .MMS :
107
- return self ._upload_server_artifacts (
108
+ upload_artifacts = self ._upload_server_artifacts (
108
109
model_path = model_path ,
109
110
sagemaker_session = sagemaker_session ,
110
111
s3_model_data_url = s3_model_data_url ,
111
112
image = image ,
112
113
)
113
114
114
115
if self .model_server == ModelServer .TENSORFLOW_SERVING :
115
- return self ._upload_tensorflow_serving_artifacts (
116
+ upload_artifacts = self ._upload_tensorflow_serving_artifacts (
116
117
model_path = model_path ,
117
118
sagemaker_session = sagemaker_session ,
118
119
secret_key = secret_key ,
@@ -121,11 +122,14 @@ def prepare(
121
122
)
122
123
123
124
if self .model_server == ModelServer .TEI :
124
- return self ._upload_tei_artifacts (
125
+ upload_artifacts = self ._upload_tei_artifacts (
125
126
model_path = model_path ,
126
127
sagemaker_session = sagemaker_session ,
127
128
s3_model_data_url = s3_model_data_url ,
128
129
image = image ,
129
130
)
130
131
132
+ if isinstance (self .model_server , ModelServer ) and upload_artifacts :
133
+ return upload_artifacts
134
+
131
135
raise ValueError ("%s model server is not supported" % self .model_server )
0 commit comments