Skip to content

Add benchmarking script #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# TensorFlow benchmarking scripts

This folder contains the TF training scripts https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks.

## Basic usage
**execute_tensorflow_training.py** uses SageMaker python sdk to start a training job. It takes the following parameters:

- role: SageMaker role used for training
- region: SageMaker region
- py-versions: py2 or py3 or "py2, py3"
- instance-types: A list of SageMaker instance types, for example 'ml.p2.xlarge, ml.c4.xlarge'. Use 'local' for local mode training.
- checkpoint-path: The S3 location where the model checkpoints and tensorboard events are saved after training

Any unknown arguments will be passed to the training script as additional arguments.

## Examples:

```bash
./execute_tensorflow_training.py -t local -r SageMakerRole --instance-type local --num_epochs 1 --wait

./execute_tensorflow_training.py -t local -r SageMakerRole --instance-type ml.c4.xlarge, ml.c5.xlarge --model resnet50

```

## Using other models, datasets and benchmarks configurations
```python tf_cnn_benchmarks/tf_cnn_benchmarks.py --help``` shows all the options that the script has.


## Tensorboard events and checkpoints

Tensorboard events are being saved to the S3 location defined by the hyperparameter checkpoint_path during training. That location can be overwritten by setting the script argument ```checkpoint-path```:

```bash
python execute_tensorflow_training.py ... --checkpoint-path s3://my/bucket/output/data
```
98 changes: 98 additions & 0 deletions benchmarks/execute_tensorflow_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python

from __future__ import absolute_import

import argparse
import itertools
import os

from sagemaker import Session
from sagemaker.estimator import Framework
from sagemaker.tensorflow import TensorFlow

default_bucket = Session().default_bucket
dir_path = os.path.dirname(os.path.realpath(__file__))

_DEFAULT_HYPERPARAMETERS = {
'batch_size': 32,
'model': 'resnet32',
'num_epochs': 10,
'data_format': 'NHWC',
'summary_verbosity': 1,
'save_summaries_steps': 10,
'data_name': 'cifar10'
}


class ScriptModeTensorFlow(Framework):
"""This class is temporary until the final version of Script Mode is released.
"""

__framework_name__ = "tensorflow-scriptmode-beta"

create_model = TensorFlow.create_model

def __init__(self, py_version='py3', **kwargs):
super(ScriptModeTensorFlow, self).__init__(**kwargs)
self.py_version = py_version
self.image_name = None
self.framework_version = '1.10.0'


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--instance-types', nargs='+', help='<Required> Set flag', required=True)
parser.add_argument('-r', '--role', required=True)
parser.add_argument('-w', '--wait', action='store_true')
parser.add_argument('--region', default='us-west-2')
parser.add_argument('--py-versions', nargs='+', help='<Required> Set flag', default=['py3'])
parser.add_argument('--checkpoint-path',
default=os.path.join(default_bucket(), 'benchmarks', 'checkpoints'),
help='The S3 location where the model checkpoints and tensorboard events are saved after training')

return parser.parse_known_args()


def main(args, script_args):
for instance_type, py_version in itertools.product(args.instance_types, args.py_versions):
base_name = '%s-%s-%s' % (py_version, instance_type[3:5], instance_type[6:])
model_dir = os.path.join(args.checkpoint_path, base_name)

job_hps = create_hyperparameters(model_dir, script_args)

print('hyperparameters:')
print(job_hps)

estimator = ScriptModeTensorFlow(
entry_point='tf_cnn_benchmarks.py',
role='SageMakerRole',
source_dir=os.path.join(dir_path, 'tf_cnn_benchmarks'),
base_job_name=base_name,
train_instance_count=1,
hyperparameters=job_hps,
train_instance_type=instance_type,
)

input_dir = 's3://sagemaker-sample-data-%s/spark/mnist/train/' % args.region
estimator.fit({'train': input_dir}, wait=args.wait)

print("To use TensorBoard, execute the following command:")
cmd = 'S3_USE_HTTPS=0 S3_VERIFY_SSL=0 AWS_REGION=%s tensorboard --host localhost --port 6006 --logdir %s'
print(cmd % (args.region, args.checkpoint_path))


