Skip to content

Commit 4312db9

Browse files
authored
Merge pull request aws#21 from verdimrc/fp-xgboost-unit-test
added unit test for xgboost processor
2 parents d06ff5e + 8b9ff09 commit 4312db9

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/unit/test_processing.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from sagemaker.sklearn.processing import SKLearnProcessor
3333
from sagemaker.pytorch.processing import PyTorchProcessor
34+
from sagemaker.xgboost.processing import XGBoostEstimator
3435
from sagemaker.network import NetworkConfig
3536
from sagemaker.processing import FeatureStoreOutput
3637
from sagemaker.fw_utils import UploadedCode
@@ -341,6 +342,40 @@ def test_pytorch_processor_with_required_parameters(
341342

342343
sagemaker_session.process.assert_called_with(**expected_args)
343344

345+
@patch("sagemaker.utils._botocore_resolver")
346+
@patch("os.path.exists", return_value=True)
347+
@patch("os.path.isfile", return_value=True)
348+
def test_xgboost_with_required_parameters(
349+
exists_mock, isfile_mock, botocore_resolver, sagemaker_session, xgboost_framework_version
350+
):
351+
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
352+
353+
processor = XGBoostEstimator(
354+
role=ROLE,
355+
instance_type="ml.m4.xlarge",
356+
framework_version=xgboost_framework_version,
357+
instance_count=1,
358+
sagemaker_session=sagemaker_session,
359+
)
360+
361+
processor.run(code="/local/path/to/processing_code.py")
362+
363+
expected_args = _get_expected_args_modular_code(processor._current_job_name)
364+
365+
if version.parse(xgboost_framework_version) < version.parse("1.2-1"):
366+
xgboost_image_uri = (
367+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{}-cpu-py3"
368+
).format(xgboost_framework_version)
369+
else:
370+
xgboost_image_uri = (
371+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{}"
372+
).format(xgboost_framework_version)
373+
374+
expected_args["app_specification"]["ImageUri"] = xgboost_image_uri
375+
376+
sagemaker_session.process.assert_called_with(**expected_args)
377+
378+
344379
@patch("os.path.exists", return_value=False)
345380
def test_script_processor_errors_with_nonexistent_local_code(exists_mock, sagemaker_session):
346381
processor = _get_script_processor(sagemaker_session)

0 commit comments

Comments
 (0)