Skip to content

Commit 8505140

Browse files
apivovarovknikure
authored andcommitted
Add context to predict_fn example
1 parent 67d8faa commit 8505140

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ The following example is for use cases with multiple GPUs and shows an overridde
772772
import torch
773773
import numpy as np
774774
775-
def predict_fn(input_data, model):
775+
def predict_fn(input_data, model, context):
776776
device = torch.device("cuda:" + str(context.system_properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
777777
model.to(device)
778778
model.eval()

0 commit comments

Comments
 (0)