Skip to content

Commit d3d41b8

Browse files
author
Panigrahi
committed
added a unit test for mxnet processor
1 parent 4312db9 commit d3d41b8

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/unit/test_processing.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
from platform import python_version
1415

1516
import pytest
1617
from mock import Mock, patch, MagicMock
@@ -32,6 +33,7 @@
3233
from sagemaker.sklearn.processing import SKLearnProcessor
3334
from sagemaker.pytorch.processing import PyTorchProcessor
3435
from sagemaker.xgboost.processing import XGBoostEstimator
36+
from sagemaker.mxnet.processing import MXNetProcessor
3537
from sagemaker.network import NetworkConfig
3638
from sagemaker.processing import FeatureStoreOutput
3739
from sagemaker.fw_utils import UploadedCode
@@ -375,6 +377,44 @@ def test_xgboost_with_required_parameters(
375377

376378
sagemaker_session.process.assert_called_with(**expected_args)
377379

380+
@patch("sagemaker.utils._botocore_resolver")
381+
@patch("os.path.exists", return_value=True)
382+
@patch("os.path.isfile", return_value=True)
383+
def test_mxnet_with_required_parameters(
384+
exists_mock, isfile_mock, botocore_resolver, sagemaker_session, mxnet_training_version, mxnet_training_py_version
385+
):
386+
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
387+
388+
processor = MXNetProcessor(
389+
role=ROLE,
390+
instance_type="ml.m4.xlarge",
391+
framework_version=mxnet_training_version,
392+
py_version=mxnet_training_py_version,
393+
instance_count=1,
394+
sagemaker_session=sagemaker_session,
395+
)
396+
397+
processor.run(code="/local/path/to/processing_code.py")
398+
399+
expected_args = _get_expected_args_modular_code(processor._current_job_name)
400+
401+
if (mxnet_training_py_version == 'py3') & (mxnet_training_version == "1.4"): #probably there is a better way to handle this
402+
mxnet_image_uri = (
403+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:{}-cpu-{}"
404+
).format(mxnet_training_version, mxnet_training_py_version)
405+
elif version.parse(mxnet_training_version) > version.parse("1.4.1" if mxnet_training_py_version == 'py2' else "1.4"):
406+
mxnet_image_uri = (
407+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:{}-cpu-{}"
408+
).format(mxnet_training_version, mxnet_training_py_version)
409+
else:
410+
mxnet_image_uri = (
411+
"520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-cpu-{}"
412+
).format(mxnet_training_version, mxnet_training_py_version)
413+
414+
expected_args["app_specification"]["ImageUri"] = mxnet_image_uri
415+
416+
sagemaker_session.process.assert_called_with(**expected_args)
417+
378418

379419
@patch("os.path.exists", return_value=False)
380420
def test_script_processor_errors_with_nonexistent_local_code(exists_mock, sagemaker_session):

0 commit comments

Comments
 (0)