Skip to content

Commit 9b13f10

Browse files
authored
Merge pull request #2513 from pytorch/cherry_picks_rel_2_1
cherry-pick: Perf + Bugfix PRs
2 parents 5b0e5fc + 0afc619 commit 9b13f10

File tree

5 files changed

+198
-29
lines changed

5 files changed

+198
-29
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import math
22
from typing import Optional
33

4+
import numpy as np
5+
import tensorrt as trt
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
8+
from torch_tensorrt.dynamo.conversion import impl
69
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
get_positive_dim,
12+
get_trt_tensor,
13+
)
814
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
915
from torch_tensorrt.fx.converters.converter_utils import (
1016
has_dynamic_shape,
@@ -96,3 +102,98 @@ def expand(
96102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
97103
set_layer_name(layer, target, name, source_ir)
98104
return layer.get_output(0)
105+
106+
107+
def chunk(
108+
ctx: ConversionContext,
109+
target: Target,
110+
source_ir: Optional[SourceIR],
111+
name: str,
112+
input: TRTTensor,
113+
chunks: int,
114+
dim: int,
115+
) -> TRTTensor:
116+
if chunks <= 0:
117+
raise RuntimeError(
118+
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
119+
)
120+
121+
shape = input.shape
122+
dim = get_positive_dim(dim, len(shape))
123+
124+
if dim >= len(shape):
125+
raise RuntimeError(
126+
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
127+
)
128+
129+
dynamic_shape = has_dynamic_shape(input.shape)
130+
if dynamic_shape > 0:
131+
# Check whether slice target dim is dynamic shape dim
132+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
133+
134+
size_dim = shape[dim]
135+
chunk_size = math.ceil(size_dim / chunks)
136+
result = []
137+
start = 0
138+
end = min(start + chunk_size, size_dim)
139+
cnt = 0
140+
141+
while start < end:
142+
result.append(
143+
slice_op(
144+
ctx,
145+
target,
146+
source_ir,
147+
f"{name}_slice_{cnt}",
148+
input,
149+
dim,
150+
start,
151+
end,
152+
1,
153+
)
154+
)
155+
start = end
156+
end = min(start + chunk_size, size_dim)
157+
cnt += 1
158+
159+
return result
160+
161+
162+
def cumsum(
163+
ctx: ConversionContext,
164+
target: Target,
165+
source_ir: Optional[SourceIR],
166+
name: str,
167+
input: TRTTensor,
168+
dim: int,
169+
) -> TRTTensor:
170+
input_shape = input.shape
171+
dim = get_positive_dim(dim, len(input_shape))
172+
loop = ctx.net.add_loop()
173+
axis = np.array(input_shape[dim])
174+
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
175+
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
176+
iterator = loop.add_iterator(input, dim, reverse=False)
177+
data = iterator.get_output(0)
178+
new_dims = tuple(data.shape)
179+
zeros = np.zeros(new_dims)
180+
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
181+
182+
running_sum = loop.add_recurrence(zero_trttensor)
183+
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
184+
running_sum_tensor = running_sum.get_output(0)
185+
186+
current_sum = impl.elementwise.add(
187+
ctx,
188+
target,
189+
source_ir,
190+
f"{name}_elementwise_add",
191+
data,
192+
running_sum_tensor,
193+
)
194+
running_sum.set_input(1, current_sum)
195+
196+
loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
197+
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
198+
loop_output.set_input(1, trip_limit)
199+
return loop_output.get_output(0)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestCumsumConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((1,), 0),
13+
((2,), 0),
14+
((3,), -1),
15+
]
16+
)
17+
def test_cumsum_1D(self, shape, dim):
18+
class Cumsum(nn.Module):
19+
def forward(self, x):
20+
return torch.ops.aten.cumsum.default(x, dim)
21+
22+
inputs = [torch.randn(shape)]
23+
self.run_test(
24+
Cumsum(),
25+
inputs,
26+
)
27+
28+
@parameterized.expand(
29+
[
30+
((3, 1), 0),
31+
((3, 1), 1),
32+
((2, 3), -1),
33+
((2, 3), -2),
34+
]
35+
)
36+
def test_cumsum_2D(self, shape, dims):
37+
class Cumsum(nn.Module):
38+
def forward(self, x):
39+
return torch.ops.aten.cumsum.default(x, dims)
40+
41+
inputs = [torch.randn(shape)]
42+
self.run_test(
43+
Cumsum(),
44+
inputs,
45+
)
46+
47+
@parameterized.expand(
48+
[
49+
((4, 2, 3), 0),
50+
((4, 2, 3), 1),
51+
((1, 2, 3), 2),
52+
((1, 2, 3), -1),
53+
((1, 2, 3), -2),
54+
]
55+
)
56+
def test_cumsum_3D(self, shape, dims):
57+
class Cumsum(nn.Module):
58+
def forward(self, x):
59+
return torch.ops.aten.cumsum.default(x, dims)
60+
61+
inputs = [torch.randn(shape)]
62+
self.run_test(
63+
Cumsum(),
64+
inputs,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

tools/perf/benchmark.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ python hub.py
77

88
batch_sizes=(1 2 4 8 16 32 64 128 256)
99
large_model_batch_sizes=(1 2 4 8 16 32 64)
10-
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor")
11-
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor")
10+
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "tensorrt")
11+
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "tensorrt")
1212

