Skip to content

Commit 29ee04a

Browse files
committed
chore: add unit test for deploying jumpstart amt model
1 parent a00de67 commit 29ee04a

File tree

1 file changed

+161
-24
lines changed

1 file changed

+161
-24
lines changed

tests/unit/test_tuner.py

Lines changed: 161 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,19 +1530,74 @@ def test_tags_prefixes_jumpstart_models(
15301530
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
15311531
):
15321532

1533+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
1534+
jumpstart_source_dir_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/source_dirs/source.tar.gz"
1535+
jumpstart_source_dir_3 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[2]}/source_dirs/source.tar.gz"
1536+
1537+
estimator_tag = {"Key": "estimator-tag-key", "Value": "estimator-tag-value"}
1538+
hp_tag = {"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}
1539+
training_model_uri_tag = {
1540+
"Key": JumpStartTag.TRAINING_MODEL_URI.value,
1541+
"Value": jumpstart_source_dir_2,
1542+
}
1543+
training_script_uri_tag = {
1544+
"Key": JumpStartTag.TRAINING_SCRIPT_URI.value,
1545+
"Value": jumpstart_source_dir,
1546+
}
1547+
inference_script_uri_tag = {
1548+
"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value,
1549+
"Value": jumpstart_source_dir_3,
1550+
}
1551+
15331552
patched_tar_and_upload_dir.return_value = UploadedCode(
15341553
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
15351554
)
15361555
sagemaker_session.boto_region_name = REGION
15371556

1557+
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
1558+
"AlgorithmSpecification": {
1559+
"TrainingInputMode": "File",
1560+
"TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other:1.0.4",
1561+
},
1562+
"HyperParameters": {
1563+
"sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
1564+
"checkpoint_path": '"s3://other/1508872349"',
1565+
"sagemaker_program": '"iris-dnn-classifier.py"',
1566+
"sagemaker_container_log_level": '"logging.INFO"',
1567+
"sagemaker_job_name": '"neo"',
1568+
"training_steps": "100",
1569+
},
1570+
"RoleArn": "arn:aws:iam::366:role/SageMakerRole",
1571+
"ResourceConfig": {
1572+
"VolumeSizeInGB": 30,
1573+
"InstanceCount": 1,
1574+
"InstanceType": "ml.c4.xlarge",
1575+
},
1576+
"EnableNetworkIsolation": False,
1577+
"StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
1578+
"TrainingJobName": "neo",
1579+
"TrainingJobStatus": "Completed",
1580+
"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo",
1581+
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
1582+
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
1583+
"EnableInterContainerTrafficEncryption": False,
1584+
"ModelArtifacts": {"S3ModelArtifacts": "blah"},
1585+
}
1586+
1587+
sagemaker_session.sagemaker_client.list_tags.return_value = {
1588+
"Tags": [
1589+
estimator_tag,
1590+
hp_tag,
1591+
training_model_uri_tag,
1592+
training_script_uri_tag,
1593+
]
1594+
}
1595+
15381596
instance_type = "ml.p2.xlarge"
15391597
instance_count = 1
15401598

15411599
training_data_uri = "s3://bucket/mydata"
15421600

1543-
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
1544-
jumpstart_source_dir_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/source_dirs/source.tar.gz"
1545-
15461601
image_uri = "fake-image-uri"
15471602

15481603
generic_estimator = Estimator(
@@ -1555,35 +1610,58 @@ def test_tags_prefixes_jumpstart_models(
15551610
source_dir=jumpstart_source_dir,
15561611
image_uri=image_uri,
15571612
model_uri=jumpstart_source_dir_2,
1558-
tags=[{"Key": "estimator-tag-key", "Value": "estimator-tag-value"}],
1613+
tags=[estimator_tag],
15591614
)
15601615

15611616
hp_tuner = HyperparameterTuner(
15621617
generic_estimator,
15631618
OBJECTIVE_METRIC_NAME,
15641619
HYPERPARAMETER_RANGES,
1565-
tags=[{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}],
1620+
tags=[hp_tag],
15661621
)
15671622

15681623
hp_tuner.fit({"training": training_data_uri})
15691624

15701625
assert [
1571-
{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"},
1572-
{"Key": "estimator-tag-key", "Value": "estimator-tag-value"},
1573-
{
1574-
"Key": JumpStartTag.TRAINING_MODEL_URI.value,
1575-
"Value": jumpstart_source_dir_2,
1576-
},
1577-
{
1578-
"Key": JumpStartTag.TRAINING_SCRIPT_URI.value,
1579-
"Value": jumpstart_source_dir,
1580-
},
1626+
hp_tag,
1627+
estimator_tag,
1628+
training_model_uri_tag,
1629+
training_script_uri_tag,
15811630
] == sagemaker_session.create_tuning_job.call_args_list[0][1]["tags"]
15821631

15831632
assert sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
15841633
JUMPSTART_RESOURCE_BASE_NAME
15851634
)
15861635

1636+
hp_tuner.deploy(
1637+
initial_instance_count=INSTANCE_COUNT,
1638+
instance_type=INSTANCE_TYPE,
1639+
image_uri=image_uri,
1640+
source_dir=jumpstart_source_dir_3,
1641+
entry_point="inference.py",
1642+
role=ROLE,
1643+
)
1644+
1645+
assert sagemaker_session.create_model.call_args_list[0][0][0].startswith(
1646+
JUMPSTART_RESOURCE_BASE_NAME
1647+
)
1648+
1649+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
1650+
JUMPSTART_RESOURCE_BASE_NAME
1651+
)
1652+
1653+
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [
1654+
training_model_uri_tag,
1655+
training_script_uri_tag,
1656+
inference_script_uri_tag,
1657+
]
1658+
1659+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [
1660+
training_model_uri_tag,
1661+
training_script_uri_tag,
1662+
inference_script_uri_tag,
1663+
]
1664+
15871665

