Skip to content

Commit 8062a54

Browse files
authored
Fix predictor __init__ staggering (#1591)
1 parent b02b951 commit 8062a54

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

pkg/workloads/cortex/lib/api/predictor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@
5252
ModelsTree, # only when num workers = 1
5353
)
5454

55-
# concurrency
56-
from cortex.lib.concurrency import FileLock
57-
5855
# model validation
5956
from cortex.lib.model import validate_model_paths
6057

@@ -245,10 +242,9 @@ def class_impl(self, project_dir):
245242
validations = PYTHON_CLASS_VALIDATION
246243

247244
try:
248-
with FileLock("/run/init_stagger.lock"):
249-
predictor_class = self._get_class_impl(
250-
"cortex_predictor", os.path.join(project_dir, self.path), target_class_name
251-
)
245+
predictor_class = self._get_class_impl(
246+
"cortex_predictor", os.path.join(project_dir, self.path), target_class_name
247+
)
252248
except CortexException as e:
253249
e.wrap("error in " + self.path)
254250
raise

pkg/workloads/cortex/serve/serve.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from cortex.lib import util
3636
from cortex.lib.api import API, get_api
3737
from cortex.lib.log import cx_logger as logger
38-
from cortex.lib.concurrency import LockedFile
38+
from cortex.lib.concurrency import FileLock, LockedFile
3939
from cortex.lib.storage import S3, LocalStorage
4040
from cortex.lib.exceptions import UserRuntimeException
4141

@@ -304,8 +304,10 @@ def start_fn():
304304
client = api.predictor.initialize_client(
305305
tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port
306306
)
307-
logger().info("loading the predictor from {}".format(api.predictor.path))
308-
predictor_impl = api.predictor.initialize_impl(project_dir, client)
307+
308+
with FileLock("/run/init_stagger.lock"):
309+
logger().info("loading the predictor from {}".format(api.predictor.path))
310+
predictor_impl = api.predictor.initialize_impl(project_dir, client)
309311

310312
local_cache["api"] = api
311313
local_cache["provider"] = provider

0 commit comments

Comments
 (0)