1
1
from typing import Optional , Union
2
2
3
3
import numpy as np
4
- import tensorrt as trt
5
4
import torch
6
5
from torch .fx .node import Target
7
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
10
9
broadcastable ,
11
10
get_trt_tensor ,
12
11
)
13
- from torch_tensorrt .dynamo .conversion .impl .slice import expand
14
- from torch_tensorrt .fx .converters .converter_utils import set_layer_name
12
+ from torch_tensorrt .fx .converters .converter_utils import prepend_ones , set_layer_name
15
13
from torch_tensorrt .fx .types import TRTTensor
16
14
17
15
@@ -30,73 +28,30 @@ def where(
30
28
x_shape = list (input .shape )
31
29
y_shape = list (other .shape )
32
30
condition_shape = list (condition .shape )
31
+ max_shape_len = max (len (x_shape ), len (y_shape ), len (condition_shape ))
33
32
34
- output_shape = list (torch .broadcast_shapes (condition_shape , x_shape , y_shape ))
35
-
36
- # expand shape
37
33
if not isinstance (condition , TRTTensor ):
38
34
assert condition .dtype in (torch .bool , np .bool_ ), "condition dtype is not bool"
39
- if condition_shape != output_shape :
40
- condition = (
41
- condition .expand (output_shape )
42
- if isinstance (condition , torch .Tensor )
43
- else np .broadcast_to (condition , output_shape )
44
- )
45
- condition_val = get_trt_tensor (ctx , condition , f"{ name } _condition" )
46
- else :
47
- assert condition .dtype == trt .bool , "mask dtype is not bool!"
48
- if condition_shape != output_shape :
49
- condition_val = expand (
50
- ctx , target , source_ir , f"{ name } _expand" , condition , output_shape
51
- )
52
- else :
53
- condition_val = condition
35
+ condition = get_trt_tensor (ctx , condition , f"{ name } _condition" )
36
+ diff = max_shape_len - len (condition_shape )
37
+ if diff > 0 :
38
+ condition = prepend_ones (
39
+ ctx .net , condition , f"{ name } _condition_broadcast" , diff
40
+ )
54
41
55
42
if not isinstance (input , TRTTensor ):
56
- if x_shape != output_shape :
57
- # special case where 1 element in input
58
- if len (input .shape ) == 0 :
59
- input = (
60
- input .unsqueeze (0 )
61
- if isinstance (input , torch .Tensor )
62
- else np .expand_dims (input , axis = 0 )
63
- )
64
- input = (
65
- input .expand (output_shape )
66
- if isinstance (input , torch .Tensor )
67
- else np .broadcast_to (input , output_shape )
68
- )
69
- x_val = get_trt_tensor (ctx , input , f"{ name } _x" )
70
- else :
71
- x_val = input
72
- if x_shape != output_shape :
73
- x_val = expand (
74
- ctx , target , source_ir , f"{ name } _x_expand" , input , output_shape
75
- )
43
+ input = get_trt_tensor (ctx , input , f"{ name } _x" )
44
+ diff = max_shape_len - len (x_shape )
45
+ if diff > 0 :
46
+ input = prepend_ones (ctx .net , input , f"{ name } _input_broadcast" , diff )
76
47
77
48
if not isinstance (other , TRTTensor ):
78
- if y_shape != output_shape :
79
- # special case where 1 element in other
80
- if len (other .shape ) == 0 :
81
- other = (
82
- other .unsqueeze (0 )
83
- if isinstance (other , torch .Tensor )
84
- else np .expand_dims (other , axis = 0 )
85
- )
86
- other = (
87
- other .expand (output_shape )
88
- if isinstance (other , torch .Tensor )
89
- else np .broadcast_to (other , output_shape )
90
- )
91
- y_val = get_trt_tensor (ctx , other , f"{ name } _y" )
92
- else :
93
- y_val = other
94
- if y_shape != output_shape :
95
- y_val = expand (
96
- ctx , target , source_ir , f"{ name } _y_expand" , y_val , output_shape
97
- )
49
+ other = get_trt_tensor (ctx , other , f"{ name } _y" )
50
+ diff = max_shape_len - len (y_shape )
51
+ if diff > 0 :
52
+ other = prepend_ones (ctx .net , other , f"{ name } _other_broadcast" , diff )
98
53
99
- return select (ctx , target , source_ir , name , x_val , y_val , condition_val )
54
+ return select (ctx , target , source_ir , name , input , other , condition )
100
55
101
56
102
57
def select (
0 commit comments