Skip to content

Commit 7846a60

Browse files
committed
Add integ test
1 parent 5980e23 commit 7846a60

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
import os
3+
import subprocess
4+
import sys
5+
6+
7+
def main():
8+
driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
9+
process_count_per_node = driver_config["process_count_per_node"]
10+
assert process_count_per_node != None
11+
12+
hps = json.loads(os.environ["SM_HP"])
13+
assert hps != None
14+
assert isinstance(hps, dict)
15+
16+
source_dir = os.environ["SM_SOURCE_DIR"]
17+
assert source_dir == "/opt/ml/input/data/code"
18+
sm_drivers_dir = os.environ["SM_DRIVER_DIR"]
19+
assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/drivers"
20+
21+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
22+
assert entry_script != None
23+
24+
python = sys.executable
25+
26+
command = [python, entry_script]
27+
print(f"Running command: {command}")
28+
subprocess.run(command, check=True)
29+
30+
31+
if __name__ == "__main__":
32+
print("Running custom driver script")
33+
main()
34+
print("Finished running custom driver script")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import json
2+
import os
3+
import time
4+
5+
6+
def main():
7+
hps = json.loads(os.environ["SM_HPS"])
8+
assert hps != None
9+
print(f"Hyperparameters: {hps}")
10+
11+
print("Running pseudo training script")
12+
for epochs in range(hps["epochs"]):
13+
print(f"Epoch: {epochs}")
14+
time.sleep(1)
15+
print("Finished running pseudo training script")
16+
17+
18+
if __name__ == "__main__":
19+
main()

tests/integ/sagemaker/modules/train/test_model_trainer.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from sagemaker.modules.train import ModelTrainer
1919
from sagemaker.modules.configs import SourceCode, Compute
20-
from sagemaker.modules.distributed import MPI, Torchrun
20+
from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig
2121

2222
EXPECTED_HYPERPARAMETERS = {
2323
"integer": 1,
@@ -106,3 +106,33 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
106106
)
107107

108108
model_trainer.train()
109+
110+
111+
def test_custom_distributed_driver(modules_sagemaker_session):
112+
class CustomDriver(DistributedConfig):
113+
process_count_per_node: int = 2
114+
115+
@property
116+
def driver_dir(self) -> str:
117+
return f"{DATA_DIR}/modules/custom_drivers"
118+
119+
@property
120+
def driver_script(self) -> str:
121+
return "driver.py"
122+
123+
source_code = SourceCode(
124+
source_dir=f"{DATA_DIR}/modules/scripts",
125+
entry_script="entry_script.py",
126+
)
127+
128+
hyperparameters = {"epochs": 10}
129+
130+
model_trainer = ModelTrainer(
131+
sagemaker_session=modules_sagemaker_session,
132+
training_image=DEFAULT_CPU_IMAGE,
133+
hyperparameters=hyperparameters,
134+
source_code=source_code,
135+
distributed=CustomDriver(),
136+
base_job_name="custom-distributed-driver",
137+
)
138+
model_trainer.train()

0 commit comments

Comments
 (0)