Skip to content

Commit 987bbe6

Browse files
authored
fix: prevent multiple values error in sklearn.transformer() (#978)
1 parent b7a2b9c commit 987bbe6

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/sagemaker/sklearn/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ def create_model(
151151
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
152152
"""
153153
role = role or self.role
154+
155+
# remove unwanted entry_point kwarg
156+
if "entry_point" in kwargs:
157+
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
158+
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
159+
154160
return SKLearnModel(
155161
self.model_data,
156162
role,

tests/unit/test_sklearn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,26 @@ def test_sklearn(strftime, sagemaker_session, sklearn_version):
302302
assert isinstance(predictor, SKLearnPredictor)
303303

304304

305+
def test_transform_multiple_values_for_entry_point_issue(sagemaker_session, sklearn_version):
306+
# https://github.com/aws/sagemaker-python-sdk/issues/974
307+
sklearn = SKLearn(
308+
entry_point=SCRIPT_PATH,
309+
role=ROLE,
310+
sagemaker_session=sagemaker_session,
311+
train_instance_type=INSTANCE_TYPE,
312+
py_version=PYTHON_VERSION,
313+
framework_version=sklearn_version,
314+
)
315+
316+
inputs = "s3://mybucket/train"
317+
318+
sklearn.fit(inputs=inputs)
319+
320+
transformer = sklearn.transformer(instance_count=1, instance_type="ml.m4.xlarge")
321+
# if we got here, we didn't get a "multiple values" error
322+
assert transformer is not None
323+
324+
305325
def test_fail_distributed_training(sagemaker_session, sklearn_version):
306326
with pytest.raises(AttributeError) as error:
307327
SKLearn(

0 commit comments

Comments
 (0)