|
19 | 19 | import tarfile
|
20 | 20 | import tempfile
|
21 | 21 |
|
| 22 | +# Repack Model |
| 23 | +# The following script is run via a training job which takes an existing model and a custom |
| 24 | +# entry point script as arguments. The script creates a new model archive with the custom |
| 25 | +# entry point in the "code" directory along with the existing model. Subsequently, when the model |
| 26 | +# is unpacked for inference, the custom entry point will be used. |
| 27 | +# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html |
| 28 | + |
22 | 29 | # distutils.dir_util.copy_tree works way better than the half-baked
|
23 | 30 | # shutil.copytree which bombs on previously existing target dirs...
|
24 | 31 | # alas ... https://bugs.python.org/issue10948
|
|
33 | 40 | parser.add_argument("--model_archive", type=str, default="model.tar.gz")
|
34 | 41 | args = parser.parse_args()
|
35 | 42 |
|
| 43 | + # the data directory contains a model archive generated by a previous training job |
36 | 44 | data_directory = "/opt/ml/input/data/training"
|
37 | 45 | model_path = os.path.join(data_directory, args.model_archive)
|
38 | 46 |
|
| 47 | + # create a temporary directory |
39 | 48 | with tempfile.TemporaryDirectory() as tmp:
|
40 | 49 | local_path = os.path.join(tmp, "local.tar.gz")
|
| 50 | + # copy the previous training job's model archive to the temporary directory |
41 | 51 | shutil.copy2(model_path, local_path)
|
42 | 52 | src_dir = os.path.join(tmp, "src")
|
| 53 | + # create the "code" directory which will contain the inference script |
| 54 | + os.makedirs(os.path.join(src_dir, "code")) |
| 55 | + # extract the contents of the previous training job's model archive to the "src" |
| 56 | + # directory of this training job |
43 | 57 | with tarfile.open(name=local_path, mode="r:gz") as tf:
|
44 | 58 | tf.extractall(path=src_dir)
|
45 | 59 |
|
| 60 | + # generate a path to the custom inference script |
46 | 61 | entry_point = os.path.join("/opt/ml/code", args.inference_script)
|
47 |
| - shutil.copy2(entry_point, os.path.join(src_dir, args.inference_script)) |
| 62 | + # copy the custom inference script to the "src" dir |
| 63 | + shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script)) |
48 | 64 |
|
| 65 | + # copy the "src" dir, which includes the previous training job's model and the |
| 66 | + # custom inference script, to the output of this training job |
49 | 67 | copy_tree(src_dir, "/opt/ml/model")
|
0 commit comments