Skip to content

Commit 727989a

Browse files
authored
Merge branch 'master' into feat/inference-instance-type-conditioned-on-training-instance-type
2 parents 97a5816 + a4ba730 commit 727989a

20 files changed

+612
-1487
lines changed

CHANGELOG.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,32 @@
11
# Changelog
22

3+
## v2.195.1 (2023-10-26)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* Allow either instance_type or instance_group to be defined in…
8+
* enhance image_uris unit tests
9+
10+
## v2.195.0 (2023-10-25)
11+
12+
### Features
13+
14+
* jumpstart gated model artifacts
15+
* jumpstart extract generated text from response
16+
* jumpstart contruct payload utility
17+
18+
### Bug Fixes and Other Changes
19+
20+
* relax upper bound on urllib in local mode requirements
21+
* bump urllib3 version
22+
* allow smdistributed to be enabled with torch_distributed.
23+
* fix URL links
24+
25+
### Documentation Changes
26+
27+
* remove python 2 reference
28+
* update framework version links
29+
330
## v2.194.0 (2023-10-19)
431

532
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.194.1.dev0
1+
2.195.2.dev0

src/sagemaker/estimator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3818,6 +3818,7 @@ def _distribution_configuration(self, distribution):
38183818

38193819
mpi_enabled = False
38203820
smdataparallel_enabled = False
3821+
p5_enabled = False
38213822
if "instance_groups" in distribution:
38223823
distribution_config["sagemaker_distribution_instance_groups"] = distribution[
38233824
"instance_groups"
@@ -3862,10 +3863,11 @@ def _distribution_configuration(self, distribution):
38623863
elif isinstance(self.instance_type, str):
38633864
p5_enabled = "p5.48xlarge" in self.instance_type
38643865
else:
3865-
raise ValueError(
3866-
"Invalid object type for instance_type argument. Expected "
3867-
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868-
)
3866+
for instance in self.instance_groups:
3867+
if "p5.48xlarge" in instance._to_request_dict().get("InstanceType", ()):
3868+
p5_enabled = True
3869+
break
3870+
38693871
img_uri = "" if self.image_uri is None else self.image_uri
38703872
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
38713873
if (

tests/unit/sagemaker/image_uris/conftest.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@
1717
import pytest
1818

1919

20-
@pytest.fixture(scope="module")
21-
def config_dir():
22-
return "src/sagemaker/image_uri_config/"
20+
CONFIG_DIR = "src/sagemaker/image_uri_config/"
21+
22+
23+
def get_config(config_file_name):
24+
config_file_path = os.path.join(CONFIG_DIR, config_file_name)
25+
with open(config_file_path, "r") as config_file:
26+
return json.load(config_file)
2327

2428

2529
@pytest.fixture(scope="module")
26-
def load_config(config_dir, request):
30+
def load_config(request):
2731
config_file_name = request.param
28-
config_file_path = os.path.join(config_dir, config_file_name)
32+
return get_config(config_file_name)
2933

30-
with open(config_file_path, "r") as config_file:
31-
return json.load(config_file)
34+
35+
@pytest.fixture(scope="module")
36+
def load_config_and_file_name(request):
37+
config_file_name = request.param
38+
return get_config(config_file_name), config_file_name
3239

3340

3441
@pytest.fixture(scope="module")

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ def graviton_framework_uri(
7474
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
7575

7676

77-
def djl_framework_uri(repo, account, djl_version, primary_framework, region=REGION):
77+
def djl_framework_uri(repo, account, tag, region=REGION):
7878
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
79-
tag = f"{djl_version}-{primary_framework}"
8079
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
8180

8281

tests/unit/sagemaker/image_uris/test_algos.py

Lines changed: 29 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -17,187 +17,33 @@
1717
from sagemaker import image_uris
1818
from tests.unit.sagemaker.image_uris import expected_uris
1919

20-
ALGO_NAMES = (
21-
"blazingtext",
22-
"factorization-machines",
23-
"forecasting-deepar",
24-
"image-classification",
25-
"ipinsights",
26-
"kmeans",
27-
"knn",
28-
"linear-learner",
29-
"ntm",
30-
"object-detection",
31-
"object2vec",
32-
"pca",
33-
"randomcutforest",
34-
"semantic-segmentation",
35-
"seq2seq",
36-
"lda",
37-
)
38-
ALGO_REGIONS_AND_ACCOUNTS = (
39-
{
40-
"algorithms": (
41-
"pca",
42-
"kmeans",
43-
"linear-learner",
44-
"factorization-machines",
45-
"ntm",
46-
"randomcutforest",
47-
"knn",
48-
"object2vec",
49-
"ipinsights",
50-
),
51-
"accounts": {
52-
"af-south-1": "455444449433",
53-
"ap-east-1": "286214385809",
54-
"ap-northeast-1": "351501993468",
55-
"ap-northeast-2": "835164637446",
56-
"ap-northeast-3": "867004704886",
57-
"ap-south-1": "991648021394",
58-
"ap-south-2": "628508329040",
59-
"ap-southeast-1": "475088953585",
60-
"ap-southeast-2": "712309505854",
61-
"ap-southeast-3": "951798379941",
62-
"ap-southeast-4": "106583098589",
63-
"ca-central-1": "469771592824",
64-
"cn-north-1": "390948362332",
65-
"cn-northwest-1": "387376663083",
66-
"eu-central-1": "664544806723",
67-
"eu-central-2": "680994064768",
68-
"eu-north-1": "669576153137",
69-
"eu-west-1": "438346466558",
70-
"eu-west-2": "644912444149",
71-
"eu-west-3": "749696950732",
72-
"eu-south-1": "257386234256",
73-
"eu-south-2": "104374241257",
74-
"il-central-1": "898809789911",
75-
"me-south-1": "249704162688",
76-
"me-central-1": "272398656194",
77-
"sa-east-1": "855470959533",
78-
"us-east-1": "382416733822",
79-
"us-east-2": "404615174143",
80-
"us-gov-west-1": "226302683700",
81-
"us-gov-east-1": "237065988967",
82-
"us-iso-east-1": "490574956308",
83-
"us-isob-east-1": "765400339828",
84-
"us-west-1": "632365934929",
85-
"us-west-2": "174872318107",
86-
},
87-
},
88-
{
89-
"algorithms": ("lda",),
90-
"accounts": {
91-
"ap-northeast-1": "258307448986",
92-
"ap-northeast-2": "293181348795",
93-
"ap-south-1": "991648021394",
94-
"ap-southeast-1": "475088953585",
95-
"ap-southeast-2": "297031611018",
96-
"ca-central-1": "469771592824",
97-
"eu-central-1": "353608530281",
98-
"eu-west-1": "999678624901",
99-
"eu-west-2": "644912444149",
100-
"us-east-1": "766337827248",
101-
"us-east-2": "999911452149",
102-
"us-gov-west-1": "226302683700",
103-
"us-iso-east-1": "490574956308",
104-
"us-isob-east-1": "765400339828",
105-
"us-west-1": "632365934929",
106-
"us-west-2": "266724342769",
107-
},
108-
},
109-
{
110-
"algorithms": ("forecasting-deepar",),
111-
"accounts": {
112-
"af-south-1": "455444449433",
113-
"ap-east-1": "286214385809",
114-
"ap-northeast-1": "633353088612",
115-
"ap-northeast-2": "204372634319",
116-
"ap-northeast-3": "867004704886",
117-
"ap-south-1": "991648021394",
118-
"ap-southeast-1": "475088953585",
119-
"ap-southeast-2": "514117268639",
120-
"ca-central-1": "469771592824",
121-
"cn-north-1": "390948362332",
122-
"cn-northwest-1": "387376663083",
123-
"eu-central-1": "495149712605",
124-
"eu-north-1": "669576153137",
125-
"eu-west-1": "224300973850",
126-
"eu-west-2": "644912444149",
127-
"eu-west-3": "749696950732",
128-
"eu-south-1": "257386234256",
129-
"me-south-1": "249704162688",
130-
"sa-east-1": "855470959533",
131-
"us-east-1": "522234722520",
132-
"us-east-2": "566113047672",
133-
"us-gov-west-1": "226302683700",
134-
"us-iso-east-1": "490574956308",
135-
"us-isob-east-1": "765400339828",
136-
"us-west-1": "632365934929",
137-
"us-west-2": "156387875391",
138-
},
139-
},
140-
{
141-
"algorithms": (
142-
"seq2seq",
143-
"image-classification",
144-
"blazingtext",
145-
"object-detection",
146-
"semantic-segmentation",
147-
),
148-
"accounts": {
149-
"af-south-1": "455444449433",
150-
"ap-east-1": "286214385809",
151-
"ap-northeast-1": "501404015308",
152-
"ap-northeast-2": "306986355934",
153-
"ap-northeast-3": "867004704886",
154-
"ap-south-1": "991648021394",
155-
"ap-south-2": "628508329040",
156-
"ap-southeast-1": "475088953585",
157-
"ap-southeast-2": "544295431143",
158-
"ap-southeast-3": "951798379941",
159-
"ap-southeast-4": "106583098589",
160-
"ca-central-1": "469771592824",
161-
"cn-north-1": "390948362332",
162-
"cn-northwest-1": "387376663083",
163-
"eu-central-1": "813361260812",
164-
"eu-central-2": "680994064768",
165-
"eu-north-1": "669576153137",
166-
"eu-west-1": "685385470294",
167-
"eu-west-2": "644912444149",
168-
"eu-west-3": "749696950732",
169-
"eu-south-1": "257386234256",
170-
"eu-south-2": "104374241257",
171-
"il-central-1": "898809789911",
172-
"me-south-1": "249704162688",
173-
"me-central-1": "272398656194",
174-
"sa-east-1": "855470959533",
175-
"us-east-1": "811284229777",
176-
"us-east-2": "825641698319",
177-
"us-gov-west-1": "226302683700",
178-
"us-gov-east-1": "237065988967",
179-
"us-iso-east-1": "490574956308",
180-
"us-isob-east-1": "765400339828",
181-
"us-west-1": "632365934929",
182-
"us-west-2": "433757028032",
183-
},
184-
},
185-
)
18620

187-
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:1"
188-
189-
190-
def _accounts_for_algo(algo):
191-
for algo_account_dict in ALGO_REGIONS_AND_ACCOUNTS:
192-
if algo in algo_account_dict["algorithms"]:
193-
return algo_account_dict["accounts"]
194-
195-
return {}
196-
197-
198-
@pytest.mark.parametrize("algo", ALGO_NAMES)
199-
def test_algo_uris(algo):
200-
accounts = _accounts_for_algo(algo)
201-
for region in accounts:
202-
uri = image_uris.retrieve(algo, region)
203-
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
21+
ALGO_NAMES = [
22+
"blazingtext.json",
23+
"factorization-machines.json",
24+
"forecasting-deepar.json",
25+
"image-classification.json",
26+
"ipinsights.json",
27+
"kmeans.json",
28+
"knn.json",
29+
"linear-learner.json",
30+
"ntm.json",
31+
"object-detection.json",
32+
"object2vec.json",
33+
"pca.json",
34+
"randomcutforest.json",
35+
"semantic-segmentation.json",
36+
"seq2seq.json",
37+
"lda.json",
38+
]
39+
40+
41+
@pytest.mark.parametrize("load_config", ALGO_NAMES, indirect=True)
42+
def test_algo_uris(load_config):
43+
VERSIONS = load_config["versions"]
44+
for version in VERSIONS:
45+
ACCOUNTS = load_config["versions"][version]["registries"]
46+
algo_name = load_config["versions"][version]["repository"]
47+
for region in ACCOUNTS.keys():
48+
uri = image_uris.retrieve(algo_name, region)
49+
assert expected_uris.algo_uri(algo_name, ACCOUNTS[region], region) == uri

0 commit comments

Comments
 (0)