15881666
@patch("time.time", return_value=510006209.073025)
15891667
@patch("sagemaker.estimator.tar_and_upload_dir")
@@ -1592,19 +1670,58 @@ def test_no_tags_prefixes_non_jumpstart_models(
15921670
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
15931671
):
15941672

1673+
non_jumpstart_source_dir = "s3://blah1/source_dirs/source.tar.gz"
1674+
non_jumpstart_source_dir_2 = "s3://blah2/source_dirs/source.tar.gz"
1675+
non_jumpstart_source_dir_3 = "s3://blah3/source_dirs/source.tar.gz"
1676+
1677+
estimator_tag = {"Key": "estimator-tag-key", "Value": "estimator-tag-value"}
1678+
hp_tag = {"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}
1679+
15951680
patched_tar_and_upload_dir.return_value = UploadedCode(
15961681
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
15971682
)
15981683
sagemaker_session.boto_region_name = REGION
15991684

1685+
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
1686+
"AlgorithmSpecification": {
1687+
"TrainingInputMode": "File",
1688+
"TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other:1.0.4",
1689+
},
1690+
"HyperParameters": {
1691+
"sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
1692+
"checkpoint_path": '"s3://other/1508872349"',
1693+
"sagemaker_program": '"iris-dnn-classifier.py"',
1694+
"sagemaker_container_log_level": '"logging.INFO"',
1695+
"sagemaker_job_name": '"neo"',
1696+
"training_steps": "100",
1697+
},
1698+
"RoleArn": "arn:aws:iam::366:role/SageMakerRole",
1699+
"ResourceConfig": {
1700+
"VolumeSizeInGB": 30,
1701+
"InstanceCount": 1,
1702+
"InstanceType": "ml.c4.xlarge",
1703+
},
1704+
"EnableNetworkIsolation": False,
1705+
"StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
1706+
"TrainingJobName": "neo",
1707+
"TrainingJobStatus": "Completed",
1708+
"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo",
1709+
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
1710+
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
1711+
"EnableInterContainerTrafficEncryption": False,
1712+
"ModelArtifacts": {"S3ModelArtifacts": "blah"},
1713+
}
1714+
1715+
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []}
1716+
1717+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.return_value = {
1718+
"BestTrainingJob": {"TrainingJobName": "some-name"}
1719+
}
16001720
instance_type = "ml.p2.xlarge"
16011721
instance_count = 1
16021722

16031723
training_data_uri = "s3://bucket/mydata"
16041724

1605-
non_jumpstart_source_dir = "s3://non-js-bucket/sdfsdfs"
1606-
non_jumpstart_source_dir_2 = "s3://non-js-bucket/sdfsdsfsdfsddfs"
1607-
16081725
image_uri = "fake-image-uri"
16091726

16101727
generic_estimator = Estimator(
@@ -1617,23 +1734,43 @@ def test_no_tags_prefixes_non_jumpstart_models(
16171734
source_dir=non_jumpstart_source_dir,
16181735
image_uri=image_uri,
16191736
model_uri=non_jumpstart_source_dir_2,
1620-
tags=[{"Key": "estimator-tag-key", "Value": "estimator-tag-value"}],
1737+
tags=[estimator_tag],
16211738
)
16221739

16231740
hp_tuner = HyperparameterTuner(
16241741
generic_estimator,
16251742
OBJECTIVE_METRIC_NAME,
16261743
HYPERPARAMETER_RANGES,
1627-
tags=[{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}],
1744+
tags=[hp_tag],
16281745
)
16291746

16301747
hp_tuner.fit({"training": training_data_uri})
16311748

1632-
assert [
1633-
{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"},
1634-
{"Key": "estimator-tag-key", "Value": "estimator-tag-value"},
1635-
] == sagemaker_session.create_tuning_job.call_args_list[0][1]["tags"]
1749+
assert [hp_tag, estimator_tag] == sagemaker_session.create_tuning_job.call_args_list[0][1][
1750+
"tags"
1751+
]
16361752

16371753
assert not sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
16381754
JUMPSTART_RESOURCE_BASE_NAME
16391755
)
1756+
1757+
hp_tuner.deploy(
1758+
initial_instance_count=INSTANCE_COUNT,
1759+
instance_type=INSTANCE_TYPE,
1760+
image_uri=image_uri,
1761+
source_dir=non_jumpstart_source_dir_3,
1762+
entry_point="inference.py",
1763+
role=ROLE,
1764+
)
1765+
1766+
assert not sagemaker_session.create_model.call_args_list[0][0][0].startswith(
1767+
JUMPSTART_RESOURCE_BASE_NAME
1768+
)
1769+
1770+
assert not sagemaker_session.endpoint_from_production_variants.call_args_list[0][1][
1771+
"name"
1772+
].startswith(JUMPSTART_RESOURCE_BASE_NAME)
1773+
1774+
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == []
1775+
1776+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == []

0 commit comments

Comments
 (0)