Skip to content

Commit 3b070ac

Browse files
qidewenwhenDewen Qi
authored andcommitted
change: Update s3 path of scheduling analysis config on ClarifyCheckStep
* change: Update s3 path of scheduling analysis config in ClarifyCheckStep Co-authored-by: Dewen Qi <[email protected]>
1 parent 4f1d329 commit 3b070ac

File tree

6 files changed

+190
-77
lines changed

6 files changed

+190
-77
lines changed

src/sagemaker/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4186,7 +4186,8 @@ def get_model_package_args(
41864186
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
41874187
or "PendingManualApproval" (default: "PendingManualApproval").
41884188
description (str): Model Package description (default: None).
4189-
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs.
4189+
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
4190+
(default: None).
41904191
container_def_list (list): A list of container defintiions (default: None).
41914192
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
41924193
Returns:
@@ -4267,7 +4268,8 @@ def get_create_model_package_request(
42674268
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
42684269
or "PendingManualApproval" (default: "PendingManualApproval").
42694270
description (str): Model Package description (default: None).
4270-
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs.
4271+
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
4272+
(default: None).
42714273
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
42724274
"""
42734275

src/sagemaker/workflow/check_job_config.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ class CheckJobConfig:
3232
def __init__(
3333
self,
3434
role,
35-
image_uri=None,
3635
instance_count=1,
3736
instance_type="ml.m5.xlarge",
38-
entrypoint=None,
3937
volume_size_in_gb=30,
4038
volume_kms_key=None,
4139
output_kms_key=None,
@@ -50,14 +48,9 @@ def __init__(
5048
5149
Args:
5250
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
53-
image_uri (str): The uri of the image to use for the jobs
54-
started by the QualityCheckStep/ClarifyCheckStep (default: None).
55-
If not specified, the default auto-generated image_uri will be used.
5651
instance_count (int): The number of instances to run the jobs with (default: 1).
5752
instance_type (str): Type of EC2 instance to use for the job
5853
(default: 'ml.m5.xlarge').
59-
entrypoint ([str]): The entrypoint for the job (default: None).
60-
Only the QualityCheckStep will take this input.
6154
volume_size_in_gb (int): Size in GB of the EBS volume
6255
to use for storing data during processing (default: 30).
6356
volume_kms_key (str): A KMS key for the processing volume (default: None).
@@ -77,12 +70,11 @@ def __init__(
7770
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
7871
object that configures network isolation, encryption of
7972
inter-container traffic, security group IDs, and subnets (default: None).
73+
8074
"""
8175
self.role = role
82-
self.image_uri = image_uri
8376
self.instance_count = instance_count
8477
self.instance_type = instance_type
85-
self.entrypoint = entrypoint
8678
self.volume_size_in_gb = volume_size_in_gb
8779
self.volume_kms_key = volume_kms_key
8880
self.output_kms_key = output_kms_key
@@ -174,7 +166,4 @@ def _generate_model_monitor(self, mm_type: str) -> Optional[ModelMonitor]:
174166
'"ModelBiasMonitor", "ModelExplainabilityMonitor"'
175167
)
176168
return None
177-
178-
monitor.image_uri = self.image_uri or monitor.image_uri
179-
monitor.entrypoint = self.entrypoint or monitor.entrypoint
180169
return monitor

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import attr
2424

25+
from sagemaker import s3
2526
from sagemaker.clarify import (
2627
DataConfig,
2728
BiasConfig,
@@ -33,6 +34,7 @@
3334
_set,
3435
)
3536
from sagemaker.model_monitor import BiasAnalysisConfig, ExplainabilityAnalysisConfig
37+
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3638
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3739
from sagemaker.utils import name_from_base
3840
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
@@ -44,8 +46,8 @@
4446
_DATA_BIAS_TYPE = "DATA_BIAS"
4547
_MODEL_BIAS_TYPE = "MODEL_BIAS"
4648
_MODEL_EXPLAINABILITY_TYPE = "MODEL_EXPLAINABILITY"
47-
_BIAS_JOB_DEFINITION_BASE_NAME = "model-bias-job-definition"
48-
_EXPLAINABILITY_JOB_DEFINITION_BASE_NAME = "model-explainability-job-definition"
49+
_BIAS_MONITORING_CFG_BASE_NAME = "bias-monitoring"
50+
_EXPLAINABILITY_MONITORING_CFG_BASE_NAME = "model-explainability-monitoring"
4951

5052

5153
@attr.s
@@ -248,10 +250,14 @@ def __init__(
248250
@property
249251
def arguments(self) -> RequestType:
250252
"""The arguments dict that is used to define the ClarifyCheck step."""
253+
normalized_inputs, normalized_outputs = self._baselining_processor._normalize_args(
254+
inputs=[self._processing_params["config_input"], self._processing_params["data_input"]],
255+
outputs=[self._processing_params["result_output"]],
256+
)
251257
process_args = ProcessingJob._get_process_args(
252258
self._baselining_processor,
253-
[self._processing_params["config_input"], self._processing_params["data_input"]],
254-
[self._processing_params["result_output"]],
259+
normalized_inputs,
260+
normalized_outputs,
255261
experiment_config=dict(),
256262
)
257263
request_dict = self._baselining_processor.sagemaker_session._get_process_request(
@@ -392,10 +398,8 @@ def _upload_monitoring_analysis_config(self) -> str:
392398
Returns:
393399
str: The S3 uri of the uploaded monitoring schedule analysis config
394400
"""
395-
monitor_schedule_name = self._model_monitor._generate_monitoring_schedule_name()
396-
output_s3_uri = self._model_monitor._normalize_monitoring_output(
397-
monitor_schedule_name
398-
).destination
401+
402+
output_s3_uri = self._get_s3_base_uri_for_monitoring_analysis_config()
399403

400404
if isinstance(self.clarify_check_config, ModelExplainabilityCheckConfig):
401405
# Explainability analysis doesn't need label
@@ -410,7 +414,9 @@ def _upload_monitoring_analysis_config(self) -> str:
410414
analysis_config = explainability_analysis_config._to_dict()
411415
if "predictor" in analysis_config and "model_name" in analysis_config["predictor"]:
412416
analysis_config["predictor"].pop("model_name")
413-
job_definition_name = name_from_base(_EXPLAINABILITY_JOB_DEFINITION_BASE_NAME)
417+
job_definition_name = name_from_base(
418+
f"{_EXPLAINABILITY_MONITORING_CFG_BASE_NAME}-config"
419+
)
414420

415421
else:
416422
bias_analysis_config = BiasAnalysisConfig(
@@ -419,8 +425,33 @@ def _upload_monitoring_analysis_config(self) -> str:
419425
label=self.clarify_check_config.data_config.label,
420426
)
421427
analysis_config = bias_analysis_config._to_dict()
422-
job_definition_name = name_from_base(_BIAS_JOB_DEFINITION_BASE_NAME)
428+
job_definition_name = name_from_base(f"{_BIAS_MONITORING_CFG_BASE_NAME}-config")
423429

424430
return self._model_monitor._upload_analysis_config(
425431
analysis_config, output_s3_uri, job_definition_name
426432
)
433+
434+
def _get_s3_base_uri_for_monitoring_analysis_config(self) -> str:
435+
"""Generate s3 base uri for monitoring schedule analysis config
436+
437+
Returns:
438+
str: The S3 base uri of the monitoring schedule analysis config
439+
"""
440+
s3_analysis_config_output_path = (
441+
self.clarify_check_config.data_config.s3_analysis_config_output_path
442+
)
443+
monitoring_cfg_base_name = f"{_BIAS_MONITORING_CFG_BASE_NAME}-configuration"
444+
if isinstance(self.clarify_check_config, ModelExplainabilityCheckConfig):
445+
monitoring_cfg_base_name = f"{_EXPLAINABILITY_MONITORING_CFG_BASE_NAME}-configuration"
446+
447+
if s3_analysis_config_output_path:
448+
return s3.s3_path_join(
449+
s3_analysis_config_output_path,
450+
monitoring_cfg_base_name,
451+
)
452+
return s3.s3_path_join(
453+
"s3://",
454+
self._model_monitor.sagemaker_session.default_bucket(),
455+
_MODEL_MONITOR_S3_PATH,
456+
monitoring_cfg_base_name,
457+
)

src/sagemaker/workflow/quality_check_step.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,16 +222,21 @@ def __init__(
222222
@property
223223
def arguments(self) -> RequestType:
224224
"""The arguments dict that is used to define the QualityCheck step."""
225+
normalized_inputs, normalized_outputs = self._baselining_processor._normalize_args(
226+
inputs=self._baseline_job_inputs,
227+
outputs=[self._baseline_output],
228+
)
225229
process_args = ProcessingJob._get_process_args(
226230
self._baselining_processor,
227-
self._baseline_job_inputs,
228-
[self._baseline_output],
231+
normalized_inputs,
232+
normalized_outputs,
229233
experiment_config=dict(),
230234
)
231235
request_dict = self._baselining_processor.sagemaker_session._get_process_request(
232236
**process_args
233237
)
234-
request_dict.pop("ProcessingJobName")
238+
if "ProcessingJobName" in request_dict:
239+
request_dict.pop("ProcessingJobName")
235240

236241
return request_dict
237242

0 commit comments

Comments
 (0)