|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "## Introduction\n", |
| 8 | + "\n", |
| 9 | + "In this notebook, we demonstrate how BlazingText supports hosting of pre-trained Text Classification and Word2Vec models [FastText models](https://fasttext.cc/docs/en/english-vectors.html). BlazingText is a GPU accelerated version of FastText. FastText is a shallow Neural Network model used to perform both word embedding generation (unsupervised) and text classification (supervised). BlazingText uses custom CUDA kernels to accelerate the training process of FastText but the underlying algorithm is same for both the algorithms. Therefore, if you have a model trained with FastText or if one of the pre-trained models made available by FastText team is sufficient for your use case, then you can take advantage of Hosting support for BlazingText to setup SageMaker endpoints for realtime predictions using FastText models. It can help you avoid to train with BlazingText algorithm if your use-case is covered by the pre-trained models available from FastText." |
| 10 | + ] |
| 11 | + }, |
| 12 | + { |
| 13 | + "cell_type": "markdown", |
| 14 | + "metadata": {}, |
| 15 | + "source": [ |
| 16 | + "To start the proceedings, we will specify few of the important parameter like IAM Role and S3 bucket location which is required for SageMaker to facilitate model hosting. SageMaker Python SDK helps us to retrieve the IAM role and also helps you to operate easily with S3 resources. " |
| 17 | + ] |
| 18 | + }, |
| 19 | + { |
| 20 | + "cell_type": "code", |
| 21 | + "execution_count": null, |
| 22 | + "metadata": {}, |
| 23 | + "outputs": [], |
| 24 | + "source": [ |
| 25 | + "import sagemaker\n", |
| 26 | + "from sagemaker import get_execution_role\n", |
| 27 | + "import boto3\n", |
| 28 | + "import json\n", |
| 29 | + "\n", |
| 30 | + "sess = sagemaker.Session()\n", |
| 31 | + "\n", |
| 32 | + "role = get_execution_role()\n", |
| 33 | + "print(role) # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf\n", |
| 34 | + "\n", |
| 35 | + "bucket = sess.default_bucket() # Replace with your own bucket name if needed\n", |
| 36 | + "print(bucket)\n", |
| 37 | + "prefix = 'fasttext/pretrained' #Replace with the prefix under which you want to store the data if needed" |
| 38 | + ] |
| 39 | + }, |
| 40 | + { |
| 41 | + "cell_type": "code", |
| 42 | + "execution_count": null, |
| 43 | + "metadata": { |
| 44 | + "collapsed": true |
| 45 | + }, |
| 46 | + "outputs": [], |
| 47 | + "source": [ |
| 48 | + "region_name = boto3.Session().region_name" |
| 49 | + ] |
| 50 | + }, |
| 51 | + { |
| 52 | + "cell_type": "code", |
| 53 | + "execution_count": null, |
| 54 | + "metadata": {}, |
| 55 | + "outputs": [], |
| 56 | + "source": [ |
| 57 | + "container = sagemaker.amazon.amazon_estimator.get_image_uri(region_name, \"blazingtext\", \"latest\")\n", |
| 58 | + "print('Using SageMaker BlazingText container: {} ({})'.format(container, region_name))" |
| 59 | + ] |
| 60 | + }, |
| 61 | + { |
| 62 | + "cell_type": "markdown", |
| 63 | + "metadata": {}, |
| 64 | + "source": [ |
| 65 | + "### Hosting the [Language Idenfication model](https://fasttext.cc/docs/en/language-identification.html) by FastText\n", |
| 66 | + "\n", |
| 67 | + "For the example, we will leverage the pre-trained model available by FastText for Language Identification. Language Identification is the first step of many NLP applications where after the language of the input text is identified, specific models for that language needs to be applied for various other downstream tasks. Language Identification underneath is a Text Classification model which uses the language IDs as the class labels and hence FastText can be directly used for the training. FastText pretrained language model supports identification of 176 different languages. " |
| 68 | + ] |
| 69 | + }, |
| 70 | + { |
| 71 | + "cell_type": "markdown", |
| 72 | + "metadata": {}, |
| 73 | + "source": [ |
| 74 | + "Here we will download the Language Identification (Text Classification) model [1] from [FastText website](https://fasttext.cc/docs/en/language-identification.html). \n", |
| 75 | + "\n", |
| 76 | + "[1] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, Bag of Tricks for Efficient Text Classification" |
| 77 | + ] |
| 78 | + }, |
| 79 | + { |
| 80 | + "cell_type": "code", |
| 81 | + "execution_count": null, |
| 82 | + "metadata": {}, |
| 83 | + "outputs": [], |
| 84 | + "source": [ |
| 85 | + "!wget -O model.bin https://s3-us-west-1.amazonaws.com/fasttext-vectors/supervised_models/lid.176.bin" |
| 86 | + ] |
| 87 | + }, |
| 88 | + { |
| 89 | + "cell_type": "markdown", |
| 90 | + "metadata": {}, |
| 91 | + "source": [ |
| 92 | + "Next we will `tar` the model and upload it to S3 with the help of utilities available from Python SDK. We'll delete the local copies of the data as it's not required anymore." |
| 93 | + ] |
| 94 | + }, |
| 95 | + { |
| 96 | + "cell_type": "code", |
| 97 | + "execution_count": null, |
| 98 | + "metadata": {}, |
| 99 | + "outputs": [], |
| 100 | + "source": [ |
| 101 | + "!tar -czvf langid.tar.gz model.bin\n", |
| 102 | + "model_location = sess.upload_data(\"langid.tar.gz\", bucket=bucket, key_prefix=prefix)\n", |
| 103 | + "!rm langid.tar.gz model.bin" |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "markdown", |
| 108 | + "metadata": {}, |
| 109 | + "source": [ |
| 110 | + "## Creating SageMaker Inference Endpoint\n", |
| 111 | + "\n", |
| 112 | + "Next we'll create a SageMaker inference endpoint with the BlazingText container. This endpoint will be compatible with the pre-trained models available from FastText and can be used for inference directly without any modification. The inference endpoint works with content-type of `application/json`." |
| 113 | + ] |
| 114 | + }, |
| 115 | + { |
| 116 | + "cell_type": "code", |
| 117 | + "execution_count": null, |
| 118 | + "metadata": {}, |
| 119 | + "outputs": [], |
| 120 | + "source": [ |
| 121 | + "lang_id = sagemaker.Model(model_data=model_location, image=container, role=role, sagemaker_session=sess)\n", |
| 122 | + "lang_id.deploy(initial_instance_count = 1,instance_type = 'ml.m4.xlarge')\n", |
| 123 | + "predictor = sagemaker.RealTimePredictor(endpoint=lang_id.endpoint_name, \n", |
| 124 | + " sagemaker_session=sess,\n", |
| 125 | + " serializer=json.dumps,\n", |
| 126 | + " deserializer=sagemaker.predictor.json_deserializer)" |
| 127 | + ] |
| 128 | + }, |
| 129 | + { |
| 130 | + "cell_type": "markdown", |
| 131 | + "metadata": { |
| 132 | + "collapsed": true |
| 133 | + }, |
| 134 | + "source": [ |
| 135 | + "Next we'll pass few sentences from various languages to the endpoint to verify that the language identification works as expected." |
| 136 | + ] |
| 137 | + }, |
| 138 | + { |
| 139 | + "cell_type": "code", |
| 140 | + "execution_count": null, |
| 141 | + "metadata": { |
| 142 | + "collapsed": true |
| 143 | + }, |
| 144 | + "outputs": [], |
| 145 | + "source": [ |
| 146 | + "sentences = [\"hi which language is this?\",\n", |
| 147 | + " \"mon nom est Pierre\",\n", |
| 148 | + " \"Dem Jungen gab ich einen Ball.\",\n", |
| 149 | + " \"আমি বাড়ি যাবো.\"]\n", |
| 150 | + "payload = {\"instances\" : sentences}" |
| 151 | + ] |
| 152 | + }, |
| 153 | + { |
| 154 | + "cell_type": "code", |
| 155 | + "execution_count": null, |
| 156 | + "metadata": {}, |
| 157 | + "outputs": [], |
| 158 | + "source": [ |
| 159 | + "predictions = predictor.predict(payload)\n", |
| 160 | + "print(predictions)" |
| 161 | + ] |
| 162 | + }, |
| 163 | + { |
| 164 | + "cell_type": "markdown", |
| 165 | + "metadata": {}, |
| 166 | + "source": [ |
| 167 | + "FastText expects the class label to be prefixed by `__label__` and that's why when we are performing inference with pre-trained model provided by FastText, we can see that the output label is prefixed with `__label__`. With a little preprocessing, we can strip the `__label__` prefix from the response." |
| 168 | + ] |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "code", |
| 172 | + "execution_count": null, |
| 173 | + "metadata": {}, |
| 174 | + "outputs": [], |
| 175 | + "source": [ |
| 176 | + "import copy\n", |
| 177 | + "predictions_copy = copy.deepcopy(predictions) # Copying predictions object because we want to change the labels in-place\n", |
| 178 | + "for output in predictions_copy:\n", |
| 179 | + " output['label'] = output['label'][0][9:].upper() #__label__ has length of 9\n", |
| 180 | + "\n", |
| 181 | + "print(predictions_copy)" |
| 182 | + ] |
| 183 | + }, |
| 184 | + { |
| 185 | + "cell_type": "markdown", |
| 186 | + "metadata": {}, |
| 187 | + "source": [ |
| 188 | + "### Stop / Close the Endpoint (Optional)\n", |
| 189 | + "Finally, we should delete the endpoint before we close the notebook if we don't need to keep the endpoint running for serving realtime predictions." |
| 190 | + ] |
| 191 | + }, |
| 192 | + { |
| 193 | + "cell_type": "code", |
| 194 | + "execution_count": null, |
| 195 | + "metadata": {}, |
| 196 | + "outputs": [], |
| 197 | + "source": [ |
| 198 | + "sess.delete_endpoint(predictor.endpoint)" |
| 199 | + ] |
| 200 | + }, |
| 201 | + { |
| 202 | + "cell_type": "markdown", |
| 203 | + "metadata": {}, |
| 204 | + "source": [ |
| 205 | + "Similarly, we can host any pre-trained [FastText word2vec model](https://fasttext.cc/docs/en/pretrained-vectors.html) using SageMaker BlazingText hosting." |
| 206 | + ] |
| 207 | + } |
| 208 | + ], |
| 209 | + "metadata": { |
| 210 | + "kernelspec": { |
| 211 | + "display_name": "conda_python3", |
| 212 | + "language": "python", |
| 213 | + "name": "conda_python3" |
| 214 | + }, |
| 215 | + "language_info": { |
| 216 | + "codemirror_mode": { |
| 217 | + "name": "ipython", |
| 218 | + "version": 3 |
| 219 | + }, |
| 220 | + "file_extension": ".py", |
| 221 | + "mimetype": "text/x-python", |
| 222 | + "name": "python", |
| 223 | + "nbconvert_exporter": "python", |
| 224 | + "pygments_lexer": "ipython3", |
| 225 | + "version": "3.6.2" |
| 226 | + }, |
| 227 | + "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." |
| 228 | + }, |
| 229 | + "nbformat": 4, |
| 230 | + "nbformat_minor": 2 |
| 231 | +} |
0 commit comments