Skip to content

Commit 67810d7

Browse files
authored
breaking: deprecate sagemaker.amazon.amazon_estimator.get_image_uri() (#1725)
This also deprecates sagemaker.amazon.amazon_estimator.registry()
1 parent d7dd857 commit 67810d7

28 files changed

+129
-312
lines changed

doc/overview.rst

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -993,12 +993,17 @@ the ML Pipeline.
993993
994994
.. code:: python
995995
996-
xgb_image = get_image_uri(sess.boto_region_name, 'xgboost', repo_version="latest")
997-
xgb_model = Model(model_data='s3://path/to/model.tar.gz', image_uri=xgb_image)
998-
sparkml_model = SparkMLModel(model_data='s3://path/to/model.tar.gz', env={'SAGEMAKER_SPARKML_SCHEMA': schema})
996+
from sagemaker import image_uris, session
997+
from sagemaker.model import Model
998+
from sagemaker.pipeline import PipelineModel
999+
from sagemaker.sparkml import SparkMLModel
9991000
1000-
model_name = 'inference-pipeline-model'
1001-
endpoint_name = 'inference-pipeline-endpoint'
1001+
xgb_image = image_uris.retrieve("xgboost", session.Session().boto_region_name, repo_version="latest")
1002+
xgb_model = Model(model_data="s3://path/to/model.tar.gz", image_uri=xgb_image)
1003+
sparkml_model = SparkMLModel(model_data="s3://path/to/model.tar.gz", env={"SAGEMAKER_SPARKML_SCHEMA": schema})
1004+
1005+
model_name = "inference-pipeline-model"
1006+
endpoint_name = "inference-pipeline-endpoint"
10021007
sm_model = PipelineModel(name=model_name, role=sagemaker_role, models=[sparkml_model, xgb_model])
10031008
10041009
This defines a ``PipelineModel`` consisting of SparkML model and an XGBoost model stacked sequentially.

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 4 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
from six.moves.urllib.parse import urlparse
2121

22+
from sagemaker import image_uris
2223
from sagemaker.amazon import validation
2324
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2425
from sagemaker.amazon.common import write_numpy_to_dense_tensor
2526
from sagemaker.estimator import EstimatorBase, _TrainingJob
2627
from sagemaker.inputs import FileSystemInput, TrainingInput
27-
from sagemaker.model import NEO_IMAGE_ACCOUNT
28-
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
28+
from sagemaker.utils import sagemaker_timestamp
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -93,8 +93,8 @@ def __init__(
9393

9494
def train_image(self):
9595
"""Placeholder docstring"""
96-
return get_image_uri(
97-
self.sagemaker_session.boto_region_name, type(self).repo_name, type(self).repo_version
96+
return image_uris.retrieve(
97+
self.repo_name, self.sagemaker_session.boto_region_name, version=self.repo_version,
9898
)
9999

100100
def hyperparameters(self):
@@ -452,167 +452,3 @@ def upload_numpy_to_s3_shards(
452452
s3.Object(bucket, key_prefix + file).delete()
453453
finally:
454454
raise ex
455-
456-
457-
def registry(region_name, algorithm=None):
458-
"""Return docker registry for the given AWS region
459-
460-
Note: Not all the algorithms listed below have an Amazon Estimator
461-
implemented. For full list of pre-implemented Estimators, look at:
462-
463-
https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
464-
465-
Args:
466-
region_name (str): The region name for the account.
467-
algorithm (str): The algorithm for the account.
468-
469-
Raises:
470-
ValueError: If invalid algorithm passed in or if mapping does not exist for given algorithm
471-
and region.
472-
"""
473-
region_to_accounts = {}
474-
if algorithm in [
475-
None,
476-
"pca",
477-
"kmeans",
478-
"linear-learner",
479-
"factorization-machines",
480-
"ntm",
481-
"randomcutforest",
482-
"knn",
483-
"object2vec",
484-
"ipinsights",
485-
]:
486-
region_to_accounts = {
487-
"us-east-1": "382416733822",
488-
"us-east-2": "404615174143",
489-
"us-west-2": "174872318107",
490-
"eu-west-1": "438346466558",
491-
"eu-central-1": "664544806723",
492-
"ap-northeast-1": "351501993468",
493-
"ap-northeast-2": "835164637446",
494-
"ap-southeast-2": "712309505854",
495-
"us-gov-west-1": "226302683700",
496-
"ap-southeast-1": "475088953585",
497-
"ap-south-1": "991648021394",
498-
"ca-central-1": "469771592824",
499-
"eu-west-2": "644912444149",
500-
"us-west-1": "632365934929",
501-
"us-iso-east-1": "490574956308",
502-
"ap-east-1": "286214385809",
503-
"eu-north-1": "669576153137",
504-
"eu-west-3": "749696950732",
505-
"sa-east-1": "855470959533",
506-
"me-south-1": "249704162688",
507-
"cn-north-1": "390948362332",
508-
"cn-northwest-1": "387376663083",
509-
}
510-
elif algorithm in ["lda"]:
511-
region_to_accounts = {
512-
"us-east-1": "766337827248",
513-
"us-east-2": "999911452149",
514-
"us-west-2": "266724342769",
515-
"eu-west-1": "999678624901",
516-
"eu-central-1": "353608530281",
517-
"ap-northeast-1": "258307448986",
518-
"ap-northeast-2": "293181348795",
519-
"ap-southeast-2": "297031611018",
520-
"us-gov-west-1": "226302683700",
521-
"ap-southeast-1": "475088953585",
522-
"ap-south-1": "991648021394",
523-
"ca-central-1": "469771592824",
524-
"eu-west-2": "644912444149",
525-
"us-west-1": "632365934929",
526-
"us-iso-east-1": "490574956308",
527-
}
528-
elif algorithm in ["forecasting-deepar"]:
529-
region_to_accounts = {
530-
"us-east-1": "522234722520",
531-
"us-east-2": "566113047672",
532-
"us-west-2": "156387875391",
533-
"eu-west-1": "224300973850",
534-
"eu-central-1": "495149712605",
535-
"ap-northeast-1": "633353088612",
536-
"ap-northeast-2": "204372634319",
537-
"ap-southeast-2": "514117268639",
538-
"us-gov-west-1": "226302683700",
539-
"ap-southeast-1": "475088953585",
540-
"ap-south-1": "991648021394",
541-
"ca-central-1": "469771592824",
542-
"eu-west-2": "644912444149",
543-
"us-west-1": "632365934929",
544-
"us-iso-east-1": "490574956308",
545-
"ap-east-1": "286214385809",
546-
"eu-north-1": "669576153137",
547-
"eu-west-3": "749696950732",
548-
"sa-east-1": "855470959533",
549-
"me-south-1": "249704162688",
550-
"cn-north-1": "390948362332",
551-
"cn-northwest-1": "387376663083",
552-
}
553-
elif algorithm in [
554-
"xgboost",
555-
"seq2seq",
556-
"image-classification",
557-
"blazingtext",
558-
"object-detection",
559-
"semantic-segmentation",
560-
]:
561-
region_to_accounts = {
562-
"us-east-1": "811284229777",
563-
"us-east-2": "825641698319",
564-
"us-west-2": "433757028032",
565-
"eu-west-1": "685385470294",
566-
"eu-central-1": "813361260812",
567-
"ap-northeast-1": "501404015308",
568-
"ap-northeast-2": "306986355934",
569-
"ap-southeast-2": "544295431143",
570-
"us-gov-west-1": "226302683700",
571-
"ap-southeast-1": "475088953585",
572-
"ap-south-1": "991648021394",
573-
"ca-central-1": "469771592824",
574-
"eu-west-2": "644912444149",
575-
"us-west-1": "632365934929",
576-
"us-iso-east-1": "490574956308",
577-
"ap-east-1": "286214385809",
578-
"eu-north-1": "669576153137",
579-
"eu-west-3": "749696950732",
580-
"sa-east-1": "855470959533",
581-
"me-south-1": "249704162688",
582-
"cn-north-1": "390948362332",
583-
"cn-northwest-1": "387376663083",
584-
}
585-
elif algorithm in ["image-classification-neo", "xgboost-neo"]:
586-
region_to_accounts = NEO_IMAGE_ACCOUNT
587-
else:
588-
raise ValueError(
589-
"Algorithm class:{} does not have mapping to account_id with images".format(algorithm)
590-
)
591-
592-
if region_name in region_to_accounts:
593-
account_id = region_to_accounts[region_name]
594-
return get_ecr_image_uri_prefix(account_id, region_name)
595-
596-
raise ValueError(
597-
"Algorithm ({algorithm}) is unsupported for region ({region_name}).".format(
598-
algorithm=algorithm, region_name=region_name
599-
)
600-
)
601-
602-
603-
def get_image_uri(region_name, repo_name, repo_version=1):
604-
"""Return algorithm image URI for the given AWS region, repository name, and
605-
repository version
606-
607-
Args:
608-
region_name:
609-
repo_name:
610-
repo_version:
611-
"""
612-
logger.warning(
613-
"'get_image_uri' method will be deprecated in favor of 'ImageURIProvider' class "
614-
"in SageMaker Python SDK v2."
615-
)
616-
617-
repo = "{}:{}".format(repo_name, repo_version)
618-
return "{}/{}".format(registry(region_name, repo_name), repo)

src/sagemaker/amazon/factorization_machines.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.validation import gt, isin, ge
@@ -309,8 +310,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
309310
**kwargs:
310311
"""
311312
sagemaker_session = sagemaker_session or Session()
312-
repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version)
313-
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
313+
image_uri = image_uris.retrieve(
314+
FactorizationMachines.repo_name,
315+
sagemaker_session.boto_region_name,
316+
version=FactorizationMachines.repo_version,
317+
)
314318
super(FactorizationMachinesModel, self).__init__(
315319
image_uri,
316320
model_data,

src/sagemaker/amazon/ipinsights.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1819
from sagemaker.amazon.validation import ge, le
1920
from sagemaker.deserializers import JSONDeserializer
@@ -219,11 +220,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
219220
**kwargs:
220221
"""
221222
sagemaker_session = sagemaker_session or Session()
222-
repo = "{}:{}".format(IPInsights.repo_name, IPInsights.repo_version)
223-
image_uri = "{}/{}".format(
224-
registry(sagemaker_session.boto_session.region_name, IPInsights.repo_name), repo
223+
image_uri = image_uris.retrieve(
224+
IPInsights.repo_name,
225+
sagemaker_session.boto_region_name,
226+
version=IPInsights.repo_version,
225227
)
226-
227228
super(IPInsightsModel, self).__init__(
228229
image_uri,
229230
model_data,

src/sagemaker/amazon/kmeans.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.validation import gt, isin, ge, le
@@ -242,8 +243,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
242243
**kwargs:
243244
"""
244245
sagemaker_session = sagemaker_session or Session()
245-
repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version)
246-
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
246+
image_uri = image_uris.retrieve(
247+
KMeans.repo_name, sagemaker_session.boto_region_name, version=KMeans.repo_version,
248+
)
247249
super(KMeansModel, self).__init__(
248250
image_uri,
249251
model_data,

src/sagemaker/amazon/knn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.validation import ge, isin
@@ -230,12 +231,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
230231
**kwargs:
231232
"""
232233
sagemaker_session = sagemaker_session or Session()
233-
repo = "{}:{}".format(KNN.repo_name, KNN.repo_version)
234-
image = "{}/{}".format(
235-
registry(sagemaker_session.boto_session.region_name, KNN.repo_name), repo
234+
image_uri = image_uris.retrieve(
235+
KNN.repo_name, sagemaker_session.boto_region_name, version=KNN.repo_version,
236236
)
237237
super(KNNModel, self).__init__(
238-
image,
238+
image_uri,
239239
model_data,
240240
role,
241241
predictor_cls=KNNPredictor,

src/sagemaker/amazon/lda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.validation import gt
@@ -214,9 +215,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
214215
**kwargs:
215216
"""
216217
sagemaker_session = sagemaker_session or Session()
217-
repo = "{}:{}".format(LDA.repo_name, LDA.repo_version)
218-
image_uri = "{}/{}".format(
219-
registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo
218+
image_uri = image_uris.retrieve(
219+
LDA.repo_name, sagemaker_session.boto_region_name, version=LDA.repo_version,
220220
)
221221
super(LDAModel, self).__init__(
222222
image_uri,

src/sagemaker/amazon/linear_learner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.validation import isin, gt, lt, ge, le
@@ -473,8 +474,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
473474
**kwargs:
474475
"""
475476
sagemaker_session = sagemaker_session or Session()
476-
repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version)
477-
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
477+
image_uri = image_uris.retrieve(
478+
LinearLearner.repo_name,
479+
sagemaker_session.boto_region_name,
480+
version=LinearLearner.repo_version,
481+
)
478482
super(LinearLearnerModel, self).__init__(
479483
image_uri,
480484
model_data,

src/sagemaker/amazon/ntm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker import image_uris
17+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1718
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1819
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1920
from sagemaker.amazon.validation import ge, le, isin
@@ -244,9 +245,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
244245
**kwargs:
245246
"""
246247
sagemaker_session = sagemaker_session or Session()
247-
repo = "{}:{}".format(NTM.repo_name, NTM.repo_version)
248-
image_uri = "{}/{}".format(
249-
registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo
248+
image_uri = image_uris.retrieve(
249+
NTM.repo_name, sagemaker_session.boto_region_name, version=NTM.repo_version,
250250
)
251251
super(NTMModel, self).__init__(
252252
image_uri,

0 commit comments

Comments
 (0)