Skip to content

Commit 75e7645

Browse files
authored
Merge pull request aws#23 from verdimrc/fp-mxnet-uni-test
added a unit test for mxnet processor
2 parents e9cdf75 + d3d41b8 commit 75e7645

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/unit/test_processing.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
from platform import python_version
1415

1516
import pytest
1617
from mock import Mock, patch, MagicMock
@@ -33,6 +34,7 @@
3334
from sagemaker.pytorch.processing import PyTorchProcessor
3435
from sagemaker.utils import get_config_value
3536
from sagemaker.xgboost.processing import XGBoostEstimator
37+
from sagemaker.mxnet.processing import MXNetProcessor
3638
from sagemaker.network import NetworkConfig
3739
from sagemaker.processing import FeatureStoreOutput
3840
from sagemaker.fw_utils import UploadedCode
@@ -390,6 +392,44 @@ def test_xgboost_with_required_parameters(
390392

391393
sagemaker_session.process.assert_called_with(**expected_args)
392394

395+
@patch("sagemaker.utils._botocore_resolver")
396+
@patch("os.path.exists", return_value=True)
397+
@patch("os.path.isfile", return_value=True)
398+
def test_mxnet_with_required_parameters(
399+
exists_mock, isfile_mock, botocore_resolver, sagemaker_session, mxnet_training_version, mxnet_training_py_version
400+
):
401+
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
402+
403+
processor = MXNetProcessor(
404+
role=ROLE,
405+
instance_type="ml.m4.xlarge",
406+
framework_version=mxnet_training_version,
407+
py_version=mxnet_training_py_version,
408+
instance_count=1,
409+
sagemaker_session=sagemaker_session,
410+
)
411+
412+
processor.run(code="/local/path/to/processing_code.py")
413+
414+
expected_args = _get_expected_args_modular_code(processor._current_job_name)
415+
416+
if (mxnet_training_py_version == 'py3') & (mxnet_training_version == "1.4"): #probably there is a better way to handle this
417+
mxnet_image_uri = (
418+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:{}-cpu-{}"
419+
).format(mxnet_training_version, mxnet_training_py_version)
420+
elif version.parse(mxnet_training_version) > version.parse("1.4.1" if mxnet_training_py_version == 'py2' else "1.4"):
421+
mxnet_image_uri = (
422+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:{}-cpu-{}"
423+
).format(mxnet_training_version, mxnet_training_py_version)
424+
else:
425+
mxnet_image_uri = (
426+
"520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-cpu-{}"
427+
).format(mxnet_training_version, mxnet_training_py_version)
428+
429+
expected_args["app_specification"]["ImageUri"] = mxnet_image_uri
430+
431+
sagemaker_session.process.assert_called_with(**expected_args)
432+
393433

394434
@patch("os.path.exists", return_value=False)
395435
def test_script_processor_errors_with_nonexistent_local_code(exists_mock, sagemaker_session):

0 commit comments

Comments
 (0)