Skip to content

change: Enhance unit-tests to automatically consume image URIs config registries from config JSONs #4126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/unit/sagemaker/image_uris/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import json
import pytest


@pytest.fixture(scope="module")
def config_dir():
return "src/sagemaker/image_uri_config/"


@pytest.fixture(scope="module")
def load_config(config_dir, request):
config_file_name = request.param
config_file_path = os.path.join(config_dir, config_file_name)

with open(config_file_path, "r") as config_file:
return json.load(config_file)


@pytest.fixture(scope="module")
def extract_versions_for_image_scope(load_config, request):
scope_val = request.param
return load_config[scope_val]["versions"]
34 changes: 3 additions & 31 deletions tests/unit/sagemaker/image_uris/test_base_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,11 @@
from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris

REGISTRIES = {
"us-east-2": "429704687514",
"me-south-1": "117516905037",
"us-west-2": "236514542706",
"ca-central-1": "310906938811",
"ap-east-1": "493642496378",
"us-east-1": "081325390199",
"ap-northeast-2": "806072073708",
"eu-west-2": "712779665605",
"ap-southeast-2": "52832661640",
"cn-northwest-1": "390780980154",
"eu-north-1": "243637512696",
"cn-north-1": "390048526115",
"ap-south-1": "394103062818",
"eu-west-3": "615547856133",
"ap-southeast-3": "276181064229",
"af-south-1": "559312083959",
"eu-west-1": "470317259841",
"eu-central-1": "936697816551",
"sa-east-1": "782484402741",
"ap-northeast-3": "792733760839",
"eu-south-1": "592751261982",
"ap-northeast-1": "102112518831",
"us-west-1": "742091327244",
"ap-southeast-1": "492261229750",
"me-central-1": "103105715889",
"us-gov-east-1": "107072934176",
"us-gov-west-1": "107173498710",
}


@pytest.mark.parametrize("load_config", ["sagemaker-base-python.json"], indirect=True)
@pytest.mark.parametrize("py_version", ["310", "38"])
def test_get_base_python_image_uri(py_version):
def test_get_base_python_image_uri(py_version, load_config):
REGISTRIES = load_config["versions"]["1.0"]["registries"]
for region in REGISTRIES.keys():
uri = image_uris.get_base_python_image_uri(
region=region,
Expand Down
52 changes: 10 additions & 42 deletions tests/unit/sagemaker/image_uris/test_data_wrangler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,10 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris

DATA_WRANGLER_ACCOUNTS = {
"af-south-1": "143210264188",
"ap-east-1": "707077482487",
"ap-northeast-1": "649008135260",
"ap-northeast-2": "131546521161",
"ap-northeast-3": "913387583493",
"ap-south-1": "089933028263",
"ap-southeast-1": "119527597002",
"ap-southeast-2": "422173101802",
"ca-central-1": "557239378090",
"eu-central-1": "024640144536",
"eu-north-1": "054986407534",
"eu-south-1": "488287956546",
"eu-west-1": "245179582081",
"eu-west-2": "894491911112",
"eu-west-3": "807237891255",
"me-south-1": "376037874950",
"sa-east-1": "424196993095",
"us-east-1": "663277389841",
"us-east-2": "415577184552",
"us-west-1": "926135532090",
"us-west-2": "174368400705",
"cn-north-1": "245909111842",
"cn-northwest-1": "249157047649",
}

# Accounts only supported in DW 3.x and beyond
DATA_WRANGLER_3X_ACCOUNTS = {
"il-central-1": "406833011540",
}

VERSIONS = ["1.x", "2.x", "3.x"]


def _test_ecr_uri(account, region, version):
actual_uri = image_uris.retrieve("data-wrangler", region=region, version=version)
Expand All @@ -60,23 +28,23 @@ def _test_ecr_uri(account, region, version):
return expected_uri == actual_uri


def test_data_wrangler_ecr_uri():
@pytest.mark.parametrize("load_config", ["data-wrangler.json"], indirect=True)
@pytest.mark.parametrize("extract_versions_for_image_scope", ["processing"], indirect=True)
def test_data_wrangler_ecr_uri(load_config, extract_versions_for_image_scope):
VERSIONS = extract_versions_for_image_scope
for version in VERSIONS:
DATA_WRANGLER_ACCOUNTS = load_config["processing"]["versions"][version]["registries"]
for region in DATA_WRANGLER_ACCOUNTS.keys():
assert _test_ecr_uri(
account=DATA_WRANGLER_ACCOUNTS[region], region=region, version=version
)


def test_data_wrangler_ecr_uri_3x():
for region in DATA_WRANGLER_3X_ACCOUNTS.keys():
assert _test_ecr_uri(
account=DATA_WRANGLER_3X_ACCOUNTS[region], region=region, version="3.x"
)


def test_data_wrangler_ecr_uri_none():
@pytest.mark.parametrize("load_config", ["data-wrangler.json"], indirect=True)
def test_data_wrangler_ecr_uri_none(load_config):
region = "us-west-2"
VERSIONS = ["1.x", "2.x", "3.x"]
DATA_WRANGLER_ACCOUNTS = load_config["processing"]["versions"]["1.x"]["registries"]
actual_uri = image_uris.retrieve("data-wrangler", region=region)
expected_uri = expected_uris.algo_uri(
"sagemaker-data-wrangler-container",
Expand Down
32 changes: 4 additions & 28 deletions tests/unit/sagemaker/image_uris/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,14 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris

ACCOUNTS = {
"af-south-1": "314341159256",
"ap-east-1": "199566480951",
"ap-northeast-1": "430734990657",
"ap-northeast-2": "578805364391",
"ap-northeast-3": "479947661362",
"ap-south-1": "904829902805",
"ap-southeast-1": "972752614525",
"ap-southeast-2": "184798709955",
"ca-central-1": "519511493484",
"cn-north-1": "618459771430",
"cn-northwest-1": "658757709296",
"eu-central-1": "482524230118",
"eu-north-1": "314864569078",
"eu-south-1": "563282790590",
"eu-west-1": "929884845733",
"eu-west-2": "250201462417",
"eu-west-3": "447278800020",
"me-south-1": "986000313247",
"sa-east-1": "818342061345",
"us-east-1": "503895931360",
"us-east-2": "915447279597",
"us-gov-west-1": "515509971035",
"us-west-1": "685455198987",
"us-west-2": "895741380848",
}


def test_debugger():
@pytest.mark.parametrize("load_config", ["debugger.json"], indirect=True)
def test_debugger(load_config):
ACCOUNTS = load_config["versions"]["latest"]["registries"]
for region in ACCOUNTS.keys():
uri = image_uris.retrieve("debugger", region=region)
expected = expected_uris.algo_uri(
Expand Down