Skip to content

Commit f43ff06

Browse files
icywang86ruiRui Wang Napieralski
and
Rui Wang Napieralski
authored
change: refactor _create_model_request (#1963)
Co-authored-by: Rui Wang Napieralski <[email protected]>
1 parent fc90363 commit f43ff06

File tree

1 file changed

+53
-44
lines changed

1 file changed

+53
-44
lines changed

src/sagemaker/session.py

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,6 +2314,51 @@ def transform(
23142314
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
23152315
self.sagemaker_client.create_transform_job(**transform_request)
23162316

2317+
def _create_model_request(
2318+
self,
2319+
name,
2320+
role,
2321+
container_defs,
2322+
vpc_config=None,
2323+
enable_network_isolation=False,
2324+
primary_container=None,
2325+
tags=None,
2326+
): # pylint: disable=redefined-outer-name
2327+
"""Placeholder docstring"""
2328+
if container_defs and primary_container:
2329+
raise ValueError("Both container_defs and primary_container can not be passed as input")
2330+
2331+
if primary_container:
2332+
msg = (
2333+
"primary_container is going to be deprecated in a future release. Please use "
2334+
"container_defs instead."
2335+
)
2336+
warnings.warn(msg, DeprecationWarning)
2337+
container_defs = primary_container
2338+
2339+
role = self.expand_role(role)
2340+
2341+
if isinstance(container_defs, list):
2342+
container_definition = container_defs
2343+
else:
2344+
container_definition = _expand_container_def(container_defs)
2345+
2346+
request = {"ModelName": name, "ExecutionRoleArn": role}
2347+
if isinstance(container_definition, list):
2348+
request["Containers"] = container_definition
2349+
else:
2350+
request["PrimaryContainer"] = container_definition
2351+
if tags:
2352+
request["Tags"] = tags
2353+
2354+
if vpc_config:
2355+
request["VpcConfig"] = vpc_config
2356+
2357+
if enable_network_isolation:
2358+
request["EnableNetworkIsolation"] = True
2359+
2360+
return request
2361+
23172362
def create_model(
23182363
self,
23192364
name,
@@ -2364,34 +2409,15 @@ def create_model(
23642409
Returns:
23652410
str: Name of the Amazon SageMaker ``Model`` created.
23662411
"""
2367-
if container_defs and primary_container:
2368-
raise ValueError("Both container_defs and primary_container can not be passed as input")
2369-
2370-
if primary_container:
2371-
msg = (
2372-
"primary_container is going to be deprecated in a future release. Please use "
2373-
"container_defs instead."
2374-
)
2375-
warnings.warn(msg, DeprecationWarning)
2376-
container_defs = primary_container
2377-
2378-
role = self.expand_role(role)
2379-
2380-
if isinstance(container_defs, list):
2381-
container_definition = container_defs
2382-
else:
2383-
container_definition = _expand_container_def(container_defs)
2384-
2385-
create_model_request = _create_model_request(
2386-
name=name, role=role, container_def=container_definition, tags=tags
2412+
create_model_request = self._create_model_request(
2413+
name=name,
2414+
role=role,
2415+
container_defs=container_defs,
2416+
vpc_config=vpc_config,
2417+
enable_network_isolation=enable_network_isolation,
2418+
primary_container=primary_container,
2419+
tags=tags,
23872420
)
2388-
2389-
if vpc_config:
2390-
create_model_request["VpcConfig"] = vpc_config
2391-
2392-
if enable_network_isolation:
2393-
create_model_request["EnableNetworkIsolation"] = True
2394-
23952421
LOGGER.info("Creating model with name: %s", name)
23962422
LOGGER.debug("CreateModel request: %s", json.dumps(create_model_request, indent=4))
23972423

@@ -3619,23 +3645,6 @@ def get_execution_role(sagemaker_session=None):
36193645
raise ValueError(message.format(arn))
36203646

36213647

3622-
def _create_model_request(
3623-
name, role, container_def=None, tags=None
3624-
): # pylint: disable=redefined-outer-name
3625-
"""Placeholder docstring"""
3626-
request = {"ModelName": name, "ExecutionRoleArn": role}
3627-
3628-
if isinstance(container_def, list):
3629-
request["Containers"] = container_def
3630-
else:
3631-
request["PrimaryContainer"] = container_def
3632-
3633-
if tags:
3634-
request["Tags"] = tags
3635-
3636-
return request
3637-
3638-
36393648
def _deployment_entity_exists(describe_fn):
36403649
"""Placeholder docstring"""
36413650
try:

0 commit comments

Comments
 (0)