File tree Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -123,6 +123,7 @@ def __init__(
123
123
content_type = None ,
124
124
content_template = None ,
125
125
custom_attributes = None ,
126
+ accelerator_type = None ,
126
127
):
127
128
"""Initializes a configuration of a model and the endpoint to be created for it.
128
129
@@ -151,6 +152,9 @@ def __init__(
151
152
Section 3.3.6. Field Value Components (
152
153
https://tools.ietf.org/html/rfc7230#section-3.2.6) of the Hypertext Transfer
153
154
Protocol (HTTP/1.1).
155
+ accelerator_type (str): The Elastic Inference accelerator type to deploy to the model
156
+ endpoint instance for making inferences to the model, see
157
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
154
158
"""
155
159
self .predictor_config = {
156
160
"model_name" : model_name ,
@@ -178,9 +182,8 @@ def __init__(
178
182
f" Please include a placeholder $features."
179
183
)
180
184
self .predictor_config ["content_template" ] = content_template
181
-
182
- if custom_attributes is not None :
183
- self .predictor_config ["custom_attributes" ] = custom_attributes
185
+ _set (custom_attributes , "custom_attributes" , self .predictor_config )
186
+ _set (accelerator_type , "accelerator_type" , self .predictor_config )
184
187
185
188
def get_predictor_config (self ):
186
189
"""Returns part of the predictor dictionary of the analysis config."""
Original file line number Diff line number Diff line change @@ -92,13 +92,15 @@ def test_model_config():
92
92
accept_type = "text/csv"
93
93
content_type = "application/jsonlines"
94
94
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
95
+ accelerator_type = "ml.eia1.medium"
95
96
model_config = ModelConfig (
96
97
model_name = model_name ,
97
98
instance_type = instance_type ,
98
99
instance_count = instance_count ,
99
100
accept_type = accept_type ,
100
101
content_type = content_type ,
101
102
custom_attributes = custom_attributes ,
103
+ accelerator_type = accelerator_type ,
102
104
)
103
105
expected_config = {
104
106
"model_name" : model_name ,
@@ -107,6 +109,7 @@ def test_model_config():
107
109
"accept_type" : accept_type ,
108
110
"content_type" : content_type ,
109
111
"custom_attributes" : custom_attributes ,
112
+ "accelerator_type" : accelerator_type ,
110
113
}
111
114
assert expected_config == model_config .get_predictor_config ()
112
115
You can’t perform that action at this time.
0 commit comments