Skip to content

Commit f668f3b

Browse files
authored
Merge branch 'master' into master-distributed-config-extensible
2 parents e49095b + 83ce1a0 commit f668f3b

File tree

12 files changed

+335
-25
lines changed

12 files changed

+335
-25
lines changed

src/sagemaker/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,8 @@ def is_repack(self) -> bool:
745745
Returns:
746746
bool: if the source need to be repacked or not
747747
"""
748+
if self.source_dir is None or self.entry_point is None:
749+
return False
748750
return self.source_dir and self.entry_point and not self.git_config
749751

750752
def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
@@ -2143,6 +2145,8 @@ def is_repack(self) -> bool:
21432145
Returns:
21442146
bool: if the source need to be repacked or not
21452147
"""
2148+
if self.source_dir is None or self.entry_point is None:
2149+
return False
21462150
return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
21472151

21482152

src/sagemaker/modules/train/model_trainer.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import json
1919
import shutil
2020
from tempfile import TemporaryDirectory
21-
2221
from typing import Optional, List, Union, Dict, Any, ClassVar
22+
import yaml
2323

2424
from graphene.utils.str_converters import to_camel_case, to_snake_case
2525

@@ -194,8 +194,9 @@ class ModelTrainer(BaseModel):
194194
Defaults to "File".
195195
environment (Optional[Dict[str, str]]):
196196
The environment variables for the training job.
197-
hyperparameters (Optional[Dict[str, Any]]):
198-
The hyperparameters for the training job.
197+
hyperparameters (Optional[Union[Dict[str, Any], str]):
198+
The hyperparameters for the training job. Can be a dictionary of hyperparameters
199+
or a path to hyperparameters json/yaml file.
199200
tags (Optional[List[Tag]]):
200201
An array of key-value pairs. You can use tags to categorize your AWS resources
201202
in different ways, for example, by purpose, owner, or environment.
@@ -225,7 +226,7 @@ class ModelTrainer(BaseModel):
225226
checkpoint_config: Optional[CheckpointConfig] = None
226227
training_input_mode: Optional[str] = "File"
227228
environment: Optional[Dict[str, str]] = {}
228-
hyperparameters: Optional[Dict[str, Any]] = {}
229+
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
229230
tags: Optional[List[Tag]] = None
230231
local_container_root: Optional[str] = os.getcwd()
231232

@@ -469,6 +470,29 @@ def model_post_init(self, __context: Any):
469470
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
470471
)
471472

473+
if self.hyperparameters and isinstance(self.hyperparameters, str):
474+
if not os.path.exists(self.hyperparameters):
475+
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
476+
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
477+
with open(self.hyperparameters, "r") as f:
478+
contents = f.read()
479+
try:
480+
self.hyperparameters = json.loads(contents)
481+
logger.debug("Hyperparameters loaded as JSON")
482+
except json.JSONDecodeError:
483+
try:
484+
logger.info(f"contents: {contents}")
485+
self.hyperparameters = yaml.safe_load(contents)
486+
if not isinstance(self.hyperparameters, dict):
487+
raise ValueError("YAML contents must be a valid mapping")
488+
logger.info(f"hyperparameters: {self.hyperparameters}")
489+
logger.debug("Hyperparameters loaded as YAML")
490+
except (yaml.YAMLError, ValueError):
491+
raise ValueError(
492+
f"Invalid hyperparameters file: {self.hyperparameters}. "
493+
"Must be a valid JSON or YAML file."
494+
)
495+
472496
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
473497
session = self.sagemaker_session
474498
base_job_name = self.base_job_name

