|
| 1 | +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Example workflow pipeline script for CustomerChurn pipeline. |
| 14 | +
|
| 15 | + . -RegisterModel |
| 16 | + . |
| 17 | + Process-> Train -> Evaluate -> Condition . |
| 18 | + . |
| 19 | + . -(stop) |
| 20 | +
|
| 21 | +Implements a get_pipeline(**kwargs) method. |
| 22 | +""" |
| 23 | + |
| 24 | +import os |
| 25 | + |
| 26 | +import boto3 |
| 27 | +import sagemaker |
| 28 | +import sagemaker.session |
| 29 | + |
| 30 | +from sagemaker.estimator import Estimator |
| 31 | +from sagemaker.inputs import TrainingInput |
| 32 | +from sagemaker.processing import ( |
| 33 | + ProcessingInput, |
| 34 | + ProcessingOutput, |
| 35 | + ScriptProcessor, |
| 36 | +) |
| 37 | +from sagemaker.sklearn.processing import SKLearnProcessor |
| 38 | +from sagemaker.workflow.conditions import ( |
| 39 | + ConditionGreaterThanOrEqualTo, |
| 40 | +) |
| 41 | +from sagemaker.workflow.condition_step import ( |
| 42 | + ConditionStep, |
| 43 | + JsonGet, |
| 44 | +) |
| 45 | +from sagemaker.model_metrics import ( |
| 46 | + MetricsSource, |
| 47 | + ModelMetrics, |
| 48 | +) |
| 49 | +from sagemaker.workflow.parameters import ( |
| 50 | + ParameterInteger, |
| 51 | + ParameterString, |
| 52 | +) |
| 53 | +from sagemaker.workflow.pipeline import Pipeline |
| 54 | +from sagemaker.workflow.properties import PropertyFile |
| 55 | +from sagemaker.workflow.steps import ( |
| 56 | + ProcessingStep, |
| 57 | + TrainingStep, |
| 58 | +) |
| 59 | +from sagemaker.workflow.step_collections import RegisterModel |
| 60 | + |
| 61 | + |
| 62 | +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) |
| 63 | + |
| 64 | + |
| 65 | +def get_session(region, default_bucket): |
| 66 | + """Gets the sagemaker session based on the region. |
| 67 | +
|
| 68 | + Args: |
| 69 | + region: the aws region to start the session |
| 70 | + default_bucket: the bucket to use for storing the artifacts |
| 71 | +
|
| 72 | + Returns: |
| 73 | + `sagemaker.session.Session instance |
| 74 | + """ |
| 75 | + |
| 76 | + boto_session = boto3.Session(region_name=region) |
| 77 | + |
| 78 | + sagemaker_client = boto_session.client("sagemaker") |
| 79 | + runtime_client = boto_session.client("sagemaker-runtime") |
| 80 | + return sagemaker.session.Session( |
| 81 | + boto_session=boto_session, |
| 82 | + sagemaker_client=sagemaker_client, |
| 83 | + sagemaker_runtime_client=runtime_client, |
| 84 | + default_bucket=default_bucket, |
| 85 | + ) |
| 86 | + |
| 87 | + |
| 88 | +def get_pipeline( |
| 89 | + region, |
| 90 | + role=None, |
| 91 | + default_bucket=None, |
| 92 | + model_package_group_name="CustomerChurnPackageGroup", # Choose any name |
| 93 | + pipeline_name="CustomerChurnDemo-p-ewf8t7lvhivm", # You can find your pipeline name in the Studio UI (project -> Pipelines -> name) |
| 94 | + base_job_prefix="CustomerChurn", # Choose any name |
| 95 | +): |
| 96 | + """Gets a SageMaker ML Pipeline instance working with on CustomerChurn data. |
| 97 | +
|
| 98 | + Args: |
| 99 | + region: AWS region to create and run the pipeline. |
| 100 | + role: IAM role to create and run steps and pipeline. |
| 101 | + default_bucket: the bucket to use for storing the artifacts |
| 102 | +
|
| 103 | + Returns: |
| 104 | + an instance of a pipeline |
| 105 | + """ |
| 106 | + sagemaker_session = get_session(region, default_bucket) |
| 107 | + if role is None: |
| 108 | + role = sagemaker.session.get_execution_role(sagemaker_session) |
| 109 | + |
| 110 | + # Parameters for pipeline execution |
| 111 | + processing_instance_count = ParameterInteger( |
| 112 | + name="ProcessingInstanceCount", default_value=1 |
| 113 | + ) |
| 114 | + processing_instance_type = ParameterString( |
| 115 | + name="ProcessingInstanceType", default_value="ml.m5.xlarge" |
| 116 | + ) |
| 117 | + training_instance_type = ParameterString( |
| 118 | + name="TrainingInstanceType", default_value="ml.m5.xlarge" |
| 119 | + ) |
| 120 | + model_approval_status = ParameterString( |
| 121 | + name="ModelApprovalStatus", |
| 122 | + default_value="PendingManualApproval", # ModelApprovalStatus can be set to a default of "Approved" if you don't want manual approval. |
| 123 | + ) |
| 124 | + input_data = ParameterString( |
| 125 | + name="InputDataUrl", |
| 126 | + default_value=f"s3://sm-pipelines-demo-data-123456789/churn.txt", # Change this to point to the s3 location of your raw input data. |
| 127 | + ) |
| 128 | + |
| 129 | + # Processing step for feature engineering |
| 130 | + sklearn_processor = SKLearnProcessor( |
| 131 | + framework_version="0.23-1", |
| 132 | + instance_type=processing_instance_type, |
| 133 | + instance_count=processing_instance_count, |
| 134 | + base_job_name=f"{base_job_prefix}/sklearn-CustomerChurn-preprocess", # choose any name |
| 135 | + sagemaker_session=sagemaker_session, |
| 136 | + role=role, |
| 137 | + ) |
| 138 | + step_process = ProcessingStep( |
| 139 | + name="CustomerChurnProcess", # choose any name |
| 140 | + processor=sklearn_processor, |
| 141 | + outputs=[ |
| 142 | + ProcessingOutput(output_name="train", source="/opt/ml/processing/train"), |
| 143 | + ProcessingOutput( |
| 144 | + output_name="validation", source="/opt/ml/processing/validation" |
| 145 | + ), |
| 146 | + ProcessingOutput(output_name="test", source="/opt/ml/processing/test"), |
| 147 | + ], |
| 148 | + code=os.path.join(BASE_DIR, "preprocess.py"), |
| 149 | + job_arguments=["--input-data", input_data], |
| 150 | + ) |
| 151 | + |
| 152 | + # Training step for generating model artifacts |
| 153 | + model_path = f"s3://{sagemaker_session.default_bucket()}/{base_job_prefix}/CustomerChurnTrain" |
| 154 | + image_uri = sagemaker.image_uris.retrieve( |
| 155 | + framework="xgboost", # we are using the Sagemaker built in xgboost algorithm |
| 156 | + region=region, |
| 157 | + version="1.0-1", |
| 158 | + py_version="py3", |
| 159 | + instance_type=training_instance_type, |
| 160 | + ) |
| 161 | + xgb_train = Estimator( |
| 162 | + image_uri=image_uri, |
| 163 | + instance_type=training_instance_type, |
| 164 | + instance_count=1, |
| 165 | + output_path=model_path, |
| 166 | + base_job_name=f"{base_job_prefix}/CustomerChurn-train", |
| 167 | + sagemaker_session=sagemaker_session, |
| 168 | + role=role, |
| 169 | + ) |
| 170 | + xgb_train.set_hyperparameters( |
| 171 | + objective="binary:logistic", |
| 172 | + num_round=50, |
| 173 | + max_depth=5, |
| 174 | + eta=0.2, |
| 175 | + gamma=4, |
| 176 | + min_child_weight=6, |
| 177 | + subsample=0.7, |
| 178 | + silent=0, |
| 179 | + ) |
| 180 | + step_train = TrainingStep( |
| 181 | + name="CustomerChurnTrain", |
| 182 | + estimator=xgb_train, |
| 183 | + inputs={ |
| 184 | + "train": TrainingInput( |
| 185 | + s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ |
| 186 | + "train" |
| 187 | + ].S3Output.S3Uri, |
| 188 | + content_type="text/csv", |
| 189 | + ), |
| 190 | + "validation": TrainingInput( |
| 191 | + s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ |
| 192 | + "validation" |
| 193 | + ].S3Output.S3Uri, |
| 194 | + content_type="text/csv", |
| 195 | + ), |
| 196 | + }, |
| 197 | + ) |
| 198 | + |
| 199 | + # Processing step for evaluation |
| 200 | + script_eval = ScriptProcessor( |
| 201 | + image_uri=image_uri, |
| 202 | + command=["python3"], |
| 203 | + instance_type=processing_instance_type, |
| 204 | + instance_count=1, |
| 205 | + base_job_name=f"{base_job_prefix}/script-CustomerChurn-eval", |
| 206 | + sagemaker_session=sagemaker_session, |
| 207 | + role=role, |
| 208 | + ) |
| 209 | + evaluation_report = PropertyFile( |
| 210 | + name="EvaluationReport", |
| 211 | + output_name="evaluation", |
| 212 | + path="evaluation.json", |
| 213 | + ) |
| 214 | + step_eval = ProcessingStep( |
| 215 | + name="CustomerChurnEval", |
| 216 | + processor=script_eval, |
| 217 | + inputs=[ |
| 218 | + ProcessingInput( |
| 219 | + source=step_train.properties.ModelArtifacts.S3ModelArtifacts, |
| 220 | + destination="/opt/ml/processing/model", |
| 221 | + ), |
| 222 | + ProcessingInput( |
| 223 | + source=step_process.properties.ProcessingOutputConfig.Outputs[ |
| 224 | + "test" |
| 225 | + ].S3Output.S3Uri, |
| 226 | + destination="/opt/ml/processing/test", |
| 227 | + ), |
| 228 | + ], |
| 229 | + outputs=[ |
| 230 | + ProcessingOutput( |
| 231 | + output_name="evaluation", source="/opt/ml/processing/evaluation" |
| 232 | + ), |
| 233 | + ], |
| 234 | + code=os.path.join(BASE_DIR, "evaluate.py"), |
| 235 | + property_files=[evaluation_report], |
| 236 | + ) |
| 237 | + |
| 238 | + # Register model step that will be conditionally executed |
| 239 | + model_metrics = ModelMetrics( |
| 240 | + model_statistics=MetricsSource( |
| 241 | + s3_uri="{}/evaluation.json".format( |
| 242 | + step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][ |
| 243 | + "S3Uri" |
| 244 | + ] |
| 245 | + ), |
| 246 | + content_type="application/json", |
| 247 | + ) |
| 248 | + ) |
| 249 | + |
| 250 | + # Register model step that will be conditionally executed |
| 251 | + step_register = RegisterModel( |
| 252 | + name="CustomerChurnRegisterModel", |
| 253 | + estimator=xgb_train, |
| 254 | + model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, |
| 255 | + content_types=["text/csv"], |
| 256 | + response_types=["text/csv"], |
| 257 | + inference_instances=["ml.t2.medium", "ml.m5.large"], |
| 258 | + transform_instances=["ml.m5.large"], |
| 259 | + model_package_group_name=model_package_group_name, |
| 260 | + approval_status=model_approval_status, |
| 261 | + model_metrics=model_metrics, |
| 262 | + ) |
| 263 | + |
| 264 | + # Condition step for evaluating model quality and branching execution |
| 265 | + cond_lte = ConditionGreaterThanOrEqualTo( # You can change the condition here |
| 266 | + left=JsonGet( |
| 267 | + step=step_eval, |
| 268 | + property_file=evaluation_report, |
| 269 | + json_path="binary_classification_metrics.accuracy.value", # This should follow the structure of your report_dict defined in the evaluate.py file. |
| 270 | + ), |
| 271 | + right=0.8, # You can change the threshold here |
| 272 | + ) |
| 273 | + step_cond = ConditionStep( |
| 274 | + name="CustomerChurnAccuracyCond", |
| 275 | + conditions=[cond_lte], |
| 276 | + if_steps=[step_register], |
| 277 | + else_steps=[], |
| 278 | + ) |
| 279 | + |
| 280 | + # Pipeline instance |
| 281 | + pipeline = Pipeline( |
| 282 | + name=pipeline_name, |
| 283 | + parameters=[ |
| 284 | + processing_instance_type, |
| 285 | + processing_instance_count, |
| 286 | + training_instance_type, |
| 287 | + model_approval_status, |
| 288 | + input_data, |
| 289 | + ], |
| 290 | + steps=[step_process, step_train, step_eval, step_cond], |
| 291 | + sagemaker_session=sagemaker_session, |
| 292 | + ) |
| 293 | + return pipeline |
0 commit comments