Skip to content

Commit 42fa663

Browse files
author
Keshav Chandak
committed
feature: Added transform with monitoring pipeline step in transformer
1 parent 1c55297 commit 42fa663

File tree

2 files changed

+219
-4
lines changed

2 files changed

+219
-4
lines changed

src/sagemaker/transformer.py

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616
from typing import Union, Optional, List, Dict
1717
from botocore import exceptions
18-
18+
import logging
19+
import copy
20+
import time
1921
from sagemaker.job import _Job
20-
from sagemaker.session import Session
22+
from sagemaker.session import Session, get_execution_role
2123
from sagemaker.inputs import BatchDataCaptureConfig
2224
from sagemaker.workflow.entities import PipelineVariable
23-
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
25+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
2426
from sagemaker.workflow import is_pipeline_variable
2527
from sagemaker.utils import base_name_from_image, name_from_base
2628

@@ -247,6 +249,155 @@ def transform(
247249
if wait:
248250
self.latest_transform_job.wait(logs=logs)
249251

252+
def transform_with_monitoring(
253+
self,
254+
monitoring_config,
255+
monitoring_resource_config,
256+
data: str,
257+
data_type: str = "S3Prefix",
258+
content_type: str = None,
259+
compression_type: str = None,
260+
split_type: str = None,
261+
input_filter: str = None,
262+
output_filter: str = None,
263+
join_source: str = None,
264+
model_client_config: Dict[str, str] = None,
265+
batch_data_capture_config: BatchDataCaptureConfig = None,
266+
monitor_before_transform: bool = False,
267+
supplied_baseline_statistics: str = None,
268+
supplied_baseline_constraints: str = None,
269+
wait: bool = True,
270+
pipeline_name: str = None,
271+
role: str = None,
272+
):
273+
"""Runs a transform job with monitoring job.
274+
275+
Note that this function will not start a transform job immediately,
276+
instead, it will create a SageMaker Pipeline and execute it.g331
277+
If you provide an existing pipeline_name, no new pipeline will be created, otherwise,
278+
each transform_with_monitoring call will create a new pipeline and execute.
279+
280+
Args:
281+
monitoring_config (Union[
282+
`sagemaker.workflow.quality_check_step.QualityCheckConfig`,
283+
`sagemaker.workflow.quality_check_step.ClarifyCheckConfig`
284+
]): the monitoring configuration used for run model monitoring.
285+
monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
286+
the check job (processing job) cluster resource configuration.
287+
transform_step_args (_JobStepArguments): the transform step transform arguments.
288+
data (str): Input data location in S3 for the transform job
289+
data_type (str): What the S3 location defines (default: 'S3Prefix').
290+
Valid values:
291+
* 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix
292+
will be used as inputs for the transform job.
293+
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3
294+
object to use as an input for the transform job.
295+
content_type (str): MIME type of the input data (default: None).
296+
compression_type (str): Compression type of the input data, if
297+
compressed (default: None). Valid values: 'Gzip', None.
298+
split_type (str): The record delimiter for the input object
299+
(default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and
300+
'TFRecord'.
301+
input_filter (str): A JSONPath to select a portion of the input to
302+
pass to the algorithm container for inference. If you omit the
303+
field, it gets the value '$', representing the entire input.
304+
For CSV data, each row is taken as a JSON array,
305+
so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
306+
CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
307+
See `Supported JSONPath Operators
308+
<https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
309+
for a table of supported JSONPath operators.
310+
For more information, see the SageMaker API documentation for
311+
`CreateTransformJob
312+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
313+
Some examples: "$[1:]", "$.features" (default: None).
314+
output_filter (str): A JSONPath to select a portion of the
315+
joined/original output to return as the output.
316+
For more information, see the SageMaker API documentation for
317+
`CreateTransformJob
318+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
319+
Some examples: "$[1:]", "$.prediction" (default: None).
320+
join_source (str): The source of data to be joined to the transform
321+
output. It can be set to 'Input' meaning the entire input record
322+
will be joined to the inference result. You can use OutputFilter
323+
to select the useful portion before uploading to S3. (default:
324+
None). Valid values: Input, None.
325+
model_client_config (dict[str, str]): Model configuration.
326+
Dictionary contains two optional keys,
327+
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
328+
(default: ``None``).
329+
batch_data_capture_config (BatchDataCaptureConfig): Configuration object which
330+
specifies the configurations related to the batch data capture for the transform job
331+
(default: ``None``).
332+
monitor_before_transform (bgool): If to run data quality
333+
or model explainability monitoring type,
334+
a true value of this flag indicates running the check step before the transform job.
335+
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
336+
check step when a violation is detected.
337+
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
338+
to the supplied statistics object representing the statistics JSON file
339+
which will be used for drift to check (default: None).
340+
supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path
341+
to the supplied constraints object representing the constraints JSON file
342+
which will be used for drift to check (default: None).
343+
wait (bool): To determine if needed to wait for the pipeline execution to complete
344+
pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
345+
role (str): Execution role
346+
"""
347+
348+
transformer = self
349+
if not isinstance(self.sagemaker_session, PipelineSession):
350+
sagemaker_session = self.sagemaker_session
351+
self.sagemaker_session = None
352+
transformer = copy.deepcopy(self)
353+
transformer.sagemaker_session = PipelineSession()
354+
self.sagemaker_session = sagemaker_session
355+
356+
transform_step_args = transformer.transform(
357+
data=data,
358+
data_type=data_type,
359+
content_type=content_type,
360+
compression_type=compression_type,
361+
split_type=split_type,
362+
input_filter=input_filter,
363+
output_filter=output_filter,
364+
batch_data_capture_config=batch_data_capture_config,
365+
join_source=join_source,
366+
model_client_config=model_client_config,
367+
)
368+
369+
from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep
370+
371+
monitoring_batch_step = MonitorBatchTransformStep(
372+
name="MonitorBatchTransformStep",
373+
display_name="MonitorBatchTransformStep",
374+
description="",
375+
transform_step_args=transform_step_args,
376+
monitor_configuration=monitoring_config,
377+
check_job_configuration=monitoring_resource_config,
378+
monitor_before_transform=monitor_before_transform,
379+
supplied_baseline_constraints=supplied_baseline_constraints,
380+
supplied_baseline_statistics=supplied_baseline_statistics,
381+
)
382+
383+
pipeline_name = (
384+
pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}"
385+
)
386+
# if pipeline exists, just start the execution
387+
from sagemaker.workflow.pipeline import Pipeline
388+
389+
pipeline = Pipeline(
390+
name=pipeline_name,
391+
steps=[monitoring_batch_step],
392+
sagemaker_session=transformer.sagemaker_session,
393+
)
394+
pipeline.upsert(role_arn=role if role else get_execution_role())
395+
execution = pipeline.start()
396+
if wait:
397+
logging.info("Waiting for transform with monitoring to execute ...")
398+
execution.wait()
399+
return execution
400+
250401
def delete_model(self):
251402
"""Delete the corresponding SageMaker model for this Transformer."""
252403
self.sagemaker_session.delete_model(self.model_name)

