|
22 | 22 | from sagemaker import Predictor, TrainingInput, utils
|
23 | 23 | from sagemaker.amazon.amazon_estimator import RecordSet
|
24 | 24 | from sagemaker.estimator import Framework
|
| 25 | +from sagemaker.fw_utils import UploadedCode |
| 26 | +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME |
| 27 | +from sagemaker.jumpstart.enums import JumpStartTag |
25 | 28 | from sagemaker.mxnet import MXNet
|
26 | 29 | from sagemaker.parameter import ParameterRange
|
27 | 30 | from sagemaker.tuner import (
|
@@ -1518,3 +1521,65 @@ def _convert_tuning_job_details(job_details, estimator_name):
|
1518 | 1521 | job_details_copy["TrainingJobDefinitions"] = [training_details]
|
1519 | 1522 |
|
1520 | 1523 | return job_details_copy
|
| 1524 | + |
| 1525 | + |
| 1526 | +@patch("time.time", return_value=510006209.073025) |
| 1527 | +@patch("sagemaker.estimator.tar_and_upload_dir") |
| 1528 | +@patch("sagemaker.model.Model._upload_code") |
| 1529 | +def test_tags_prefixes_jumpstart_models( |
| 1530 | + patched_upload_code, patched_tar_and_upload_dir, sagemaker_session |
| 1531 | +): |
| 1532 | + |
| 1533 | + patched_tar_and_upload_dir.return_value = UploadedCode( |
| 1534 | + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
| 1535 | + ) |
| 1536 | + sagemaker_session.boto_region_name = REGION |
| 1537 | + |
| 1538 | + instance_type = "ml.p2.xlarge" |
| 1539 | + instance_count = 1 |
| 1540 | + |
| 1541 | + training_data_uri = "s3://bucket/mydata" |
| 1542 | + |
| 1543 | + jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" |
| 1544 | + jumpstart_source_dir_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/source_dirs/source.tar.gz" |
| 1545 | + |
| 1546 | + image_uri = "fake-image-uri" |
| 1547 | + |
| 1548 | + generic_estimator = Estimator( |
| 1549 | + entry_point="transfer_learning.py", |
| 1550 | + role=ROLE, |
| 1551 | + region=REGION, |
| 1552 | + sagemaker_session=sagemaker_session, |
| 1553 | + instance_count=instance_count, |
| 1554 | + instance_type=instance_type, |
| 1555 | + source_dir=jumpstart_source_dir, |
| 1556 | + image_uri=image_uri, |
| 1557 | + model_uri=jumpstart_source_dir_2, |
| 1558 | + tags=[{"Key": "estimator-tag-key", "Value": "estimator-tag-value"}], |
| 1559 | + ) |
| 1560 | + |
| 1561 | + hp_tuner = HyperparameterTuner( |
| 1562 | + generic_estimator, |
| 1563 | + OBJECTIVE_METRIC_NAME, |
| 1564 | + HYPERPARAMETER_RANGES, |
| 1565 | + tags=[{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}], |
| 1566 | + ) |
| 1567 | + |
| 1568 | + hp_tuner.fit({"training": training_data_uri}) |
| 1569 | + |
| 1570 | + assert [ |
| 1571 | + {"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}, |
| 1572 | + {"Key": "estimator-tag-key", "Value": "estimator-tag-value"}, |
| 1573 | + { |
| 1574 | + "Key": JumpStartTag.TRAINING_MODEL_URI.value, |
| 1575 | + "Value": jumpstart_source_dir_2, |
| 1576 | + }, |
| 1577 | + { |
| 1578 | + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, |
| 1579 | + "Value": jumpstart_source_dir, |
| 1580 | + }, |
| 1581 | + ] == sagemaker_session.create_tuning_job.call_args_list[0][1]["tags"] |
| 1582 | + |
| 1583 | + assert sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith( |
| 1584 | + JUMPSTART_RESOURCE_BASE_NAME |
| 1585 | + ) |
0 commit comments