|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
| 14 | +from platform import python_version |
14 | 15 |
|
15 | 16 | import pytest
|
16 | 17 | from mock import Mock, patch, MagicMock
|
|
32 | 33 | from sagemaker.sklearn.processing import SKLearnProcessor
|
33 | 34 | from sagemaker.pytorch.processing import PyTorchProcessor
|
34 | 35 | from sagemaker.xgboost.processing import XGBoostEstimator
|
| 36 | +from sagemaker.mxnet.processing import MXNetProcessor |
35 | 37 | from sagemaker.network import NetworkConfig
|
36 | 38 | from sagemaker.processing import FeatureStoreOutput
|
37 | 39 | from sagemaker.fw_utils import UploadedCode
|
@@ -375,6 +377,44 @@ def test_xgboost_with_required_parameters(
|
375 | 377 |
|
376 | 378 | sagemaker_session.process.assert_called_with(**expected_args)
|
377 | 379 |
|
| 380 | +@patch("sagemaker.utils._botocore_resolver") |
| 381 | +@patch("os.path.exists", return_value=True) |
| 382 | +@patch("os.path.isfile", return_value=True) |
| 383 | +def test_mxnet_with_required_parameters( |
| 384 | + exists_mock, isfile_mock, botocore_resolver, sagemaker_session, mxnet_training_version, mxnet_training_py_version |
| 385 | +): |
| 386 | + botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME} |
| 387 | + |
| 388 | + processor = MXNetProcessor( |
| 389 | + role=ROLE, |
| 390 | + instance_type="ml.m4.xlarge", |
| 391 | + framework_version=mxnet_training_version, |
| 392 | + py_version=mxnet_training_py_version, |
| 393 | + instance_count=1, |
| 394 | + sagemaker_session=sagemaker_session, |
| 395 | + ) |
| 396 | + |
| 397 | + processor.run(code="/local/path/to/processing_code.py") |
| 398 | + |
| 399 | + expected_args = _get_expected_args_modular_code(processor._current_job_name) |
| 400 | + |
| 401 | + if (mxnet_training_py_version == 'py3') & (mxnet_training_version == "1.4"): #probably there is a better way to handle this |
| 402 | + mxnet_image_uri = ( |
| 403 | + "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:{}-cpu-{}" |
| 404 | + ).format(mxnet_training_version, mxnet_training_py_version) |
| 405 | + elif version.parse(mxnet_training_version) > version.parse("1.4.1" if mxnet_training_py_version == 'py2' else "1.4"): |
| 406 | + mxnet_image_uri = ( |
| 407 | + "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:{}-cpu-{}" |
| 408 | + ).format(mxnet_training_version, mxnet_training_py_version) |
| 409 | + else: |
| 410 | + mxnet_image_uri = ( |
| 411 | + "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-cpu-{}" |
| 412 | + ).format(mxnet_training_version, mxnet_training_py_version) |
| 413 | + |
| 414 | + expected_args["app_specification"]["ImageUri"] = mxnet_image_uri |
| 415 | + |
| 416 | + sagemaker_session.process.assert_called_with(**expected_args) |
| 417 | + |
378 | 418 |
|
379 | 419 | @patch("os.path.exists", return_value=False)
|
380 | 420 | def test_script_processor_errors_with_nonexistent_local_code(exists_mock, sagemaker_session):
|
|
0 commit comments