@@ -8,7 +8,10 @@ namespace conversion {
8
8
namespace converters {
9
9
namespace impl {
10
10
namespace {
11
- auto reduced_registrations = RegisterNodeConversionPatterns()
11
+
12
+
13
+
14
+ auto reduce_registrations = RegisterNodeConversionPatterns()
12
15
.pattern({
13
16
" aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)" ,
14
17
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -36,7 +39,7 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
36
39
LOG_DEBUG (" Dim to reduce:" << util::toDims (dims)); // Some abuse of toDim but just for debug info
37
40
38
41
uint32_t axis_mask = 0 ;
39
- for (int d = 0 ; d < dims.size (); d++) {
42
+ for (size_t d = 0 ; d < dims.size (); d++) {
40
43
axis_mask |= 1 << dims[d];
41
44
}
42
45
LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
@@ -52,6 +55,131 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
52
55
mean_layer->setName (util::node_info (n).c_str ());
53
56
auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
54
57
58
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
59
+ return true ;
60
+ }
61
+ }).pattern({
62
+ " aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor" ,
63
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
64
+ auto in_tensor = args[0 ].ITensor ();
65
+ auto in_dims = util::toVec (in_tensor->getDimensions ());
66
+ LOG_WARNING (" Sum Converter disregards dtype" );
67
+
68
+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
69
+
70
+ auto sum_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kSUM , axis_mask, false );
71
+
72
+ TRTORCH_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
73
+
74
+ sum_layer->setName (util::node_info (n).c_str ());
75
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sum_layer->getOutput (0 ));
76
+
77
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
78
+ return true ;
79
+ }
80
+ }).pattern({
81
+ " aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" ,
82
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
83
+ auto in_tensor = args[0 ].ITensor ();
84
+ auto dims = args[1 ].unwrapToIntList ();
85
+ LOG_DEBUG (" Dim to reduce:" << util::toDims (dims)); // Some abuse of toDim but just for debug info
86
+
87
+ uint32_t axis_mask = 0 ;
88
+ for (size_t d = 0 ; d < dims.size (); d++) {
89
+ axis_mask |= 1 << dims[d];
90
+ }
91
+ LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
92
+
93
+ auto keepdim = args[2 ].unwrapToBool ();
94
+ LOG_DEBUG (" Keep dims :" << keepdim);
95
+
96
+ LOG_WARNING (" Sum converter disregards dtype" );
97
+ auto sum_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kSUM , axis_mask, keepdim);
98
+
99
+ TRTORCH_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
100
+
101
+ sum_layer->setName (util::node_info (n).c_str ());
102
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sum_layer->getOutput (0 ));
103
+
104
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
105
+ return true ;
106
+ }
107
+ }).pattern({
108
+ " aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor" ,
109
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
110
+ auto in_tensor = args[0 ].ITensor ();
111
+ auto in_dims = util::toVec (in_tensor->getDimensions ());
112
+ LOG_WARNING (" Prod Converter disregards dtype" );
113
+
114
+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
115
+
116
+ auto prod_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kPROD , axis_mask, false );
117
+
118
+ TRTORCH_CHECK (prod_layer, " Unable to create sum layer from node: " << *n);
119
+
120
+ prod_layer->setName (util::node_info (n).c_str ());
121
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], prod_layer->getOutput (0 ));
122
+
123
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
124
+ return true ;
125
+ }
126
+ }).pattern({
127
+ " aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" ,
128
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
129
+ auto in_tensor = args[0 ].ITensor ();
130
+ auto dim = args[1 ].unwrapToInt ();
131
+ LOG_DEBUG (" Dim to reduce:" << dim); // Some abuse of toDim but just for debug info
132
+
133
+ uint32_t axis_mask = 1 << dim;
134
+ LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
135
+
136
+ auto keepdim = args[2 ].unwrapToBool ();
137
+ LOG_DEBUG (" Keep dims :" << keepdim);
138
+
139
+ LOG_WARNING (" Prod converter disregards dtype" );
140
+ auto prod_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kPROD , axis_mask, keepdim);
141
+
142
+ TRTORCH_CHECK (prod_layer, " Unable to create mean layer from node: " << *n);
143
+
144
+ prod_layer->setName (util::node_info (n).c_str ());
145
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], prod_layer->getOutput (0 ));
146
+
147
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
148
+ return true ;
149
+ }
150
+ }).pattern({
151
+ " aten::max(Tensor self) -> Tensor" ,
152
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
153
+ auto in_tensor = args[0 ].ITensor ();
154
+ auto in_dims = util::toVec (in_tensor->getDimensions ());
155
+
156
+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
157
+
158
+ auto max_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kMAX , axis_mask, false );
159
+
160
+ TRTORCH_CHECK (max_layer, " Unable to create max layer from node: " << *n);
161
+
162
+ max_layer->setName (util::node_info (n).c_str ());
163
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], max_layer->getOutput (0 ));
164
+
165
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
166
+ return true ;
167
+ }
168
+ }).pattern({
169
+ " aten::min(Tensor self) -> Tensor" ,
170
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
171
+ auto in_tensor = args[0 ].ITensor ();
172
+ auto in_dims = util::toVec (in_tensor->getDimensions ());
173
+
174
+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
175
+
176
+ auto min_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kMIN , axis_mask, false );
177
+
178
+ TRTORCH_CHECK (min_layer, " Unable to create min layer from node: " << *n);
179
+
180
+ min_layer->setName (util::node_info (n).c_str ());
181
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], min_layer->getOutput (0 ));
182
+
55
183
LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
56
184
return true ;
57
185
}
@@ -62,63 +190,3 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
62
190
} // namespace conversion
63
191
} // namespace core
64
192
} // namespace trtorch
65
-
66
- // #include "core/util/prelude.h"
67
- // #include "core/conversion/converters/converters.h"
68
-
69
- // namespace trtorch {
70
- // namespace core {
71
- // namespace conversion {
72
- // namespace converters {
73
- // namespace impl {
74
- // namespace {
75
-
76
- // #define convert(unary, trt_type) \
77
- // auto unary##_registrations TRTORCH_UNUSED = \
78
- // RegisterNodeConversionPatterns().pattern( \
79
- // {"aten::" #unary "(Tensor self) -> Tensor", \
80
- // [](ConversionCtx *ctx, const torch::jit::Node *n, \
81
- // args &args) -> bool { \
82
- // auto in = args[0].ITensor(); \
83
- // auto unary = \
84
- // ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
85
- // \
86
- // TRTORCH_CHECK( \
87
- // unary, \
88
- // "Unable to create " #unary " layer from node: " << *n); \
89
- // \
90
- // unary->setName(util::node_info(n).c_str()); \
91
- // auto out_tensor = ctx->AssociateValueAndTensor( \
92
- // n->outputs()[0], \
93
- // unary->getOutput(0)); \
94
- // LOG_DEBUG( \
95
- // "Output tensor shape: " << out_tensor->getDimensions()); \
96
- // \
97
- // return true; \
98
- // }});
99
-
100
- // convert(cos, kCOS);
101
- // convert(acos, kACOS);
102
- // convert(cosh, kCOSH);
103
- // convert(sin, kSIN);
104
- // convert(asin, kASIN);
105
- // convert(sinh, kSINH);
106
- // convert(tan, kTAN);
107
- // convert(atan, kATAN);
108
- // convert(abs, kABS);
109
- // convert(floor, kFLOOR);
110
- // convert(reciprocal, kRECIP);
111
- // convert(log, kLOG);
112
- // convert(ceil, kCEIL);
113
- // convert(sqrt, kSQRT);
114
- // convert(exp, kEXP);
115
- // convert(neg, kNEG);
116
-
117
- // #undef convert
118
-
119
- // } // namespace
120
- // } // namespace impl
121
- // } // namespace converters
122
- // } // namespace conversion
123
- // } // namespace core
124
- // } // namespace trtorch
0 commit comments