Skip to content

Commit 57a2384

Browse files
Danicywang86rui
authored andcommitted
fix: add kwargs to create_model for 1p to work with kms (#1081)
1 parent d12071f commit 57a2384

File tree

7 files changed

+21
-7
lines changed

7 files changed

+21
-7
lines changed

src/sagemaker/amazon/factorization_machines.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def __init__(
235235
self.factors_init_sigma = factors_init_sigma
236236
self.factors_init_value = factors_init_value
237237

238-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
238+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
239239
"""Return a :class:`~sagemaker.amazon.FactorizationMachinesModel`
240240
referencing the latest s3 model data produced by this Estimator.
241241
@@ -244,12 +244,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
244244
the model. Default: use subnets and security groups from this Estimator.
245245
* 'Subnets' (list[str]): List of subnet ids.
246246
* 'SecurityGroupIds' (list[str]): List of security group ids.
247+
**kwargs: Additional kwargs passed to the FactorizationMachinesModel constructor.
247248
"""
248249
return FactorizationMachinesModel(
249250
self.model_data,
250251
self.role,
251252
sagemaker_session=self.sagemaker_session,
252253
vpc_config=self.get_vpc_config(vpc_config_override),
254+
**kwargs
253255
)
254256

255257

src/sagemaker/amazon/ipinsights.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
self.shuffled_negative_sampling_rate = shuffled_negative_sampling_rate
132132
self.weight_decay = weight_decay
133133

134-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
134+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
135135
"""Create a model for the latest s3 model produced by this estimator.
136136
137137
Args:
@@ -140,6 +140,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
140140
Default: use subnets and security groups from this Estimator.
141141
* 'Subnets' (list[str]): List of subnet ids.
142142
* 'SecurityGroupIds' (list[str]): List of security group ids.
143+
**kwargs: Additional kwargs passed to the IPInsightsModel constructor.
143144
Returns:
144145
:class:`~sagemaker.amazon.IPInsightsModel`: references the latest s3 model
145146
data produced by this estimator.
@@ -149,6 +150,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
149150
self.role,
150151
sagemaker_session=self.sagemaker_session,
151152
vpc_config=self.get_vpc_config(vpc_config_override),
153+
**kwargs
152154
)
153155

154156
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

src/sagemaker/amazon/knn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
'"dimension_reduction_target" is required when "dimension_reduction_type" is set.'
146146
)
147147

148-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
148+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
149149
"""Return a :class:`~sagemaker.amazon.KNNModel` referencing the latest
150150
s3 model data produced by this Estimator.
151151
@@ -154,12 +154,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
154154
the model. Default: use subnets and security groups from this Estimator.
155155
* 'Subnets' (list[str]): List of subnet ids.
156156
* 'SecurityGroupIds' (list[str]): List of security group ids.
157+
**kwargs: Additional kwargs passed to the KNNModel constructor.
157158
"""
158159
return KNNModel(
159160
self.model_data,
160161
self.role,
161162
sagemaker_session=self.sagemaker_session,
162163
vpc_config=self.get_vpc_config(vpc_config_override),
164+
**kwargs
163165
)
164166

165167
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

src/sagemaker/amazon/ntm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(
155155
self.weight_decay = weight_decay
156156
self.learning_rate = learning_rate
157157

158-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
158+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
159159
"""Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest
160160
s3 model data produced by this Estimator.
161161
@@ -164,12 +164,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
164164
the model. Default: use subnets and security groups from this Estimator.
165165
* 'Subnets' (list[str]): List of subnet ids.
166166
* 'SecurityGroupIds' (list[str]): List of security group ids.
167+
**kwargs: Additional kwargs passed to the NTMModel constructor.
167168
"""
168169
return NTMModel(
169170
self.model_data,
170171
self.role,
171172
sagemaker_session=self.sagemaker_session,
172173
vpc_config=self.get_vpc_config(vpc_config_override),
174+
**kwargs
173175
)
174176

175177
def _prepare_for_training( # pylint: disable=signature-differs

src/sagemaker/amazon/object2vec.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __init__(
295295
self.enc0_freeze_pretrained_embedding = enc0_freeze_pretrained_embedding
296296
self.enc1_freeze_pretrained_embedding = enc1_freeze_pretrained_embedding
297297

298-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
298+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
299299
"""Return a :class:`~sagemaker.amazon.Object2VecModel` referencing the
300300
latest s3 model data produced by this Estimator.
301301
@@ -304,12 +304,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
304304
the model. Default: use subnets and security groups from this Estimator.
305305
* 'Subnets' (list[str]): List of subnet ids.
306306
* 'SecurityGroupIds' (list[str]): List of security group ids.
307+
**kwargs: Additional kwargs passed to the Object2VecModel constructor.
307308
"""
308309
return Object2VecModel(
309310
self.model_data,
310311
self.role,
311312
sagemaker_session=self.sagemaker_session,
312313
vpc_config=self.get_vpc_config(vpc_config_override),
314+
**kwargs
313315
)
314316

315317
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

src/sagemaker/amazon/pca.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
self.subtract_mean = subtract_mean
121121
self.extra_components = extra_components
122122

123-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
123+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
124124
"""Return a :class:`~sagemaker.amazon.pca.PCAModel` referencing the
125125
latest s3 model data produced by this Estimator.
126126
@@ -129,12 +129,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
129129
the model. Default: use subnets and security groups from this Estimator.
130130
* 'Subnets' (list[str]): List of subnet ids.
131131
* 'SecurityGroupIds' (list[str]): List of security group ids.
132+
**kwargs: Additional kwargs passed to the PCAModel constructor.
132133
"""
133134
return PCAModel(
134135
self.model_data,
135136
self.role,
136137
sagemaker_session=self.sagemaker_session,
137138
vpc_config=self.get_vpc_config(vpc_config_override),
139+
**kwargs
138140
)
139141

140142
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

src/sagemaker/amazon/randomcutforest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
self.num_trees = num_trees
114114
self.eval_metrics = eval_metrics
115115

116-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
116+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
117117
"""Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing
118118
the latest s3 model data produced by this Estimator.
119119
@@ -122,12 +122,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
122122
the model. Default: use subnets and security groups from this Estimator.
123123
* 'Subnets' (list[str]): List of subnet ids.
124124
* 'SecurityGroupIds' (list[str]): List of security group ids.
125+
**kwargs: Additional kwargs passed to the RandomCutForestModel constructor.
125126
"""
126127
return RandomCutForestModel(
127128
self.model_data,
128129
self.role,
129130
sagemaker_session=self.sagemaker_session,
130131
vpc_config=self.get_vpc_config(vpc_config_override),
132+
**kwargs
131133
)
132134

133135
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

0 commit comments

Comments
 (0)