Skip to content

Commit 48feb0c

Browse files
committed
fix bugs
1 parent ea0a3fd commit 48feb0c

File tree

1 file changed

+5
-9
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/slice

1 file changed

+5
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
22
from typing import Optional
33

4+
import numpy as np
45
import tensorrt as trt
5-
import torch
66
from torch.fx.node import Target
77
from torch_tensorrt.dynamo._SourceIR import SourceIR
88
from torch_tensorrt.dynamo.conversion import impl
@@ -13,11 +13,9 @@
1313
)
1414
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
1515
from torch_tensorrt.fx.converters.converter_utils import (
16-
Frameworks,
1716
has_dynamic_shape,
1817
prepend_ones,
1918
set_layer_name,
20-
unified_dtype_converter,
2119
)
2220
from torch_tensorrt.fx.types import Shape, TRTTensor
2321

@@ -130,18 +128,16 @@ def cumsum(
130128
input_shape = input.shape
131129
dim = get_positive_dim(dim, len(input_shape))
132130
loop = ctx.net.add_loop()
133-
axis = torch.tensor(input_shape[dim], dtype=torch.int32)
131+
axis = np.array(input_shape[dim])
134132
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
135133
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
136134
iterator = loop.add_iterator(input, dim, reverse=False)
137135
data = iterator.get_output(0)
138136
new_dims = tuple(data.shape)
139-
zero_tensor = torch.zeros(
140-
new_dims, dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
141-
)
142-
zero_tensor = get_trt_tensor(ctx, zero_tensor, f"{name}_initial_value")
137+
zeros = np.zeros(new_dims)
138+
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
143139

144-
running_sum = loop.add_recurrence(zero_tensor)
140+
running_sum = loop.add_recurrence(zero_trttensor)
145141
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
146142
running_sum_tensor = running_sum.get_output(0)
147143

0 commit comments

Comments
 (0)