Skip to content

fix: add kwargs to create_model for 1p to work with kms #1081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/factorization_machines.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
self.factors_init_sigma = factors_init_sigma
self.factors_init_value = factors_init_value

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.FactorizationMachinesModel`
referencing the latest s3 model data produced by this Estimator.

Expand All @@ -244,12 +244,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the FactorizationMachinesModel constructor.
"""
return FactorizationMachinesModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)


Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
self.shuffled_negative_sampling_rate = shuffled_negative_sampling_rate
self.weight_decay = weight_decay

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Create a model for the latest s3 model produced by this estimator.

Args:
Expand All @@ -140,6 +140,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the IPInsightsModel constructor.
Returns:
:class:`~sagemaker.amazon.IPInsightsModel`: references the latest s3 model
data produced by this estimator.
Expand All @@ -149,6 +150,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
'"dimension_reduction_target" is required when "dimension_reduction_type" is set.'
)

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.KNNModel` referencing the latest
s3 model data produced by this Estimator.

Expand All @@ -154,12 +154,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the KNNModel constructor.
"""
return KNNModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
self.weight_decay = weight_decay
self.learning_rate = learning_rate

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest
s3 model data produced by this Estimator.

Expand All @@ -164,12 +164,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the NTMModel constructor.
"""
return NTMModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training( # pylint: disable=signature-differs
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/object2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(
self.enc0_freeze_pretrained_embedding = enc0_freeze_pretrained_embedding
self.enc1_freeze_pretrained_embedding = enc1_freeze_pretrained_embedding

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.Object2VecModel` referencing the
latest s3 model data produced by this Estimator.

Expand All @@ -304,12 +304,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the Object2VecModel constructor.
"""
return Object2VecModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
self.subtract_mean = subtract_mean
self.extra_components = extra_components

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.pca.PCAModel` referencing the
latest s3 model data produced by this Estimator.

Expand All @@ -129,12 +129,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the PCAModel constructor.
"""
return PCAModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/randomcutforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.num_trees = num_trees
self.eval_metrics = eval_metrics

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing
the latest s3 model data produced by this Estimator.

Expand All @@ -122,12 +122,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the RandomCutForestModel constructor.
"""
return RandomCutForestModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down