Skip to content

Commit 504b39d

Browse files
authored
chore: fix docs for export (#2447)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent da90d61 commit 504b39d

File tree

3 files changed

+25
-29
lines changed

3 files changed

+25
-29
lines changed

docsrc/dynamo/dynamo_export.rst

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _dynamo_export:
22

3-
Compiling ``ExportedPrograms`` with Torch-TensorRT
3+
Compiling Exported Programs with Torch-TensorRT
44
=============================================
55
.. currentmodule:: torch_tensorrt.dynamo
66

@@ -9,8 +9,6 @@ Compiling ``ExportedPrograms`` with Torch-TensorRT
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Using the Torch-TensorRT Frontend for ``torch.export.ExportedPrograms``
13-
--------------------------------------------------------
1412
Pytorch 2.1 introduced ``torch.export`` APIs which
1513
can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo
1614
frontend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple
@@ -43,8 +41,7 @@ Some of the frequently used options are as follows:
4341

4442
The complete list of options can be found `here <https://github.com/pytorch/TensorRT/blob/123a486d6644a5bbeeec33e2f32257349acc0b8f/py/torch_tensorrt/dynamo/compile.py#L51-L77>`_
4543

46-
.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in
47-
our Torchscript IR. We plan to implement similar support for dynamo in our next release.
44+
.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in our Torchscript IR. We plan to implement similar support for dynamo in our next release.
4845

4946
Under the hood
5047
--------------

docsrc/user_guide/saving_models.rst

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ The following code illustrates this approach.
2929
import torch_tensorrt
3030
3131
model = MyModel().eval().cuda()
32-
inputs = torch.randn((1, 3, 224, 224)).cuda()
32+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
3333
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
34-
trt_script_model = torch.jit.trace(trt_gm, inputs)
35-
torch.jit.save(trt_script_model, "trt_model.ts")
34+
trt_traced_model = torch.jit.trace(trt_gm, inputs)
35+
torch.jit.save(trt_traced_model, "trt_model.ts")
3636
3737
# Later, you can load it and run inference
3838
model = torch.jit.load("trt_model.ts").cuda()
39-
model(inputs)
39+
model(*inputs)
4040
4141
b) ExportedProgram
4242
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -50,40 +50,39 @@ b) ExportedProgram
5050
import torch_tensorrt
5151
5252
model = MyModel().eval().cuda()
53-
inputs = torch.randn((1, 3, 224, 224)).cuda()
53+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5454
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
5555
# Transform and create an exported program
56-
trt_gm = torch_tensorrt.dynamo.export(trt_gm, inputs)
57-
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
58-
torch._export.save(trt_exp_program, "trt_model.ep")
56+
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
57+
torch.export.save(trt_exp_program, "trt_model.ep")
5958
6059
# Later, you can load it and run inference
61-
model = torch._export.load("trt_model.ep")
62-
model(inputs)
60+
model = torch.export.load("trt_model.ep")
61+
model(*inputs)
6362
6463
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
6564
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
6665

67-
NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
66+
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
6867

6968

7069
Torchscript IR
7170
-------------
7271

73-
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
74-
This behavior stays the same in 2.X versions as well.
72+
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
73+
This behavior stays the same in 2.X versions as well.
7574

76-
.. code-block:: python
75+
.. code-block:: python
7776
78-
import torch
79-
import torch_tensorrt
77+
import torch
78+
import torch_tensorrt
8079
81-
model = MyModel().eval().cuda()
82-
inputs = torch.randn((1, 3, 224, 224)).cuda()
83-
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
84-
torch.jit.save(trt_ts, "trt_model.ts")
80+
model = MyModel().eval().cuda()
81+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
82+
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
83+
torch.jit.save(trt_ts, "trt_model.ts")
8584
86-
# Later, you can load it and run inference
87-
model = torch.jit.load("trt_model.ts").cuda()
88-
model(inputs)
85+
# Later, you can load it and run inference
86+
model = torch.jit.load("trt_model.ts").cuda()
87+
model(*inputs)
8988

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def export(
5656
return exp_program
5757
else:
5858
raise ValueError(
59-
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
59+
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
6060
)
6161

6262

0 commit comments

Comments
 (0)