@@ -384,6 +384,147 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
384
384
let hasVerifier = 1;
385
385
}
386
386
387
+ //===----------------------------------------------------------------------===//
388
+ // GroupedConvNDOp ops.
389
+ //===----------------------------------------------------------------------===//
390
+
391
+ def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd",
392
+ [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> {
393
+
394
+ let summary = [{
395
+ Performs N-D grouped convolution with switchable channel position; either first or last.
396
+ }];
397
+ let description = [{
398
+ Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
399
+ will represent all spatial dimensions. Operand layouts are determined by the `layouts`
400
+ `StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the
401
+ corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of
402
+ n: (batch dim)
403
+ g: (group dim)
404
+ f: (feature or output channel dim)
405
+ s: (all spatial dims)
406
+ c: (input channel dim).
407
+
408
+ The domain will always be in the order `(N, G, F, S, C, KS)`.
409
+
410
+ }];
411
+
412
+ let arguments = (ins
413
+ Variadic<TensorOrMemref>:$inputs,
414
+ Variadic<TensorOrMemref>:$inits,
415
+ DefaultValuedAttr<StrArrayAttr, "{\"ngcs\", \"gfcs\", \"ngfs\"}">:$layouts,
416
+ OptionalAttr<I64ElementsAttr>:$strides,
417
+ OptionalAttr<I64ElementsAttr>:$dilations
418
+ );
419
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
420
+ let regions = (region AnyRegion:$region);
421
+
422
+ let skipDefaultBuilders = 1;
423
+ let builders = [
424
+ OpBuilder<
425
+ (ins "Value":$input, "Value":$filter, "Value":$init,
426
+ CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
427
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
428
+ [{
429
+ int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
430
+ if (strides.empty())
431
+ strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
432
+ if (dilations.empty())
433
+ dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
434
+ $_state.addAttribute(getStridesAttrName($_state.name),
435
+ ::mlir::DenseElementsAttr::get(
436
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
437
+ $_state.addAttribute(getDilationsAttrName($_state.name),
438
+ ::mlir::DenseElementsAttr::get(
439
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
440
+ buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init,
441
+ attributes, GroupedConvNDOp::getRegionBuilder());
442
+ }]>,
443
+ OpBuilder<
444
+ (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
445
+ "Value":$init,
446
+ CArg<"ArrayRef<int64_t>", "{}">:$strides, CArg<"ArrayRef<int64_t>", "{}">:$dilations,
447
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
448
+ [{
449
+ int64_t numSpatialDims = cast<ShapedType>(input.getType()).getRank() - 3;
450
+ if (strides.empty())
451
+ strides = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
452
+ if (dilations.empty())
453
+ dilations = ::llvm::SmallVector<int64_t, 2>(numSpatialDims, 1);
454
+ $_state.addAttribute(getStridesAttrName($_state.name),
455
+ ::mlir::DenseElementsAttr::get(
456
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides));
457
+ $_state.addAttribute(getDilationsAttrName($_state.name),
458
+ ::mlir::DenseElementsAttr::get(
459
+ ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations));
460
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
461
+ {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder());
462
+ }]>,
463
+ OpBuilder<
464
+ (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter,
465
+ "Value":$init, "Attribute":$strides, "Attribute":$dilations,
466
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
467
+ [{
468
+ $_state.addAttribute(getStridesAttrName($_state.name), strides);
469
+ $_state.addAttribute(getDilationsAttrName($_state.name), dilations);
470
+ buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init,
471
+ attributes, GroupedConvNDOp::getRegionBuilder());
472
+ }]>
473
+ ];
474
+
475
+ // TODO: Figure out how to move this to the interface
476
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
477
+ void print(::mlir::OpAsmPrinter &printer) {
478
+ return detail::convolution_impl::print(*this, printer);
479
+ }
480
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
481
+ ::mlir::OperationState &result) {
482
+ return detail::convolution_impl::parse(parser, result);
483
+ }
484
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
485
+ mlir::ArrayRef<mlir::NamedAttribute>)>
486
+ getRegionBuilder() {
487
+ return detail::convolution_impl::regionBuilder;
488
+ }
489
+ // Implement functions necessary for DestinationStyleOpInterface.
490
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
491
+
492
+ // Implement functions necessary for LinalgOp.
493
+ ArrayAttr getIndexingMaps();
494
+
495
+ // Implement functions necessary for GroupedConvolutionOpInterface
496
+ int64_t getSpatialRank() {
497
+ return detail::grouped_convolution_impl::getSpatialRank(*this);
498
+ }
499
+
500
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> getLayoutsEnums() {
501
+ SmallVector<SmallVector<::mlir::utils::GroupedConvDim>> layouts;
502
+ for (auto attr : (*this).getLayoutsAttr().getValue()) {
503
+ std::string layoutStr = cast<StringAttr>(attr).getValue().str();
504
+ SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size());
505
+ for (size_t i = 0; i < layoutStr.size(); i++) {
506
+ auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str());
507
+ assert(maybeDimEnum);
508
+ layout[i] = maybeDimEnum.value();
509
+ }
510
+ layouts.push_back(layout);
511
+ }
512
+ return layouts;
513
+ }
514
+
515
+ int64_t getOutputChannelPosition() {
516
+ return 2;
517
+ }
518
+
519
+ int64_t getInputChannelPosition() {
520
+ return 2;
521
+ }
522
+
523
+ int64_t getInputGroupsPosition() {
524
+ return 1;
525
+ }
526
+ }];
527
+ }
387
528
388
529
//===----------------------------------------------------------------------===//
389
530
// Transpose op.
0 commit comments