1313

1414
# Benchmark VGG16 model

tools/perf/perf_run.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -293,29 +293,30 @@ def run_tensorrt(
293293
input_tensors,
294294
params,
295295
precision,
296-
is_trt_engine=False,
297296
batch_size=1,
298297
):
299-
engine = None
300-
301-
# If the model file is a TensorRT engine then directly deserialize and run inference
302-
# else convert the torch module to a TensorRT engine first and then run inference
303-
if not is_trt_engine:
304-
compile_settings = {
305-
"inputs": input_tensors,
306-
"enabled_precisions": {precision_to_dtype(precision)},
307-
"truncate_long_and_double": params.get("truncate", False),
308-
}
309-
310-
print("Converting method to TensorRT engine...")
311-
with torch.no_grad(), torchtrt.logging.errors():
312-
model = torchtrt.ts.convert_method_to_trt_engine(
313-
model, "forward", **compile_settings
314-
)
315-
298+
# Export an ONNX model and convert to TRT
299+
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
300+
logger = trt.Logger(trt.Logger.WARNING)
301+
builder = trt.Builder(logger)
302+
network = builder.create_network(
303+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
304+
)
305+
parser = trt.OnnxParser(network, logger)
306+
success = parser.parse_from_file("./tmp.onnx")
307+
if not success:
308+
raise ValueError("ONNX conversion failed")
309+
310+
config = builder.create_builder_config()
311+
if precision == "fp16":
312+
config.set_flag(trt.BuilderFlag.FP16)
313+
start_compile = time.time_ns()
314+
serialized_engine = builder.build_serialized_network(network, config)
315+
end_compile = time.time_ns()
316+
compile_time_s = (end_compile - start_compile) / 1e9
316317
# Deserialize the TensorRT engine
317-
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
318-
engine = runtime.deserialize_cuda_engine(model)
318+
with trt.Runtime(logger) as runtime:
319+
engine = runtime.deserialize_cuda_engine(serialized_engine)
319320

320321
print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
321322
iters = params.get("iterations", 20)
@@ -350,7 +351,7 @@ def run_tensorrt(
350351
meas_time = end_time - start_time
351352
timings.append(meas_time)
352353

353-
recordStats("TensorRT", timings, precision, batch_size)
354+
recordStats("TensorRT", timings, precision, batch_size, compile_time_s)
354355

355356

356357
# Deploys inference run for different backend configurations
@@ -426,11 +427,10 @@ def run(
426427
)
427428
elif backend == "tensorrt":
428429
run_tensorrt(
429-
model,
430+
model_torch,
430431
input_tensors,
431432
params,
432433
precision,
433-
is_trt_engine,
434434
batch_size,
435435
)
436436
elif backend == "dynamo":
@@ -439,9 +439,6 @@ def run(
439439
elif backend == "torch_compile":
440440
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
441441

442-
elif backend == "torch_compile":
443-
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)
444-
445442
elif backend == "inductor":
446443
run_inductor(model_torch, input_tensors, params, precision, batch_size)
447444

tools/perf/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
numpy
22
argparse
33
pyyaml
4+
onnx
45
transformers==4.33.2
56
diffusers==0.21.4
67
pandas==2.0.1
78
timm==0.9.8
9+

0 commit comments

Comments
 (0)