@@ -393,6 +393,89 @@ auto expand_registrations TORCHTRT_UNUSED =
393
393
auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], collapse->getOutput (0 ));
394
394
LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
395
395
396
+ return true ;
397
+ }})
398
+ .pattern(
399
+ {" aten::meshgrid(Tensor[] tensors) -> (Tensor[])" ,
400
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
401
+ // torch.meshgrid only supports 1D or 0D input tensors
402
+ auto arg_tensors = args[0 ].IValue ()->toListRef ();
403
+ std::vector<nvinfer1::ITensor*> tensors;
404
+ for (auto t : arg_tensors) {
405
+ if (t.isTensor ()) {
406
+ auto torch_tensor = t.toTensor ();
407
+ tensors.push_back (tensor_to_const (ctx, torch_tensor));
408
+ } else {
409
+ auto cont = t.toCustomClass <TensorContainer>();
410
+ tensors.push_back (cont->tensor ());
411
+ }
412
+ }
413
+
414
+ // build the output shape for all tensors in the output list
415
+ nvinfer1::Dims output_dims;
416
+ output_dims.nbDims = tensors.size ();
417
+ for (size_t idx = 0UL ; idx < tensors.size (); ++idx) {
418
+ auto dims = tensors[idx]->getDimensions ();
419
+ output_dims.d [idx] = dims.nbDims == 0 ? 1 : dims.d [0 ];
420
+ }
421
+ std::vector<nvinfer1::ITensor*> out_tensors;
422
+ // Reshape tensors into output shape (reshape, expand)
423
+ for (size_t idx = 0UL ; idx < tensors.size (); ++idx) {
424
+ auto t = tensors[idx];
425
+ auto dims = t->getDimensions ();
426
+ nvinfer1::Dims reshape_dims;
427
+ reshape_dims.nbDims = tensors.size ();
428
+ for (size_t reshape_idx = 0UL ; reshape_idx < tensors.size (); ++reshape_idx) {
429
+ if (reshape_idx == idx) {
430
+ reshape_dims.d [reshape_idx] = dims.nbDims == 0 ? 1 : dims.d [0 ];
431
+ } else {
432
+ reshape_dims.d [reshape_idx] = 1 ;
433
+ }
434
+ }
435
+ // Add a reshape layer before expanding dims
436
+ auto reshape_layer = ctx->net ->addShuffle (*t);
437
+ reshape_layer->setReshapeDimensions (reshape_dims);
438
+ std::stringstream reshape_layer_name;
439
+ reshape_layer_name << util::node_info (n) << " _meshgrid_reshape_" << std::to_string (idx);
440
+ reshape_layer->setName (reshape_layer_name.str ().c_str ());
441
+ auto reshaped = reshape_layer->getOutput (0 );
442
+ LOG_DEBUG (" Tensor " << idx << " reshaped to : " << reshaped->getDimensions () << " from " << dims);
443
+
444
+ // Add slice layer for expansion
445
+ std::vector<int64_t > start_vec (output_dims.nbDims , 0 );
446
+ auto start_offset = util::toDims (c10::IntArrayRef (start_vec));
447
+
448
+ std::vector<int64_t > strides_vec (output_dims.nbDims , 0 );
449
+ for (int64_t i = 0 ; i < output_dims.nbDims ; i++) {
450
+ strides_vec[i] = (reshaped->getDimensions ().d [i] != 1 );
451
+ }
452
+
453
+ auto strides = util::toDims (c10::IntArrayRef (strides_vec));
454
+
455
+ auto slice_layer = ctx->net ->addSlice (*reshaped, start_offset, output_dims, strides);
456
+ std::stringstream slice_layer_name;
457
+ slice_layer_name << util::node_info (n) << " _meshgrid_slice_" << std::to_string (idx);
458
+ slice_layer->setName (slice_layer_name.str ().c_str ());
459
+ auto slice_output = slice_layer->getOutput (0 );
460
+ LOG_DEBUG (" Tensor " << idx << " expanded to : " << slice_output->getDimensions ());
461
+ out_tensors.push_back (slice_output);
462
+ }
463
+
464
+ // Pack output tensors into list
465
+ c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
466
+ c10::TypePtr elementType = lt->getElementType ();
467
+ auto list = c10::impl::GenericList (elementType);
468
+ list.reserve (out_tensors.size ());
469
+
470
+ for (auto t : out_tensors) {
471
+ auto tensor_holder = TensorContainer ();
472
+ tensor_holder.hold_tensor (t);
473
+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
474
+ list.emplace_back (ival);
475
+ }
476
+
477
+ auto output_list = std::move (torch::jit::IValue (list));
478
+ ctx->AssociateValueAndIValue (n->outputs ()[0 ], output_list);
396
479
return true ;
397
480
}});
398
481
0 commit comments