Skip to content

Feat: Create/Delete Model Collection #3779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/sagemaker/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""This module contains code related to Amazon SageMaker Collection.

These Classes helps in providing features to maintain and create collections
"""

from __future__ import absolute_import
import json
import time
from typing import List


from botocore.exceptions import ClientError
from sagemaker.session import Session


class Collection(object):
"""Sets up Amazon SageMaker Collection."""

def __init__(self, sagemaker_session):
"""Initializes a Collection instance.

The collection provides a logical grouping for model groups

Args:
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using
the default AWS configuration chain.
"""
self.sagemaker_session = sagemaker_session or Session()

def _check_access_error(self, err: ClientError):
"""To check if the error is related to the access error and to provide the relavant message

Args:
err: The client error that needs to be checked
"""
error_code = err.response["Error"]["Code"]
if error_code == "AccessDeniedException":
raise Exception(
f"{error_code}: This account needs to attach a custom policy "
"to the user role to gain access to Collections. Refer - "
"https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
)

def create(self, collection_name: str, parent_collection_name: str = None):
"""Creates a collection

Args:
collection_name (str): The name of the collection to be created
parent_collection_name (str): The name of the parent collection.
To be None if the collection is to be created on the root level
"""

tag_rule_key = f"sagemaker:collection-path:{time.time()}"
tags_on_collection = {
"sagemaker:collection": "true",
"sagemaker:collection-path:root": "true",
}
tag_rule_values = [collection_name]

if parent_collection_name is not None:
try:
group_query = self.sagemaker_session.get_resource_group_query(
group=parent_collection_name
)
except ClientError as e:
error_code = e.response["Error"]["Code"]

if error_code == "NotFoundException":
raise ValueError(f"Cannot find collection: {parent_collection_name}")
self._check_access_error(err=e)
raise
if group_query.get("GroupQuery"):
parent_tag_rule_query = json.loads(
group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "")
)
parent_tag_rule = parent_tag_rule_query.get("TagFilters", [])[0]
if not parent_tag_rule:
raise "Invalid parent_collection_name"
parent_tag_value = parent_tag_rule["Values"][0]
tags_on_collection = {
parent_tag_rule["Key"]: parent_tag_value,
"sagemaker:collection": "true",
}
tag_rule_values = [f"{parent_tag_value}/{collection_name}"]
try:
resource_filters = [
"AWS::SageMaker::ModelPackageGroup",
"AWS::ResourceGroups::Group",
]

tag_filters = [
{
"Key": tag_rule_key,
"Values": tag_rule_values,
}
]
resource_query = {
"Query": json.dumps(
{"ResourceTypeFilters": resource_filters, "TagFilters": tag_filters}
),
"Type": "TAG_FILTERS_1_0",
}
collection_create_response = self.sagemaker_session.create_group(
collection_name, resource_query, tags_on_collection
)
return {
"Name": collection_create_response["Group"]["Name"],
"Arn": collection_create_response["Group"]["GroupArn"],
}

except ClientError as e:
message = e.response["Error"]["Message"]
error_code = e.response["Error"]["Code"]

if error_code == "BadRequestException" and "group already exists" in message:
raise ValueError("Collection with the given name already exists")

self._check_access_error(err=e)
raise

def delete(self, collections: List[str]):
"""Deletes a lits of collection

Args:
collections (List[str]): List of collections to be deleted
Only deletes a collection if it is empty
"""

if len(collections) > 10:
raise ValueError("Can delete upto 10 collections at a time")

delete_collection_failures = []
deleted_collection = []
collection_filter = [
{
"Name": "resource-type",
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
},
]
for collection in collections:
try:
collection_details = self.sagemaker_session.list_group_resources(
group=collection, filters=collection_filter
)
except ClientError as e:
self._check_access_error(err=e)
delete_collection_failures.append(
{"collection": collection, "message": e.response["Error"]["Message"]}
)
continue
if collection_details.get("Resources") and len(collection_details["Resources"]) > 0:
delete_collection_failures.append(
{"collection": collection, "message": "Validation error: Collection not empty"}
)
else:
try:
self.sagemaker_session.delete_resource_group(group=collection)
deleted_collection.append(collection)
except ClientError as e:
self._check_access_error(err=e)
delete_collection_failures.append(
{"collection": collection, "message": e.response["Error"]["Message"]}
)
return {
"deleted_collections": deleted_collection,
"delete_collection_failures": delete_collection_failures,
}
52 changes: 52 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(
self._default_bucket_name_override = default_bucket
self.s3_resource = None
self.s3_client = None
self.resource_groups_client = None
self.config = None
self.lambda_client = None
self.settings = settings
Expand Down Expand Up @@ -3961,6 +3962,57 @@ def delete_model(self, model_name):
LOGGER.info("Deleting model with name: %s", model_name)
self.sagemaker_client.delete_model(ModelName=model_name)

def list_group_resources(self, group, filters):
"""To list group resources with given filters

