Skip to content

Commit e68f4f5

Browse files
schinmayeepintaoz-aws
authored andcommitted
Feature: Add unit tests for recipes and minor bug fixes. (#1532)
* basic checks and unit test for recipes * More testing for recipes. Move recipe overrides to top before accessing any recipe fields. * check that we use customer provided image uri if it is set * reformat * test fixes * update git urls for recipes * revert to ssh git urls for recipes
1 parent 1f34950 commit e68f4f5

File tree

2 files changed

+166
-4
lines changed

2 files changed

+166
-4
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
596596

597597
recipe = OmegaConf.load(temp_local_recipe)
598598
os.unlink(temp_local_recipe)
599+
recipe_overrides.setdefault("run", dict())["results_dir"] = "/opt/ml/model"
600+
recipe_overrides.setdefault("exp_manager", dict())["exp_dir"] = "/opt/ml/model/"
601+
recipe = OmegaConf.merge(recipe, recipe_overrides)
599602

600603
if "instance_type" not in kwargs:
601604
raise ValueError("Must pass instance type to estimator when using training recipes.")
@@ -671,10 +674,6 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
671674
f"Devices of type {device_type} are not supported with training recipes."
672675
)
673676

674-
recipe_overrides.setdefault("run", dict())["results_dir"] = "/opt/ml/model"
675-
recipe_overrides.setdefault("exp_manager", dict())["exp_dir"] = "/opt/ml/model/"
676-
recipe = OmegaConf.merge(recipe, recipe_overrides)
677-
678677
if "container" in recipe and not recipe["container"]:
679678
logger.warning(
680679
"Ignoring container from training_recipe. Use image_uri arg for estimator."

tests/unit/test_pytorch.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
BUCKET_NAME = "mybucket"
3636
INSTANCE_COUNT = 1
3737
INSTANCE_TYPE = "ml.c4.4xlarge"
38+
INSTANCE_TYPE_GPU = "ml.p4d.24xlarge"
39+
INSTANCE_TYPE_TRAINIUM = "ml.trn1.32xlarge"
3840
ACCELERATOR_TYPE = "ml.eia.medium"
3941
IMAGE_URI = "sagemaker-pytorch"
4042
JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP)
@@ -59,6 +61,11 @@
5961
}
6062

6163
DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}}
64+
NEURON_RECIPE = (
65+
"https://raw.githubusercontent.com/aws-neuron/"
66+
"neuronx-distributed-training/refs/heads/main/examples/"
67+
"conf/hf_llama3_8B_config.yaml"
68+
)
6269

6370

6471
@pytest.fixture(name="sagemaker_session")
@@ -826,3 +833,159 @@ def test_predictor_with_component_name(sagemaker_session, component_name):
826833
predictor = PyTorchPredictor("endpoint", sagemaker_session, component_name=component_name)
827834

