|
1 | 1 | import math
|
2 | 2 | from typing import Optional
|
3 | 3 |
|
| 4 | +import numpy as np |
4 | 5 | import tensorrt as trt
|
5 |
| -import torch |
6 | 6 | from torch.fx.node import Target
|
7 | 7 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
8 | 8 | from torch_tensorrt.dynamo.conversion import impl
|
|
13 | 13 | )
|
14 | 14 | from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
|
15 | 15 | from torch_tensorrt.fx.converters.converter_utils import (
|
16 |
| - Frameworks, |
17 | 16 | has_dynamic_shape,
|
18 | 17 | prepend_ones,
|
19 | 18 | set_layer_name,
|
20 |
| - unified_dtype_converter, |
21 | 19 | )
|
22 | 20 | from torch_tensorrt.fx.types import Shape, TRTTensor
|
23 | 21 |
|
@@ -130,18 +128,16 @@ def cumsum(
|
130 | 128 | input_shape = input.shape
|
131 | 129 | dim = get_positive_dim(dim, len(input_shape))
|
132 | 130 | loop = ctx.net.add_loop()
|
133 |
| - axis = torch.tensor(input_shape[dim], dtype=torch.int32) |
| 131 | + axis = np.array(input_shape[dim]) |
134 | 132 | trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
|
135 | 133 | loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
|
136 | 134 | iterator = loop.add_iterator(input, dim, reverse=False)
|
137 | 135 | data = iterator.get_output(0)
|
138 | 136 | new_dims = tuple(data.shape)
|
139 |
| - zero_tensor = torch.zeros( |
140 |
| - new_dims, dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH) |
141 |
| - ) |
142 |
| - zero_tensor = get_trt_tensor(ctx, zero_tensor, f"{name}_initial_value") |
| 137 | + zeros = np.zeros(new_dims) |
| 138 | + zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value") |
143 | 139 |
|
144 |
| - running_sum = loop.add_recurrence(zero_tensor) |
| 140 | + running_sum = loop.add_recurrence(zero_trttensor) |
145 | 141 | set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
|
146 | 142 | running_sum_tensor = running_sum.get_output(0)
|
147 | 143 |
|
|
0 commit comments