Skip to content

Commit 321b168

Browse files
chenliu0831knakad
authored andcommitted
fix: Fix linear learner crash when num_class is string and predict type is multiclass_classifier (aws#246)
1 parent 7a7db0a commit 321b168

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/sagemaker/amazon/linear_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def __init__(
366366
self.balance_multiclass_weights = balance_multiclass_weights
367367

368368
if self.predictor_type == "multiclass_classifier" and (
369-
num_classes is None or num_classes < 3
369+
num_classes is None or int(num_classes) < 3
370370
):
371371
raise ValueError(
372372
"For predictor_type 'multiclass_classifier', 'num_classes' should be set to a "

tests/unit/test_linear_learner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,13 @@ def test_num_classes_is_required_for_multiclass_classifier(sagemaker_session):
210210
)
211211

212212

213+
def test_num_classes_can_be_string_for_multiclass_classifier(sagemaker_session):
214+
test_params = ALL_REQ_ARGS.copy()
215+
test_params["predictor_type"] = "multiclass_classifier"
216+
test_params["num_classes"] = "3"
217+
LinearLearner(sagemaker_session=sagemaker_session, **test_params)
218+
219+
213220
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)])
214221
def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value):
215222
with pytest.raises(TypeError):
@@ -374,7 +381,7 @@ def test_prepare_for_training_multiple_channel_no_train(sagemaker_session):
374381
with pytest.raises(ValueError) as ex:
375382
lr._prepare_for_training([data, data])
376383

377-
assert "Must provide train channel." in str(ex)
384+
assert "Must provide train channel." in str(ex)
378385

379386

380387
@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit")

0 commit comments

Comments
 (0)