Skip to content

Commit 68b2149

Browse files
committed
chore: updates
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 3947c2a commit 68b2149

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

docsrc/user_guide/saving_models.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ The following code illustrates this approach.
2929
import torch_tensorrt
3030
3131
model = MyModel().eval().cuda()
32-
inputs = torch.randn((1, 3, 224, 224)).cuda()
32+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
3333
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
3434
trt_script_model = torch.jit.trace(trt_gm, inputs)
3535
torch.jit.save(trt_script_model, "trt_model.ts")
3636
3737
# Later, you can load it and run inference
3838
model = torch.jit.load("trt_model.ts").cuda()
39-
model(inputs)
39+
model(*inputs)
4040
4141
b) ExportedProgram
4242
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -50,15 +50,15 @@ b) ExportedProgram
5050
import torch_tensorrt
5151
5252
model = MyModel().eval().cuda()
53-
inputs = torch.randn((1, 3, 224, 224)).cuda()
53+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5454
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
5555
# Transform and create an exported program
5656
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
5757
torch.export.save(trt_exp_program, "trt_model.ep")
5858
5959
# Later, you can load it and run inference
6060
model = torch.export.load("trt_model.ep")
61-
model(inputs)
61+
model(*inputs)
6262
6363
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
6464
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
@@ -78,11 +78,11 @@ Torchscript IR
7878
import torch_tensorrt
7979
8080
model = MyModel().eval().cuda()
81-
inputs = torch.randn((1, 3, 224, 224)).cuda()
81+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
8282
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
8383
torch.jit.save(trt_ts, "trt_model.ts")
8484
8585
# Later, you can load it and run inference
8686
model = torch.jit.load("trt_model.ts").cuda()
87-
model(inputs)
87+
model(*inputs)
8888

0 commit comments

Comments
 (0)