@@ -29,14 +29,14 @@ The following code illustrates this approach.
29
29
import torch_tensorrt
30
30
31
31
model = MyModel().eval().cuda()
32
- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
32
+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
33
33
trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs) # Output is a torch.fx.GraphModule
34
34
trt_script_model = torch.jit.trace(trt_gm, inputs)
35
35
torch.jit.save(trt_script_model, " trt_model.ts" )
36
36
37
37
# Later, you can load it and run inference
38
38
model = torch.jit.load(" trt_model.ts" ).cuda()
39
- model(inputs)
39
+ model(* inputs)
40
40
41
41
b) ExportedProgram
42
42
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -50,15 +50,15 @@ b) ExportedProgram
50
50
import torch_tensorrt
51
51
52
52
model = MyModel().eval().cuda()
53
- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
53
+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
54
54
trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs) # Output is a torch.fx.GraphModule
55
55
# Transform and create an exported program
56
56
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
57
57
torch.export.save(trt_exp_program, " trt_model.ep" )
58
58
59
59
# Later, you can load it and run inference
60
60
model = torch.export.load(" trt_model.ep" )
61
- model(inputs)
61
+ model(* inputs)
62
62
63
63
`torch_tensorrt.dynamo.export ` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
64
64
This is needed as `torch._export ` serialization cannot handle serializing and deserializing of submodules (`call_module ` nodes).
@@ -78,11 +78,11 @@ Torchscript IR
78
78
import torch_tensorrt
79
79
80
80
model = MyModel().eval().cuda()
81
- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
81
+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
82
82
trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs) # Output is a ScriptModule object
83
83
torch.jit.save(trt_ts, " trt_model.ts" )
84
84
85
85
# Later, you can load it and run inference
86
86
model = torch.jit.load(" trt_model.ts" ).cuda()
87
- model(inputs)
87
+ model(* inputs)
88
88
0 commit comments