@@ -735,6 +735,41 @@ def pdp_config():
735
735
return PDPConfig (features = ["F1" , "F2" ], grid_resolution = 20 )
736
736
737
737
738
+ def test_model_config_validations ():
739
+ new_model_endpoint_definition = {
740
+ "model_name" : "xgboost-model" ,
741
+ "instance_type" : "ml.c5.xlarge" ,
742
+ "instance_count" : 1 ,
743
+ }
744
+ existing_endpoint_definition = {"endpoint_name" : "existing_endpoint" }
745
+
746
+ with pytest .raises (AssertionError ):
747
+ # should be one of them
748
+ ModelConfig (
749
+ ** new_model_endpoint_definition ,
750
+ ** existing_endpoint_definition ,
751
+ )
752
+
753
+ with pytest .raises (AssertionError ):
754
+ # should be one of them
755
+ ModelConfig (
756
+ endpoint_name_prefix = "prefix" ,
757
+ ** existing_endpoint_definition ,
758
+ )
759
+
760
+ # success path for new model
761
+ assert ModelConfig (** new_model_endpoint_definition ).predictor_config == {
762
+ "initial_instance_count" : 1 ,
763
+ "instance_type" : "ml.c5.xlarge" ,
764
+ "model_name" : "xgboost-model" ,
765
+ }
766
+
767
+ # success path for existing endpoint
768
+ assert (
769
+ ModelConfig (** existing_endpoint_definition ).predictor_config == existing_endpoint_definition
770
+ )
771
+
772
+
738
773
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
739
774
def test_pre_training_bias (
740
775
name_from_base ,
@@ -1396,6 +1431,47 @@ def test_analysis_config_generator_for_bias_explainability(
1396
1431
assert actual == expected
1397
1432
1398
1433
1434
+ def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint (
1435
+ data_config , data_bias_config
1436
+ ):
1437
+ model_config = ModelConfig (endpoint_name = "existing_endpoint_name" )
1438
+ model_predicted_label_config = ModelPredictedLabelConfig (
1439
+ probability = "pr" ,
1440
+ label_headers = ["success" ],
1441
+ )
1442
+ actual = _AnalysisConfigGenerator .bias_and_explainability (
1443
+ data_config ,
1444
+ model_config ,
1445
+ model_predicted_label_config ,
1446
+ [SHAPConfig (), PDPConfig ()],
1447
+ data_bias_config ,
1448
+ pre_training_methods = "all" ,
1449
+ post_training_methods = "all" ,
1450
+ )
1451
+ expected = {
1452
+ "dataset_type" : "text/csv" ,
1453
+ "facet" : [{"name_or_index" : "F1" }],
1454
+ "group_variable" : "F2" ,
1455
+ "headers" : ["Label" , "F1" , "F2" , "F3" , "F4" ],
1456
+ "joinsource_name_or_index" : "F4" ,
1457
+ "label" : "Label" ,
1458
+ "label_values_or_threshold" : [1 ],
1459
+ "methods" : {
1460
+ "pdp" : {"grid_resolution" : 15 , "top_k_features" : 10 },
1461
+ "post_training_bias" : {"methods" : "all" },
1462
+ "pre_training_bias" : {"methods" : "all" },
1463
+ "report" : {"name" : "report" , "title" : "Analysis Report" },
1464
+ "shap" : {"save_local_shap_values" : True , "use_logit" : False },
1465
+ },
1466
+ "predictor" : {
1467
+ "label_headers" : ["success" ],
1468
+ "endpoint_name" : "existing_endpoint_name" ,
1469
+ "probability" : "pr" ,
1470
+ },
1471
+ }
1472
+ assert actual == expected
1473
+
1474
+
1399
1475
def test_analysis_config_generator_for_bias_pre_training (data_config , data_bias_config ):
1400
1476
actual = _AnalysisConfigGenerator .bias_pre_training (
1401
1477
data_config , data_bias_config , methods = "all"
0 commit comments