@@ -260,6 +260,70 @@ auto element_wise_registrations TRTORCH_UNUSED =
260
260
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
261
261
return true ;
262
262
}})
263
+ .pattern({" aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
264
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
265
+ // TODO: Remove with functionalization
266
+ auto self = args[0 ].ITensorOrFreeze (ctx);
267
+ auto other = args[1 ].ITensorOrFreeze (ctx);
268
+ auto equal = add_elementwise (
269
+ ctx,
270
+ nvinfer1::ElementWiseOperation::kEQUAL ,
271
+ self,
272
+ other,
273
+ util::node_info (n) + std::string (" is_equal" ));
274
+ TRTORCH_CHECK (equal, " Unable to create elementwise equal layer from node: " << *n);
275
+ // XOR with ones negates and produces not_equal result
276
+ auto options = torch::TensorOptions ().dtype (torch::kFloat32 );
277
+ auto ones = at::full ({1 }, 1 , {options});
278
+ auto ones_tensor = tensor_to_const (ctx, ones);
279
+ nvinfer1::IIdentityLayer* cast_layer = ctx->net ->addIdentity (*ones_tensor);
280
+ cast_layer->setOutputType (0 , nvinfer1::DataType::kBOOL );
281
+
282
+ auto sub = add_elementwise (
283
+ ctx,
284
+ nvinfer1::ElementWiseOperation::kXOR ,
285
+ cast_layer->getOutput (0 ),
286
+ equal->getOutput (0 ),
287
+ util::node_info (n));
288
+ TRTORCH_CHECK (sub, " Unable to create ne (not equal) layer from node: " << *n);
289
+
290
+ sub->setName (util::node_info (n).c_str ());
291
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sub->getOutput (0 ));
292
+ LOG_DEBUG (" Not equal layer output tensor shape: " << out->getDimensions ());
293
+ return true ;
294
+ }})
295
+ .pattern({" aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
296
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
297
+ auto self = args[0 ].ITensorOrFreeze (ctx);
298
+ auto scalar = args[1 ].unwrapToScalar ().to <float >();
299
+ auto scalar_tensor = tensor_to_const (ctx, torch::tensor ({scalar}));
300
+ auto equal = add_elementwise (
301
+ ctx,
302
+ nvinfer1::ElementWiseOperation::kEQUAL ,
303
+ self,
304
+ scalar_tensor,
305
+ util::node_info (n) + std::string (" is_equal" ));
306
+ TRTORCH_CHECK (equal, " Unable to create elementwise equal layer from node: " << *n);
307
+ // XOR with ones negates and produces not_equal result
308
+ auto options = torch::TensorOptions ().dtype (torch::kFloat32 );
309
+ auto ones = at::full ({1 }, 1 , {options});
310
+ auto ones_tensor = tensor_to_const (ctx, ones);
311
+ nvinfer1::IIdentityLayer* cast_layer = ctx->net ->addIdentity (*ones_tensor);
312
+ cast_layer->setOutputType (0 , nvinfer1::DataType::kBOOL );
313
+
314
+ auto sub = add_elementwise (
315
+ ctx,
316
+ nvinfer1::ElementWiseOperation::kXOR ,
317
+ cast_layer->getOutput (0 ),
318
+ equal->getOutput (0 ),
319
+ util::node_info (n));
320
+ TRTORCH_CHECK (sub, " Unable to create ne (not equal) layer from node: " << *n);
321
+
322
+ sub->setName (util::node_info (n).c_str ());
323
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sub->getOutput (0 ));
324
+ LOG_DEBUG (" Not equal layer output tensor shape: " << out->getDimensions ());
325
+ return true ;
326
+ }})
263
327
.pattern({" aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)" ,
264
328
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
265
329
// TODO: Remove with functionalization
0 commit comments