Skip to content

Commit 720faab

Browse files
nargokulpintaoz-aws
authored andcommitted
Morpheus tests (#1633)
* Parameterized intelligent defaults tests * Parameterized intelligent defaults tests * Parameterized intelligent defaults tests * Tests for all Model Builder deployment modes * Fix * CodeStyle Fixes * CodeStyle Fixes * Add Deepdiff dependency * Add Deepdiff dependency * Add Codestyle fix
1 parent 4f31237 commit 720faab

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ uvicorn>=0.30.1
4949
fastapi==0.115.4
5050
nest-asyncio
5151
sagemaker-mlflow>=0.1.0
52+
deepdiff>=8.0.0

src/sagemaker/modules/train/model_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ class ModelTrainer(BaseModel):
235235
"role",
236236
"base_job_name",
237237
"source_code",
238-
"distributed",
239238
"compute",
240239
"networking",
241240
"stopping_condition",
@@ -251,7 +250,6 @@ class ModelTrainer(BaseModel):
251250

252251
SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = {
253252
"source_code": SourceCode,
254-
"distributed": DistributedConfig,
255253
"compute": Compute,
256254
"networking": Networking,
257255
"stopping_condition": StoppingCondition,

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
3636
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
3737
from sagemaker.s3 import S3Downloader
38-
from sagemaker import Session, utils
38+
from sagemaker import Session
3939
from sagemaker.model import Model
4040
from sagemaker.base_predictor import PredictorBase
4141
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
@@ -112,7 +112,7 @@
112112
validate_image_uri_and_hardware,
113113
)
114114
from sagemaker.serverless import ServerlessInferenceConfig
115-
from sagemaker.utils import Tags
115+
from sagemaker.utils import Tags, unique_name_from_base
116116
from sagemaker.workflow.entities import PipelineVariable
117117
from sagemaker.huggingface.llm_utils import (
118118
get_huggingface_model_metadata,
@@ -1621,7 +1621,7 @@ def deploy(
16211621
"""
16221622
if not hasattr(self, "built_model"):
16231623
raise ValueError("Model Needs to be built before deploying")
1624-
endpoint_name = utils.unique_name_from_base(endpoint_name)
1624+
endpoint_name = unique_name_from_base(endpoint_name)
16251625
if not inference_config: # Real-time Deployment
16261626
return self.built_model.deploy(
16271627
instance_type=self.instance_type,

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,57 @@ def test_train_with_default_params(mock_training_job, model_trainer):
194194
training_job_instance.wait.assert_called_once_with(logs=True)
195195

196196

197+
@pytest.mark.parametrize(
198+
"default_config",
199+
[
200+
{
201+
"path_name": "sourceCode",
202+
"config_value": {"command": "echo 'Hello World' && env"},
203+
},
204+
{
205+
"path_name": "compute",
206+
"config_value": {"volume_size_in_gb": 45},
207+
},
208+
{
209+
"path_name": "networking",
210+
"config_value": {
211+
"enable_network_isolation": True,
212+
"security_group_ids": ["sg-1"],
213+
"subnets": ["subnet-1"],
214+
},
215+
},
216+
{
217+
"path_name": "stoppingCondition",
218+
"config_value": {"max_runtime_in_seconds": 15},
219+
},
220+
{
221+
"path_name": "trainingImageConfig",
222+
"config_value": {"training_repository_access_mode": "private"},
223+
},
224+
{
225+
"path_name": "outputDataConfig",
226+
"config_value": {"s3_output_path": "Sample S3 path"},
227+
},
228+
{
229+
"path_name": "checkpointConfig",
230+
"config_value": {"s3_uri": "sample uri"},
231+
},
232+
],
233+
)
197234
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
198235
@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config")
199236
@patch("sagemaker.modules.train.model_trainer.ModelTrainer.create_input_data_channel")
200237
def test_train_with_intelligent_defaults(
201-
mock_create_input_data_channel, mock_resolve_value_from_config, mock_training_job, model_trainer
238+
mock_create_input_data_channel,
239+
mock_resolve_value_from_config,
240+
mock_training_job,
241+
default_config,
242+
model_trainer,
202243
):
203-
source_code_path = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, MODEL_TRAINER, "sourceCode")
204-
205244
mock_resolve_value_from_config.side_effect = lambda **kwargs: (
206-
{"command": "echo 'Hello World' && env"}
207-
if kwargs["config_path"] == source_code_path
245+
default_config["config_value"]
246+
if kwargs["config_path"]
247+
== _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, MODEL_TRAINER, default_config["path_name"])
208248
else None
209249
)
210250

0 commit comments

Comments
 (0)