|
31 | 31 | )
|
32 | 32 | from sagemaker.sklearn.processing import SKLearnProcessor
|
33 | 33 | from sagemaker.pytorch.processing import PyTorchProcessor
|
| 34 | +from sagemaker.xgboost.processing import XGBoostEstimator |
34 | 35 | from sagemaker.network import NetworkConfig
|
35 | 36 | from sagemaker.processing import FeatureStoreOutput
|
36 | 37 | from sagemaker.fw_utils import UploadedCode
|
@@ -341,6 +342,40 @@ def test_pytorch_processor_with_required_parameters(
|
341 | 342 |
|
342 | 343 | sagemaker_session.process.assert_called_with(**expected_args)
|
343 | 344 |
|
| 345 | +@patch("sagemaker.utils._botocore_resolver") |
| 346 | +@patch("os.path.exists", return_value=True) |
| 347 | +@patch("os.path.isfile", return_value=True) |
| 348 | +def test_xgboost_with_required_parameters( |
| 349 | + exists_mock, isfile_mock, botocore_resolver, sagemaker_session, xgboost_framework_version |
| 350 | +): |
| 351 | + botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME} |
| 352 | + |
| 353 | + processor = XGBoostEstimator( |
| 354 | + role=ROLE, |
| 355 | + instance_type="ml.m4.xlarge", |
| 356 | + framework_version=xgboost_framework_version, |
| 357 | + instance_count=1, |
| 358 | + sagemaker_session=sagemaker_session, |
| 359 | + ) |
| 360 | + |
| 361 | + processor.run(code="/local/path/to/processing_code.py") |
| 362 | + |
| 363 | + expected_args = _get_expected_args_modular_code(processor._current_job_name) |
| 364 | + |
| 365 | + if version.parse(xgboost_framework_version) < version.parse("1.2-1"): |
| 366 | + xgboost_image_uri = ( |
| 367 | + "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{}-cpu-py3" |
| 368 | + ).format(xgboost_framework_version) |
| 369 | + else: |
| 370 | + xgboost_image_uri = ( |
| 371 | + "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{}" |
| 372 | + ).format(xgboost_framework_version) |
| 373 | + |
| 374 | + expected_args["app_specification"]["ImageUri"] = xgboost_image_uri |
| 375 | + |
| 376 | + sagemaker_session.process.assert_called_with(**expected_args) |
| 377 | + |
| 378 | + |
344 | 379 | @patch("os.path.exists", return_value=False)
|
345 | 380 | def test_script_processor_errors_with_nonexistent_local_code(exists_mock, sagemaker_session):
|
346 | 381 | processor = _get_script_processor(sagemaker_session)
|
|
0 commit comments