tests/integ/test_transformer.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.transformer import Transformer
2626
from sagemaker.estimator import Estimator
2727
from sagemaker.inputs import BatchDataCaptureConfig
28+
from sagemaker.xgboost import XGBoostModel
2829
from sagemaker.utils import unique_name_from_base
2930
from tests.integ import (
3031
datasets,
@@ -36,7 +37,7 @@
3637
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
3738
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
3839

39-
from sagemaker.model_monitor import DatasetFormat, Statistics
40+
from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints
4041

4142
from sagemaker.workflow.check_job_config import CheckJobConfig
4243
from sagemaker.workflow.quality_check_step import (
@@ -645,3 +646,66 @@ def _create_transformer_and_transform_job(
645646
job_name=unique_name_from_base("test-transform"),
646647
)
647648
return transformer
649+
650+
651+
def test_transformer_and_monitoring_job(
652+
pipeline_session,
653+
sagemaker_session,
654+
role,
655+
pipeline_name,
656+
check_job_config,
657+
data_bias_check_config,
658+
):
659+
xgb_model_data_s3 = pipeline_session.upload_data(
660+
path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
661+
key_prefix="integ-test-data/xgboost/model",
662+
)
663+
data_bias_supplied_baseline_constraints = Constraints.from_file_path(
664+
constraints_file_path=os.path.join(
665+
DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json"
666+
),
667+
sagemaker_session=sagemaker_session,
668+
).file_s3_uri
669+
670+
xgb_model = XGBoostModel(
671+
model_data=xgb_model_data_s3,
672+
framework_version="1.3-1",
673+
role=role,
674+
sagemaker_session=sagemaker_session,
675+
entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
676+
enable_network_isolation=True,
677+
)
678+
679+
xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)
680+
681+
transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
682+
transformer = Transformer(
683+
model_name=xgb_model.name,
684+
strategy="SingleRecord",
685+
instance_type="ml.m5.xlarge",
686+
instance_count=1,
687+
output_path=transform_output,
688+
sagemaker_session=pipeline_session,
689+
)
690+
691+
transform_input = pipeline_session.upload_data(
692+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
693+
key_prefix="integ-test-data/xgboost_abalone/abalone",
694+
)
695+
696+
execution = transformer.transform_with_monitoring(
697+
monitoring_config=data_bias_check_config,
698+
monitoring_resource_config=check_job_config,
699+
data=transform_input,
700+
content_type="text/libsvm",
701+
supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
702+
role=role,
703+
)
704+
705+
execution_steps = execution.list_steps()
706+
assert len(execution_steps) == 2
707+
708+
for execution_step in execution_steps:
709+
assert execution_step["StepStatus"] == "Succeeded"
710+
711+
xgb_model.delete_model()

0 commit comments

Comments
 (0)