@@ -1632,17 +1632,29 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
1632
1632
return neo_bucket
1633
1633
1634
1634
1635
- def remove_env_var_from_estimator_kwargs_if_accept_eula_present (
1636
- init_kwargs : dict , accept_eula : Optional [bool ]
1635
+ def remove_env_var_from_estimator_kwargs_if_model_access_config_present (
1636
+ init_kwargs : dict , model_access_config : Optional [dict ]
1637
1637
):
1638
- """Remove env vars if access configs are used
1638
+ """Remove env vars if ModelAccessConfig is used
1639
1639
1640
1640
Args:
1641
1641
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
1642
1642
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
1643
1643
"""
1644
- if accept_eula is not None and init_kwargs ["environment" ]:
1645
- del init_kwargs ["environment" ][constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY ]
1644
+ if (
1645
+ model_access_config is not None
1646
+ and init_kwargs .get ("environment" ) is not None
1647
+ and init_kwargs .get ("model_uri" ) is not None
1648
+ ):
1649
+ if (
1650
+ constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1651
+ in init_kwargs ["environment" ]
1652
+ ):
1653
+ del init_kwargs ["environment" ][
1654
+ constants .SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
1655
+ ]
1656
+ if "accept_eula" in init_kwargs ["environment" ]:
1657
+ del init_kwargs ["environment" ]["accept_eula" ]
1646
1658
1647
1659
1648
1660
def get_hub_access_config (hub_content_arn : Optional [str ]):
@@ -1659,16 +1671,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]):
1659
1671
return hub_access_config
1660
1672
1661
1673
1662
- def get_model_access_config (accept_eula : Optional [bool ]):
1674
+ def get_model_access_config (accept_eula : Optional [bool ], environment : Optional [ dict ] ):
1663
1675
"""Get access configs
1664
1676
1665
1677
Args:
1666
1678
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
1667
1679
"""
1680
+ env_var_eula = environment .get ("accept_eula" ) if environment else None
1681
+ if env_var_eula is not None and accept_eula is not None :
1682
+ raise ValueError (
1683
+ "Cannot pass in both accept_eula and environment variables. "
1684
+ "Please remove the environment variable and pass in the accept_eula parameter."
1685
+ )
1686
+
1687
+ model_access_config = None
1688
+ if env_var_eula is not None :
1689
+ model_access_config = {"AcceptEula" : env_var_eula == "true" }
1668
1690
if accept_eula is not None :
1669
1691
model_access_config = {"AcceptEula" : accept_eula }
1670
- else :
1671
- model_access_config = None
1672
1692
1673
1693
return model_access_config
1674
1694
0 commit comments