Skip to content

S3 Algos #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ examples/tensorflow/distributed_mnist/data
doc/_build
**/.DS_Store
venv/
*.rec
5 changes: 4 additions & 1 deletion src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor
from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor
from sagemaker.amazon.image_classification import ImageClassification, ImageClassificationModel
from sagemaker.amazon.image_classification import ImageClassificationPredictor
from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor

Expand All @@ -32,4 +34,5 @@
LinearLearnerModel, LinearLearnerPredictor,
FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor,
Model, RealTimePredictor, Session,
container_def, s3_input, production_variant, get_execution_role]
ImageClassification, ImageClassificationModel, ImageClassificationPredictor,
container_def, s3_input, production_variant, get_execution_role]
79 changes: 66 additions & 13 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

logger = logging.getLogger(__name__)


class AmazonAlgorithmEstimatorBase(EstimatorBase):
"""Base class for Amazon first-party Estimator implementations. This class isn't intended
to be instantiated directly."""
Expand Down Expand Up @@ -126,9 +125,53 @@ def record_set(self, train, labels=None, channel="train"):
return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel)


class AmazonS3AlgorithmEstimatorBase(AmazonAlgorithmEstimatorBase):
"""Base class for Amazon first-party Estimator implementations. This class isn't
intended to be instantiated directly. This is difference from the base class
because this class handles S3 data"""

def fit(self, records, mini_batch_size=None, distribution='ShardedByS3Key', **kwargs):
"""Fit this Estimator on serialized Record objects, stored in S3.

``records`` should be a list of instances of :class:`~RecordSet`. This defines a collection of
s3 data files to train this ``Estimator`` on.

More information on the Amazon Record format is available at:
https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html

See :meth:`~AmazonS3AlgorithmEstimatorBase.s3_record_set` to construct a ``RecordSet`` object
from :class:`~numpy.ndarray` arrays.

Args:
records (list): This is a list of :class:`~RecordSet` items The list of records to train
this ``Estimator`` will depend on each algorithm and type of input data.
mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a
default value will be used.
"""
default_mini_batch_size = 32
self.mini_batch_size = mini_batch_size or default_mini_batch_size
data = {}
for record in records:
data[record.channel] = s3_input(record.s3_data, distribution=distribution,
s3_data_type=record.s3_data_type)
super(AmazonAlgorithmEstimatorBase, self).fit(data, **kwargs)

def s3_record_set(self, s3_loc, channel="train"):
"""Build a :class:`~RecordSet` from a S3 location with data in it.

Args:
s3_loc (str): A s3 bucket where data is located
channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to.

Returns:
RecordSet: A RecordSet referencing the encoded, uploading training and label data.
"""
return RecordSet(self.data_location + '/' + s3_loc, channel=channel)

# Re-write a new recordset class for s3 objects.
class RecordSet(object):

def __init__(self, s3_data, num_records, feature_dim, s3_data_type='ManifestFile', channel='train'):
def __init__(self, s3_data, num_records = None, feature_dim = None, s3_data_type='ManifestFile', channel='train'):
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.

Args:
Expand Down Expand Up @@ -163,7 +206,6 @@ def _build_shards(num_shards, array):
shards.append(array[(num_shards - 1) * shard_size:])
return shards


def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None):
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards`` s3 objects,
stored in "s3://``bucket``/``key_prefix``/"."""
Expand Down Expand Up @@ -199,13 +241,24 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
finally:
raise ex


def registry(region_name):
"""Return docker registry for the given AWS region"""
account_id = {
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-west-2": "174872318107",
"eu-west-1": "438346466558"
}[region_name]
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
def registry(region_name, algorithm = None):
"""Return docker registry for the given AWS region

Args:
algorithm (str): Provide the algorithm to get the docker back"""
if algorithm is None:
account_id = {
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-west-2": "174872318107",
"eu-west-1": "438346466558"
}[region_name]
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
elif algorithm in ['image_classification']:
account_id = {
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-west-2": "433757028032",
"eu-west-1": "685385470294"
}[region_name]
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
Loading