Skip to content

Commit cbc9074

Browse files
feature add Neo image uri config for Pytorch 1.12
1 parent 5d4c3e2 commit cbc9074

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/data/pytorch_neo/code/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def model_fn(model_dir):
7171
logger.info("model_fn")
7272
neopytorch.config(model_dir=model_dir, neo_runtime=True)
7373
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74-
# The compiled model is saved as "model.pth"
75-
model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device)
74+
# The compiled model is saved as "model.pth" or "model.pt"
75+
model = torch.jit.load(os.path.join(model_dir, "model.pt"), map_location=device)
7676

7777
# It is recommended to run warm-up inference during model load
7878
sample_input_path = os.path.join(model_dir, "sample_input.pkl")

0 commit comments

Comments
 (0)