Args:
group (str): The name or the ARN of the group.
filters (list): Filters that needs to be applied to the list operation.
"""
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
"resource-groups"
)
return self.resource_groups_client.list_group_resources(Group=group, Filters=filters)

def delete_resource_group(self, group):
"""To delete a resource group

Args:
group (str): The name or the ARN of the resource group to delete.
"""
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
"resource-groups"
)
return self.resource_groups_client.delete_group(Group=group)

def get_resource_group_query(self, group):
"""To get the group query for an AWS Resource Group

Args:
group (str): The name or the ARN of the resource group to query.
"""
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
"resource-groups"
)
return self.resource_groups_client.get_group_query(Group=group)

def create_group(self, name, resource_query, tags):
"""To create a AWS Resource Group

Args:
name (str): The name of the group, which is also the identifier of the group.
resource_query (str): The resource query that determines
which AWS resources are members of this group
tags (dict): The Tags to be attached to the Resource Group
"""
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
"resource-groups"
)

return self.resource_groups_client.create_group(
Name=name, ResourceQuery=resource_query, Tags=tags
)

def list_tags(self, resource_arn, max_results=50):
"""List the tags given an Amazon Resource Name.

Expand Down
62 changes: 62 additions & 0 deletions tests/integ/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker.utils import unique_name_from_base
from sagemaker.collection import Collection


def test_create_collection_root_success(sagemaker_session):
collection = Collection(sagemaker_session)
collection_name = unique_name_from_base("test-collection")
collection.create(collection_name)
collection_filter = [
{
"Name": "resource-type",
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
},
]
collection_details = sagemaker_session.list_group_resources(
group=collection_name, filters=collection_filter
)
assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200
delete_response = collection.delete([collection_name])
assert len(delete_response["deleted_collections"]) == 1
assert len(delete_response["delete_collection_failures"]) == 0


def test_create_collection_nested_success(sagemaker_session):
collection = Collection(sagemaker_session)
collection_name = unique_name_from_base("test-collection")
child_collection_name = unique_name_from_base("test-collection-2")
collection.create(collection_name)
collection.create(collection_name=child_collection_name, parent_collection_name=collection_name)
collection_filter = [
{
"Name": "resource-type",
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
},
]
collection_details = sagemaker_session.list_group_resources(
group=collection_name, filters=collection_filter
)
# has one child i.e child collection
assert len(collection_details["Resources"]) == 1

collection_details = sagemaker_session.list_group_resources(
group=child_collection_name, filters=collection_filter
)
collection_details["ResponseMetadata"]["HTTPStatusCode"]
delete_response = collection.delete([child_collection_name, collection_name])
assert len(delete_response["deleted_collections"]) == 2
assert len(delete_response["delete_collection_failures"]) == 0
83 changes: 83 additions & 0 deletions tests/unit/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
import json
from mock import Mock

from sagemaker.collection import Collection

REGION = "us-west-2"
COLLECTION_NAME = "test-collection"
QUERY = {
"ResourceTypeFilters": ["AWS::SageMaker::ModelPackageGroup", "AWS::ResourceGroups::Group"],
"TagFilters": [
{"Key": "sagemaker:collection-path:1676120428.4811652", "Values": ["test-collection-k"]}
],
}
CREATE_COLLECTION_RESPONSE = {
"Group": {
"GroupArn": f"arn:aws:resource-groups:us-west-2:205984106344:group/{COLLECTION_NAME}",
"Name": COLLECTION_NAME,
},
"ResourceQuery": {
"Type": "TAG_FILTERS_1_0",
"Query": json.dumps(QUERY),
},
"Tags": {"sagemaker:collection-path:root": "true"},
}


@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
session_mock = Mock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=REGION,
config=None,
local_mode=False,
)

session_mock.create_group = Mock(
name="create_collection", return_value=CREATE_COLLECTION_RESPONSE
)
session_mock.delete_resource_group = Mock(name="delete_resource_group", return_value=True)
session_mock.list_group_resources = Mock(name="list_group_resources", return_value={})

return session_mock


def test_create_collection_success(sagemaker_session):
collection = Collection(sagemaker_session)
create_response = collection.create(collection_name=COLLECTION_NAME)
assert create_response["Name"] is COLLECTION_NAME
assert create_response["Arn"] is not None


def test_delete_collection_success(sagemaker_session):
collection = Collection(sagemaker_session)
delete_response = collection.delete(collections=[COLLECTION_NAME])
assert len(delete_response["deleted_collections"]) == 1
assert len(delete_response["delete_collection_failures"]) == 0


def test_delete_collection_failure_when_collection_is_not_empty(sagemaker_session):
collection = Collection(sagemaker_session)
sagemaker_session.list_group_resources = Mock(
name="list_group_resources", return_value={"Resources": [{}]}
)
delete_response = collection.delete(collections=[COLLECTION_NAME])
assert len(delete_response["deleted_collections"]) == 0
assert len(delete_response["delete_collection_failures"]) == 1