Skip to content

Commit 9e60482

Browse files
committed
support dds and nonzero op
1 parent 2368e63 commit 9e60482

File tree

4 files changed

+280
-43
lines changed

4 files changed

+280
-43
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,3 +3582,20 @@ def aten_ops_full(
35823582
fill_value=args[1],
35833583
dtype=kwargs.get("dtype", None),
35843584
)
3585+
3586+
3587+
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
3588+
def aten_ops_nonzero(
3589+
ctx: ConversionContext,
3590+
target: Target,
3591+
args: Tuple[Argument, ...],
3592+
kwargs: Dict[str, Argument],
3593+
name: str,
3594+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3595+
return impl.unary.nonzero(
3596+
ctx,
3597+
target,
3598+
SourceIR.ATEN,
3599+
name,
3600+
args[0],
3601+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,18 @@ def native_dropout(
624624
mask = np.ones(input_val.shape, dtype=bool)
625625
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
626626
return identity_layer.get_output(0), mask
627+
628+
629+
def nonzero(
630+
ctx: ConversionContext,
631+
target: Target,
632+
source_ir: Optional[SourceIR],
633+
name: str,
634+
input_val: TRTTensor,
635+
) -> TRTTensor:
636+
non_zero_layer = ctx.net.add_non_zero(input_val)
637+
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
638+
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
639+
shuffle_layer.first_transpose = trt.Permutation([1, 0])
640+
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
641+
return shuffle_layer.get_output(0)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 174 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,41 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26+
class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
27+
def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None:
28+
trt.IOutputAllocator.__init__(self)
29+
self.buffers: Dict[str, torch.Tensor] = {}
30+
self.shapes: Dict[str, Tuple[int, ...]] = {}
31+
self.dtypes: Dict[str, torch.dtype] = output_dtypes
32+
33+
def reallocate_output_async(
34+
self,
35+
tensor_name: str,
36+
memory: int,
37+
size: int,
38+
alignment: int,
39+
stream: torch.cuda.Stream,
40+
) -> Any:
41+
shape = (size,)
42+
if tensor_name not in self.buffers:
43+
self.buffers[tensor_name] = torch.empty(
44+
shape,
45+
dtype=self.dtypes[tensor_name],
46+
device=torch.cuda.current_device(),
47+
)
48+
else:
49+
if self.buffers[tensor_name].shape != shape:
50+
self.buffers[tensor_name] = torch.empty(
51+
shape,
52+
dtype=self.dtypes[tensor_name],
53+
device=torch.cuda.current_device(),
54+
)
55+
return self.buffers[tensor_name].data_ptr()
56+
57+
def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None:
58+
self.shapes[tensor_name] = tuple(shape)
59+
60+
2661
class TorchTRTRuntimeStates:
2762
def __init__(self, new_cudagraphs: bool):
2863
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
@@ -164,8 +199,11 @@ def __init__(
164199
self.runtime_states = TorchTRTRuntimeStates(
165200
torch_tensorrt.runtime.get_cudagraphs_mode()
166201
)
202+
203+
self.contains_dds_layer = False
167204
self.pre_allocated_outputs: List[torch.Tensor] = []
168205
self.use_pre_allocated_outputs = False
206+
self.output_allocator: Optional[DynamicOutputAllocator] = None
169207

170208
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
171209
self.setup_engine()
@@ -238,9 +276,19 @@ def setup_engine(self) -> None:
238276
for output_name in self.output_names
239277
]
240278

279+
self.contains_dds_layer = self._check_dds_layer()
280+
if self.contains_dds_layer:
281+
self.setup_output_allocator()
282+
241283
if torch_tensorrt.runtime.get_cudagraphs_mode():
242284
self.cudagraph = torch.cuda.CUDAGraph()
243285

286+
def _check_dds_layer(self) -> bool:
287+
layer_info = self.get_layer_info()
288+
if "trainStation" in layer_info: # contains dds layer
289+
return True
290+
return False
291+
244292
def _check_initialized(self) -> None:
245293
if not self.initialized:
246294
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
@@ -358,19 +406,22 @@ def create_output_tensors(self) -> List[torch.Tensor]:
358406
def set_pre_allocated_outputs(self, enable: bool) -> None:
359407
self.use_pre_allocated_outputs = enable
360408

409+
def setup_output_allocator(self) -> None:
410+
if self.output_allocator is None:
411+
output_dtypes_dict = {}
412+
for o, output_name in enumerate(self.output_names):
413+
output_dtypes_dict[output_name] = self.output_dtypes[o]
414+
self.output_allocator = DynamicOutputAllocator(output_dtypes_dict)
415+
416+
for output_name in self.output_names:
417+
if not self.context.set_output_allocator(
418+
output_name, self.output_allocator
419+
):
420+
raise RuntimeError(f"Failed to set output allocator for {output_name}")
421+
361422
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
362-
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
363-
contiguous_inputs: List[torch.Tensor] = [
364-
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
365-
for i in inputs
366-
]
367-
with (
368-
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
369-
if self.profiling_enabled
370-
else nullcontext()
371-
):
372-
self._check_initialized()
373423

424+
def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
374425
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
375426
shape_changed = self.validate_input_shapes(inputs)
376427
(
@@ -389,38 +440,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
389440
self._input_buffers = [None] * len(self.input_names)
390441
self._output_buffers = [None] * len(self.output_names)
391442

392-
# If in safe mode, check at each iteration for whether a switch is required
393-
if (
394-
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
395-
):
396-
curr_device_id = torch.cuda.current_device()
397-
curr_device_properties = torch.cuda.get_device_properties(
398-
curr_device_id
399-
)
400-
logger.debug(f"Current Device: cuda:{curr_device_id}")
401-
402-
# If a switch is required, move all inputs to new device and set as active device
403-
if _is_switch_required(
404-
curr_device_id,
405-
self.target_device_id,
406-
curr_device_properties,
407-
self.target_device_properties,
408-
):
409-
device_id, _ = _select_rt_device(
410-
curr_device_id,
411-
self.target_device_id,
412-
self.target_device_properties,
413-
)
414-
415-
# Update current device
416-
device = torch.device(device_id)
417-
torch.cuda.set_device(device_id)
418-
419-
contiguous_inputs = [
420-
tensor.to(device) for tensor in contiguous_inputs
421-
]
422-
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
423-
424443
with (
425444
torch.autograd.profiler.record_function(
426445
"PythonTorchTensorRTModule:ProcessInputs"
@@ -536,6 +555,118 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
536555

537556
return outputs
538557

558+
def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
559+
with (
560+
torch.autograd.profiler.record_function(
561+
"PythonTorchTensorRTModule:ProcessInputs"
562+
)
563+
if self.profiling_enabled
564+
else nullcontext()
565+
):
566+
assert len(contiguous_inputs) == len(
567+
self.input_names
568+
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
569+
570+
self.setup_input_tensors(contiguous_inputs, False, False)
571+
572+
with (
573+
torch.autograd.profiler.record_function(
574+
"PythonTorchTensorRTModule:TensorRTRuntime"
575+
)
576+
if self.profiling_enabled
577+
else nullcontext()
578+
):
579+
self._caller_stream = torch.cuda.current_stream()
580+
if (
581+
self._engine_stream == torch.cuda.default_stream()
582+
or self._engine_stream is None
583+
):
584+
self._engine_stream = torch.cuda.Stream()
585+
586+
self._engine_stream.wait_stream(self._caller_stream)
587+
588+
with torch.cuda.stream(self._engine_stream):
589+
self.context.execute_async_v3(
590+
self._engine_stream.cuda_stream
591+
) # The OutputAllocator is called by execute_async_v3()
592+
593+
self._caller_stream.wait_stream(self._engine_stream)
594+
595+
with (
596+
torch.autograd.profiler.record_function(
597+
"PythonTorchTensorRTModule:ProcessOutputs"
598+
)
599+
if self.profiling_enabled
600+
else nullcontext()
601+
):
602+
outputs = []
603+
assert self.output_allocator is not None
604+
for o, output_name in enumerate(self.output_names):
605+
shape = self.output_allocator.shapes.get(output_name, None)
606+
dtype = self.output_dtypes[o]
607+
output = (
608+
self.output_allocator.buffers.get(output_name, None)
609+
.clone()
610+
.detach()
611+
)
612+
prod = int(torch.prod(torch.tensor(shape)))
613+
output = output.reshape(-1).view(dtype)[:prod].reshape(shape)
614+
outputs.append(output)
615+
616+
if len(outputs) == 1:
617+
return outputs[0]
618+
619+
return outputs
620+
621+
# Run forward function
622+
contiguous_inputs: List[torch.Tensor] = [
623+
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
624+
for i in inputs
625+
]
626+
with (
627+
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
628+
if self.profiling_enabled
629+
else nullcontext()
630+
):
631+
self._check_initialized()
632+
633+
# If in safe mode, check at each iteration for whether a switch is required
634+
if (
635+
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
636+
):
637+
curr_device_id = torch.cuda.current_device()
638+
curr_device_properties = torch.cuda.get_device_properties(
639+
curr_device_id
640+
)
641+
logger.debug(f"Current Device: cuda:{curr_device_id}")
642+
643+
# If a switch is required, move all inputs to new device and set as active device
644+
if _is_switch_required(
645+
curr_device_id,
646+
self.target_device_id,
647+
curr_device_properties,
648+
self.target_device_properties,
649+
):
650+
device_id, _ = _select_rt_device(
651+
curr_device_id,
652+
self.target_device_id,
653+
self.target_device_properties,
654+
)
655+
656+
# Update current device
657+
device = torch.device(device_id)
658+
torch.cuda.set_device(device_id)
659+
660+
contiguous_inputs = [
661+
tensor.to(device) for tensor in contiguous_inputs
662+
]
663+
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
664+
665+
if self.contains_dds_layer:
666+
return run_output_allocator()
667+
else:
668+
return run_cuda_graph()
669+
539670
def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
540671
"""
541672
Enable TensorRT profiling. After calling this function, TensorRT will report
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestNonZeroConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
((10,), torch.int),
14+
((1, 20), torch.int32),
15+
((2, 3), torch.int64),
16+
((2, 3, 4), torch.float),
17+
((2, 3, 4, 5), torch.float),
18+
]
19+
)
20+
def test_non_zero(self, input_shape, dtype):
21+
class NonZero(nn.Module):
22+
def forward(self, input):
23+
return torch.ops.aten.nonzero.default(input)
24+
25+
inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
26+
self.run_test(
27+
NonZero(),
28+
inputs,
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
(
34+
"1d",
35+
(1,),
36+
(10,),
37+
(100,),
38+
torch.int32,
39+
),
40+
(
41+
"2d",
42+
(1, 2),
43+
(5, 10),
44+
(20, 40),
45+
torch.float16,
46+
),
47+
(
48+
"3d",
49+
(1, 2, 3),
50+
(5, 10, 20),
51+
(30, 40, 50),
52+
torch.float,
53+
),
54+
]
55+
)
56+
def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype):
57+
class NonZero(nn.Module):
58+
def forward(self, input):
59+
return torch.ops.aten.nonzero.default(input)
60+
61+
input_specs = [
62+
Input(
63+
min_shape=min_shape,
64+
opt_shape=opt_shape,
65+
max_shape=max_shape,
66+
dtype=dtype,
67+
),
68+
]
69+
70+
self.run_test_with_dynamic_shape(NonZero(), input_specs)
71+
72+
73+
if __name__ == "__main__":
74+
run_tests()

0 commit comments

Comments
 (0)