src/sagemaker/pipeline.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import sagemaker
1919
from sagemaker import ModelMetrics, Model
20+
from sagemaker import local
21+
from sagemaker import session
2022
from sagemaker.config import (
2123
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
2224
MODEL_VPC_CONFIG_PATH,
@@ -560,3 +562,16 @@ def delete_model(self):
560562
raise ValueError("The SageMaker model must be created before attempting to delete.")
561563

562564
self.sagemaker_session.delete_model(self.name)
565+
566+
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
567+
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
568+
569+
The type of session object is determined by the instance type.
570+
"""
571+
if self.sagemaker_session:
572+
return
573+
574+
if instance_type in ("local", "local_gpu"):
575+
self.sagemaker_session = local.LocalSession(sagemaker_config=self._sagemaker_config)
576+
else:
577+
self.sagemaker_session = session.Session(sagemaker_config=self._sagemaker_config)

src/sagemaker/workflow/steps.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def arguments(self) -> RequestType:
645645
request_dict = self.step_args
646646
else:
647647
if isinstance(self.model, PipelineModel):
648+
self.model._init_sagemaker_session_if_does_not_exist()
648649
request_dict = self.model.sagemaker_session._create_model_request(
649650
name="",
650651
role=self.model.role,
@@ -653,6 +654,7 @@ def arguments(self) -> RequestType:
653654
enable_network_isolation=self.model.enable_network_isolation,
654655
)
655656
else:
657+
self.model._init_sagemaker_session_if_does_not_exist()
656658
request_dict = self.model.sagemaker_session._create_model_request(
657659
name="",
658660
role=self.model.role,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"integer": 1,
3+
"boolean": true,
4+
"float": 3.14,
5+
"string": "Hello World",
6+
"list": [1, 2, 3],
7+
"dict": {
8+
"string": "value",
9+
"integer": 3,
10+
"float": 3.14,
11+
"list": [1, 2, 3],
12+
"dict": {"key": "value"},
13+
"boolean": true
14+
}
15+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
integer: 1
2+
boolean: true
3+
float: 3.14
4+
string: "Hello World"
5+
list:
6+
- 1
7+
- 2
8+
- 3
9+
dict:
10+
string: value
11+
integer: 3
12+
float: 3.14
13+
list:
14+
- 1
15+
- 2
16+
- 3
17+
dict:
18+
key: value
19+
boolean: true
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
omegaconf

tests/data/modules/params_script/train.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import argparse
1717
import json
1818
import os
19+
from typing import List, Dict, Any
20+
from dataclasses import dataclass
21+
from omegaconf import OmegaConf
1922

2023
EXPECTED_HYPERPARAMETERS = {
2124
"integer": 1,
@@ -26,6 +29,7 @@
2629
"dict": {
2730
"string": "value",
2831
"integer": 3,
32+
"float": 3.14,
2933
"list": [1, 2, 3],
3034
"dict": {"key": "value"},
3135
"boolean": True,
@@ -117,7 +121,7 @@ def main():
117121
assert isinstance(params["dict"], dict)
118122

119123
params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
120-
print(params)
124+
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
121125
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
122126
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
123127
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
@@ -132,9 +136,96 @@ def main():
132136
assert isinstance(params["float"], float)
133137
assert isinstance(params["list"], list)
134138
assert isinstance(params["dict"], dict)
135-
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
136139

137-
print("Test passed.")
140+
# Local JSON - DictConfig OmegaConf
141+
params = OmegaConf.load("hyperparameters.json")
142+
143+
print(f"Local hyperparameters.json: {params}")
144+
assert params.string == EXPECTED_HYPERPARAMETERS["string"]
145+
assert params.integer == EXPECTED_HYPERPARAMETERS["integer"]
146+
assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
147+
assert params.float == EXPECTED_HYPERPARAMETERS["float"]
148+
assert params.list == EXPECTED_HYPERPARAMETERS["list"]
149+
assert params.dict == EXPECTED_HYPERPARAMETERS["dict"]
150+
assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
151+
assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
152+
assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
153+
assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
154+
assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
155+
assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
156+
157+
@dataclass
158+
class DictConfig:
159+
string: str
160+
integer: int
161+
boolean: bool
162+
float: float
163+
list: List[int]
164+
dict: Dict[str, Any]
165+
166+
@dataclass
167+
class HPConfig:
168+
string: str
169+
integer: int
170+
boolean: bool
171+
float: float
172+
list: List[int]
173+
dict: DictConfig
174+
175+
# Local JSON - Structured OmegaConf
176+
hp_config: HPConfig = OmegaConf.merge(
177+
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json")
178+
)
179+
print(f"Local hyperparameters.json - Structured: {hp_config}")
180+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
181+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
182+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
183+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
184+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
185+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
186+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
187+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
188+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
189+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
190+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
191+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
192+
193+
# Local YAML - Structured OmegaConf
194+
hp_config: HPConfig = OmegaConf.merge(
195+
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml")
196+
)
197+
print(f"Local hyperparameters.yaml - Structured: {hp_config}")
198+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
199+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
200+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
201+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
202+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
203+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
204+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
205+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
206+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
207+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
208+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
209+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
210+
print(f"hyperparameters.yaml -> hyperparameters: {hp_config}")
211+
212+
# HP Dict - Structured OmegaConf
213+
hp_dict = json.loads(os.environ["SM_HPS"])
214+
hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict))
215+
print(f"SM_HPS - Structured: {hp_config}")
216+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
217+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
218+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
219+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
220+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
221+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
222+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
223+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
224+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
225+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
226+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
227+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
228+
print(f"SM_HPS -> hyperparameters: {hp_config}")
138229

139230

140231
if __name__ == "__main__":

0 commit comments

Comments
 (0)