22
22
23
23
import attr
24
24
25
+ from sagemaker import s3
25
26
from sagemaker .clarify import (
26
27
DataConfig ,
27
28
BiasConfig ,
33
34
_set ,
34
35
)
35
36
from sagemaker .model_monitor import BiasAnalysisConfig , ExplainabilityAnalysisConfig
37
+ from sagemaker .model_monitor .model_monitoring import _MODEL_MONITOR_S3_PATH
36
38
from sagemaker .processing import ProcessingInput , ProcessingOutput , ProcessingJob
37
39
from sagemaker .utils import name_from_base
38
40
from sagemaker .workflow import PipelineNonPrimitiveInputTypes , ExecutionVariable , Parameter
44
46
_DATA_BIAS_TYPE = "DATA_BIAS"
45
47
_MODEL_BIAS_TYPE = "MODEL_BIAS"
46
48
_MODEL_EXPLAINABILITY_TYPE = "MODEL_EXPLAINABILITY"
47
- _BIAS_JOB_DEFINITION_BASE_NAME = "model- bias-job-definition "
48
- _EXPLAINABILITY_JOB_DEFINITION_BASE_NAME = "model-explainability-job-definition "
49
+ _BIAS_MONITORING_CFG_BASE_NAME = "bias-monitoring "
50
+ _EXPLAINABILITY_MONITORING_CFG_BASE_NAME = "model-explainability-monitoring "
49
51
50
52
51
53
@attr .s
@@ -248,10 +250,14 @@ def __init__(
248
250
@property
249
251
def arguments (self ) -> RequestType :
250
252
"""The arguments dict that is used to define the ClarifyCheck step."""
253
+ normalized_inputs , normalized_outputs = self ._baselining_processor ._normalize_args (
254
+ inputs = [self ._processing_params ["config_input" ], self ._processing_params ["data_input" ]],
255
+ outputs = [self ._processing_params ["result_output" ]],
256
+ )
251
257
process_args = ProcessingJob ._get_process_args (
252
258
self ._baselining_processor ,
253
- [ self . _processing_params [ "config_input" ], self . _processing_params [ "data_input" ]] ,
254
- [ self . _processing_params [ "result_output" ]] ,
259
+ normalized_inputs ,
260
+ normalized_outputs ,
255
261
experiment_config = dict (),
256
262
)
257
263
request_dict = self ._baselining_processor .sagemaker_session ._get_process_request (
@@ -392,10 +398,8 @@ def _upload_monitoring_analysis_config(self) -> str:
392
398
Returns:
393
399
str: The S3 uri of the uploaded monitoring schedule analysis config
394
400
"""
395
- monitor_schedule_name = self ._model_monitor ._generate_monitoring_schedule_name ()
396
- output_s3_uri = self ._model_monitor ._normalize_monitoring_output (
397
- monitor_schedule_name
398
- ).destination
401
+
402
+ output_s3_uri = self ._get_s3_base_uri_for_monitoring_analysis_config ()
399
403
400
404
if isinstance (self .clarify_check_config , ModelExplainabilityCheckConfig ):
401
405
# Explainability analysis doesn't need label
@@ -410,7 +414,9 @@ def _upload_monitoring_analysis_config(self) -> str:
410
414
analysis_config = explainability_analysis_config ._to_dict ()
411
415
if "predictor" in analysis_config and "model_name" in analysis_config ["predictor" ]:
412
416
analysis_config ["predictor" ].pop ("model_name" )
413
- job_definition_name = name_from_base (_EXPLAINABILITY_JOB_DEFINITION_BASE_NAME )
417
+ job_definition_name = name_from_base (
418
+ f"{ _EXPLAINABILITY_MONITORING_CFG_BASE_NAME } -config"
419
+ )
414
420
415
421
else :
416
422
bias_analysis_config = BiasAnalysisConfig (
@@ -419,8 +425,33 @@ def _upload_monitoring_analysis_config(self) -> str:
419
425
label = self .clarify_check_config .data_config .label ,
420
426
)
421
427
analysis_config = bias_analysis_config ._to_dict ()
422
- job_definition_name = name_from_base (_BIAS_JOB_DEFINITION_BASE_NAME )
428
+ job_definition_name = name_from_base (f" { _BIAS_MONITORING_CFG_BASE_NAME } -config" )
423
429
424
430
return self ._model_monitor ._upload_analysis_config (
425
431
analysis_config , output_s3_uri , job_definition_name
426
432
)
433
+
434
+ def _get_s3_base_uri_for_monitoring_analysis_config (self ) -> str :
435
+ """Generate s3 base uri for monitoring schedule analysis config
436
+
437
+ Returns:
438
+ str: The S3 base uri of the monitoring schedule analysis config
439
+ """
440
+ s3_analysis_config_output_path = (
441
+ self .clarify_check_config .data_config .s3_analysis_config_output_path
442
+ )
443
+ monitoring_cfg_base_name = f"{ _BIAS_MONITORING_CFG_BASE_NAME } -configuration"
444
+ if isinstance (self .clarify_check_config , ModelExplainabilityCheckConfig ):
445
+ monitoring_cfg_base_name = f"{ _EXPLAINABILITY_MONITORING_CFG_BASE_NAME } -configuration"
446
+
447
+ if s3_analysis_config_output_path :
448
+ return s3 .s3_path_join (
449
+ s3_analysis_config_output_path ,
450
+ monitoring_cfg_base_name ,
451
+ )
452
+ return s3 .s3_path_join (
453
+ "s3://" ,
454
+ self ._model_monitor .sagemaker_session .default_bucket (),
455
+ _MODEL_MONITOR_S3_PATH ,
456
+ monitoring_cfg_base_name ,
457
+ )
0 commit comments