828835
assert predictor._get_component_name() == component_name
836+
837+
838+
def test_training_recipe_for_cpu(sagemaker_session):
839+
container_log_level = '"logging.INFO"'
840+
841+
recipe_overrides = {
842+
"exp_manager": {
843+
"explicit_log_dir": "/opt/ml/output/tensorboard",
844+
"checkpoint_dir": "/opt/ml/checkpoints",
845+
},
846+
"model": {
847+
"data": {
848+
"train_dir": "/opt/ml/input/data/train",
849+
"val_dir": "/opt/ml/input/data/val",
850+
},
851+
},
852+
}
853+
854+
with pytest.raises(ValueError):
855+
PyTorch(
856+
output_path="s3://output_path",
857+
role=ROLE,
858+
sagemaker_session=sagemaker_session,
859+
instance_count=INSTANCE_COUNT,
860+
instance_type=INSTANCE_TYPE,
861+
base_job_name="job",
862+
container_log_level=container_log_level,
863+
training_recipe="llama/hf_llama3_8b_seq8192_gpu",
864+
recipe_overrides=recipe_overrides,
865+
)
866+
867+
868+
@pytest.mark.parametrize(
869+
"recipe, model",
870+
[
871+
("hf_llama3_8b_seq8192_gpu", "llama"),
872+
("hf_mistral_gpu", "mistral"),
873+
("hf_mixtral_gpu", "mixtral"),
874+
],
875+
)
876+
def test_training_recipe_for_gpu(sagemaker_session, recipe, model):
877+
container_log_level = '"logging.INFO"'
878+
879+
recipe_overrides = {
880+
"exp_manager": {
881+
"explicit_log_dir": "/opt/ml/output",
882+
"checkpoint_dir": "/opt/ml/checkpoints",
883+
},
884+
"model": {
885+
"data": {
886+
"train_dir": "/opt/ml/input/data/train",
887+
"val_dir": "/opt/ml/input/data/val",
888+
},
889+
},
890+
}
891+
pytorch = PyTorch(
892+
output_path="s3://output_path",
893+
role=ROLE,
894+
sagemaker_session=sagemaker_session,
895+
instance_count=INSTANCE_COUNT,
896+
instance_type=INSTANCE_TYPE_GPU,
897+
base_job_name="job",
898+
container_log_level=container_log_level,
899+
training_recipe=f"{model}/{recipe}",
900+
recipe_overrides=recipe_overrides,
901+
)
902+
903+
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples", model)
904+
assert pytorch.entry_point == f"{model}_pretrain.py"
905+
assert pytorch.image_uri == pytorch.SM_TRAINING_RECIPE_GPU_IMG
906+
expected_distribution = {
907+
"torch_distributed": {
908+
"enabled": True,
909+
},
910+
"smdistributed": {
911+
"modelparallel": {
912+
"enabled": True,
913+
"parameters": {
914+
"placement_strategy": "cluster",
915+
},
916+
},
917+
},
918+
}
919+
assert pytorch.distribution.items() == expected_distribution.items()
920+
921+
922+
def test_training_recipe_with_override(sagemaker_session):
923+
container_log_level = '"logging.INFO"'
924+
925+
recipe_overrides = {
926+
"exp_manager": {
927+
"explicit_log_dir": "/opt/ml/output",
928+
"checkpoint_dir": "/opt/ml/checkpoints",
929+
},
930+
"model": {
931+
"data": {
932+
"train_dir": "/opt/ml/input/data/train",
933+
"val_dir": "/opt/ml/input/data/val",
934+
},
935+
"model_type": "mistral",
936+
},
937+
}
938+
pytorch = PyTorch(
939+
output_path="s3://output_path",
940+
role=ROLE,
941+
image_uri=IMAGE_URI,
942+
sagemaker_session=sagemaker_session,
943+
instance_count=INSTANCE_COUNT,
944+
instance_type=INSTANCE_TYPE_GPU,
945+
base_job_name="job",
946+
container_log_level=container_log_level,
947+
training_recipe="llama/hf_llama3_8b_seq8192_gpu",
948+
recipe_overrides=recipe_overrides,
949+
)
950+
951+
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples", "mistral")
952+
assert pytorch.entry_point == "mistral_pretrain.py"
953+
assert pytorch.image_uri == IMAGE_URI
954+
955+
956+
def test_training_recipe_for_trainium(sagemaker_session):
957+
container_log_level = '"logging.INFO"'
958+
959+
recipe_overrides = {
960+
"exp_manager": {
961+
"explicit_log_dir": "/opt/ml/output",
962+
},
963+
"data": {
964+
"train_dir": "/opt/ml/input/data/train",
965+
},
966+
"model": {
967+
"model_config": "/opt/ml/input/data/train/config.json",
968+
},
969+
"compiler_cache_url": "s3://s3://output_path/neuron-cache",
970+
}
971+
pytorch = PyTorch(
972+
output_path="s3://output_path",
973+
role=ROLE,
974+
sagemaker_session=sagemaker_session,
975+
instance_count=INSTANCE_COUNT,
976+
instance_type=INSTANCE_TYPE_TRAINIUM,
977+
base_job_name="job",
978+
container_log_level=container_log_level,
979+
training_recipe=NEURON_RECIPE,
980+
recipe_overrides=recipe_overrides,
981+
)
982+
983+
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples")
984+
assert pytorch.entry_point == "training_orchestrator.py"
985+
assert pytorch.image_uri == pytorch.SM_NEURONX_DIST_IMG
986+
expected_distribution = {
987+
"torch_distributed": {
988+
"enabled": True,
989+
},
990+
}
991+
assert pytorch.distribution == expected_distribution

0 commit comments

Comments
 (0)