def create_hyperparameters(model_dir, script_args):
job_hps = _DEFAULT_HYPERPARAMETERS.copy()

job_hps.update({'train_dir': model_dir, 'eval_dir': model_dir})

script_arg_keys_without_dashes = [key[2:] if key.startswith('--') else key[1:] for key in script_args[::2]]
script_arg_values = script_args[1::2]
job_hps.update(dict(zip(script_arg_keys_without_dashes, script_arg_values)))

return job_hps


if __name__ == '__main__':
args, script_args = get_args()
main(args, script_args)
89 changes: 89 additions & 0 deletions benchmarks/tf_cnn_benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# tf_cnn_benchmarks: High performance benchmarks

tf_cnn_benchmarks contains implementations of several popular convolutional
models, and is designed to be as fast as possible. tf_cnn_benchmarks supports
both running on a single machine or running in distributed mode across multiple
hosts. See the [High-Performance models
guide](https://www.tensorflow.org/performance/performance_models) for more
information.

These models utilize many of the strategies in the [TensorFlow Performance
Guide](https://www.tensorflow.org/performance/performance_guide). Benchmark
results can be found [here](https://www.tensorflow.org/performance/benchmarks).

These models are designed for performance. For models that have clean and
easy-to-read implementations, see the [TensorFlow Official
Models](https://github.com/tensorflow/models/tree/master/official).

## Getting Started

To run ResNet50 with synthetic data without distortions with a single GPU, run

```
python tf_cnn_benchmarks.py --num_gpus=1 --batch_size=32 --model=resnet50 --variable_update=parameter_server
```

Note that the master branch of tf_cnn_benchmarks requires the latest nightly
version of TensorFlow. You can install the nightly version by running `pip
install tf-nightly-gpu` in a clean environment, or by installing TensorFlow from
source. We sometimes will create a branch of tf_cnn_benchmarks, in the form of
cnn_tf_vX.Y_compatible, that is compatible with TensorFlow version X.Y For
example, branch
[cnn_tf_v1.9_compatible](https://github.com/tensorflow/benchmarks/tree/cnn_tf_v1.9_compatible/scripts/tf_cnn_benchmarks)
works with TensorFlow 1.9.

Some important flags are

* model: Model to use, e.g. resnet50, inception3, vgg16, and alexnet.
* num_gpus: Number of GPUs to use.
* data_dir: Path to data to process. If not set, synthetic data is used. To
use Imagenet data use these
[instructions](https://github.com/tensorflow/models/tree/master/research/inception#getting-started)
as a starting point.
* batch_size: Batch size for each GPU.
* variable_update: The method for managing variables: parameter_server
,replicated, distributed_replicated, independent
* local_parameter_device: Device to use as parameter server: cpu or gpu.

To see the full list of flags, run `python tf_cnn_benchmarks.py --help`.

To run ResNet50 with real data with 8 GPUs, run:

```
python tf_cnn_benchmarks.py --data_format=NCHW --batch_size=256 \
--model=resnet50 --optimizer=momentum --variable_update=replicated \
--nodistortions --gradient_repacking=8 --num_gpus=8 \
--num_epochs=90 --weight_decay=1e-4 --data_dir=${DATA_DIR} --use_fp16 \
--train_dir=${CKPT_DIR}
```
This will train a ResNet-50 model on ImageNet with 2048 batch size on 8
GPUs. The model should train to around 76% accuracy.

## Running the tests

To run the tests, run

```bash
pip install portpicker
python run_tests.py && python run_tests.py --run_distributed_tests
```

Note the tests require portpicker.

The command above runs a subset of tests that is both fast and fairly
comprehensive. Alternatively, all the tests can be run, but this will take a
long time:

```bash
python run_tests.py --full_tests && python run_tests.py --full_tests --run_distributed_tests
```

We will run all tests on every PR before merging them, so it is not necessary
to pass `--full_tests` when running tests yourself.

To run an individual test, such as method `testParameterServer` of test class
`TfCnnBenchmarksTest` of module `benchmark_cnn_test`, run

```bash
python -m unittest -v benchmark_cnn_test.TfCnnBenchmarksTest.testParameterServer
```
Loading