Skip to content

[mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp #68526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2023

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Oct 8, 2023

  • [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations
  • [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Changes
  • [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations
  • [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp

Patch is 23.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68526.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+10-8)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+104-62)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+83)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..0b0be116ce1c1d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                               ArrayRef<ReassociationIndices> dimSequences);
 
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
 /// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
-    RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+                        ArrayRef<ReassociationIndices> foldedIterationDims,
+                        RewriterBase &rewriter);
 
 struct LowerPackResult {
   tensor::PadOp padOp;
@@ -1507,7 +1509,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
 /// to return an array of `ReassociationIndices` representing dimensions that
 /// should be merged.
 using GetCollapsableDimensionsFn =
-    std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+    std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
 
 /// Pattern to collapse dimensions in a linalg.generic op. This will collapse
 /// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..3e5f0ec24ffde99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,24 +1373,31 @@ getOperandReassociation(AffineMap indexingMap,
 }
 
 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                    OpOperand *opOperand,
                                    const CollapsingInfo &collapsingInfo,
                                    OpBuilder &builder) {
-  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
   SmallVector<ReassociationIndices> operandReassociation =
       getOperandReassociation(indexingMap, collapsingInfo);
 
-  // If the number of entries in the reassocation for the operand is same as the
-  // number of results of the indexing map, then nothing to do for this operand.
+  // If the number of entries in the reassociation for the operand is same as
+  // the number of results of the indexing map, then nothing to do for this
+  // operand.
   Value operand = opOperand->get();
   if (operandReassociation.size() == indexingMap.getNumResults())
     return operand;
 
   // Insert a reshape to collapse the dimensions.
-  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
-      loc, operand, operandReassociation);
-  return reshapeOp.getResult();
+  if (isa<MemRefType>(operand.getType())) {
+    return builder
+        .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  } else {
+    return builder
+        .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  }
 }
 
 /// Modify the `linalg.index` operations in the original generic op, to its
@@ -1434,27 +1441,43 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
 }
 
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
+  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+                "unsupported linalg op type to collapse");
+
   // Bail on trivial no-op cases.
-  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
         return foldedDims.size() <= 1;
       }))
     return failure();
 
+  bool hasBufferSemantics = op.hasBufferSemantics();
+  if (hasBufferSemantics &&
+      !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
+        MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+        if (!memRefToCollapse)
+          return true;
+
+        return memref::CollapseShapeOp::isGuaranteedCollapsible(
+            memRefToCollapse, foldedIterationDims);
+      }))
+    return rewriter.notifyMatchFailure(op,
+                                       "memref is not guaranteed collapsible");
+
   CollapsingInfo collapsingInfo;
-  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
-                                       foldedIterationDims))) {
+  if (failed(
+          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
     return rewriter.notifyMatchFailure(
-        genericOp, "illegal to collapse specified dimensions");
+        op, "illegal to collapse specified dimensions");
   }
 
   // Bail on non-canonical ranges.
   SmallVector<Range> loopRanges =
-      cast<LinalgOp>(genericOp.getOperation())
-          .createLoopRanges(rewriter, genericOp.getLoc());
+      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
@@ -1467,80 +1490,97 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
                opFoldIsConstantValue(range.stride, 1);
       })) {
     return rewriter.notifyMatchFailure(
-        genericOp,
-        "expected all loop ranges to have zero start and unit stride");
+        op, "expected all loop ranges to have zero start and unit stride");
   }
 
   // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypesArray(), collapsingInfo);
+  SmallVector<utils::IteratorType> iteratorTypes =
+      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
   auto indexingMaps = llvm::to_vector(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+      llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
         return getCollapsedOpIndexingMap(map, collapsingInfo);
       }));
 
-  Location loc = genericOp->getLoc();
+  Location loc = op->getLoc();
 
   // Get the input operands.
-  auto inputOperands = llvm::to_vector(llvm::map_range(
-      genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
-        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+  auto inputOperands = llvm::to_vector(
+      llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
                                      rewriter);
       }));
 
   // Get the output operands and result types.
   SmallVector<Type> resultTypes;
   SmallVector<Value> outputOperands;
-  resultTypes.reserve(genericOp.getNumDpsInits());
-  outputOperands.reserve(genericOp.getNumDpsInits());
-  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
-    Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
-                                            collapsingInfo, rewriter);
+  resultTypes.reserve(op.getNumDpsInits());
+  outputOperands.reserve(op.getNumDpsInits());
+  for (OpOperand &output : op.getDpsInitsMutable()) {
+    Value newOutput =
+        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
     outputOperands.push_back(newOutput);
-    resultTypes.push_back(newOutput.getType());
+    // If the op has "buffer semantics", then the init operands are ranked
+    // memrefs and the op has no results.
+    if (!hasBufferSemantics)
+      resultTypes.push_back(newOutput.getType());
   }
 
   // Create the generic op.
-  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
-      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &genericOp->getRegion(0).front();
-  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
-  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
-                       collapsedOpBlock->getArguments());
-
-  if (collapsedGenericOp.hasIndexSemantics()) {
+  Operation *collapsedOp;
+  if (isa<linalg::GenericOp>(op)) {
+    collapsedOp = rewriter.create<linalg::GenericOp>(
+        loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+        iteratorTypes,
+        [](OpBuilder &builder, Location loc, ValueRange args) {});
+    Block *origOpBlock = &op->getRegion(0).front();
+    Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+    rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+                         collapsedOpBlock->getArguments());
+  } else {
+    assert(isa<linalg::CopyOp>(op));
+    collapsedOp = rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+                                                  outputOperands[0]);
+  }
+  LinalgType collapsedLinalgOp = cast<LinalgType>(collapsedOp);
+
+  if (collapsedLinalgOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(collapsedGenericOp);
+    rewriter.setInsertionPoint(collapsedLinalgOp);
     SmallVector<Value> loopBound =
         llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
           return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
         }));
     generateCollapsedIndexingRegion(loc,
-                                    &collapsedGenericOp->getRegion(0).front(),
+                                    &collapsedLinalgOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
 
   // Insert expanding reshape for the result to get back the original result
   // type.
   SmallVector<Value> results;
-  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
+  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
     Value collapsedOpResult =
-        collapsedGenericOp->getResult(originalResult.index());
+        collapsedLinalgOp->getResult(originalResult.index());
     auto originalResultType =
         cast<ShapedType>(originalResult.value().getType());
     auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
       AffineMap indexingMap =
-          genericOp.getIndexingMapMatchingResult(originalResult.value());
+          op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
-      Value result = rewriter.create<tensor::ExpandShapeOp>(
-          loc, originalResultType, collapsedOpResult, reassociation);
-      results.push_back(result);
+      if (isa<MemRefType>(collapsedOpResult.getType())) {
+        Value result = rewriter.create<memref::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      } else {
+        Value result = rewriter.create<tensor::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      }
     } else {
       results.push_back(collapsedOpResult);
     }
@@ -1578,8 +1618,8 @@ class FoldWithProducerReshapeOpByCollapsing
       }
 
       std::optional<SmallVector<Value>> replacements =
-          collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                         rewriter);
+          collapseOpIterationDims<linalg::GenericOp>(
+              genericOp, collapsableIterationDims, rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
@@ -1596,36 +1636,36 @@ class FoldWithProducerReshapeOpByCollapsing
 };
 
 /// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
 public:
   CollapseLinalgDimensions(MLIRContext *context,
                            GetCollapsableDimensionsFn collapseDimensions,
                            PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpRewritePattern<LinalgType>(context, benefit),
         controlCollapseDimension(std::move(collapseDimensions)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgType op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<ReassociationIndices> collapsableIterationDims =
-        controlCollapseDimension(genericOp);
+        controlCollapseDimension(op);
     if (collapsableIterationDims.empty())
       return failure();
 
     // Check if the specified list of dimensions to collapse is a valid list.
-    if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
                                   collapsableIterationDims)) {
       return rewriter.notifyMatchFailure(
-          genericOp, "specified dimensions cannot be collapsed");
+          op, "specified dimensions cannot be collapsed");
     }
 
     std::optional<SmallVector<Value>> replacements =
-        collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                       rewriter);
+        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+                                            rewriter);
     if (!replacements) {
-      return rewriter.notifyMatchFailure(genericOp,
-                                         "failed to collapse dimensions");
+      return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }
-    rewriter.replaceOp(genericOp, *replacements);
+    rewriter.replaceOp(op, *replacements);
     return success();
   }
 
@@ -1856,8 +1896,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
 void mlir::linalg::populateCollapseDimensions(
     RewritePatternSet &patterns,
     const GetCollapsableDimensionsFn &controlCollapseDimensions) {
-  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
-                                         controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>(
+      patterns.getContext(), controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::CopyOp>>(
+      patterns.getContext(), controlCollapseDimensions);
 }
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,86 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
 // CHECK-LABEL: func @uncollapsable(
 //       CHECK:   linalg.generic
 //  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL:   func.func private @collapsable_memref(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK:             %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK:         }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x...
[truncated]

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall change looks reasonable. Some refactoring might make this better, but this throws up some design issues. Why is CopyOp treated specially? This should work for other Linalg named ops as well. An alternative would be to use matchers to match a linalg.generic and convert to a linalg.copy after it is collapsed. That way we could keep the two concerns separate (and having a matcher for linalg.generic -> linalg.copy might be helpful in other places as well.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 10, 2023

Adding @aniragil @amrami @amirBish as subscribers

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 11, 2023

Overall change looks reasonable. Some refactoring might make this better, but this throws up some design issues. Why is CopyOp treated specially? This should work for other Linalg named ops as well. An alternative would be to use matchers to match a linalg.generic and convert to a linalg.copy after it is collapsed. That way we could keep the two concerns separate (and having a matcher for linalg.generic -> linalg.copy might be helpful in other places as well.

Actually am fully convient with the current design also. I am afraid that by the time I will need to collapse the linalg.copy it would be hard to identify it was linalg.copy at the beginning (i.e. after fusion). Any idea how to solve it?

@AviadCo AviadCo force-pushed the linalg/enable-collapse-on-copy branch from f2c94f6 to 8a3290a Compare October 11, 2023 08:38
@github-actions
Copy link

github-actions bot commented Oct 11, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@AviadCo AviadCo force-pushed the linalg/enable-collapse-on-copy branch from 8a3290a to 78d4218 Compare October 11, 2023 08:52
@nicolasvasilache
Copy link
Contributor

Overall change looks reasonable. Some refactoring might make this better, but this throws up some design issues. Why is CopyOp treated specially? This should work for other Linalg named ops as well. An alternative would be to use matchers to match a linalg.generic and convert to a linalg.copy after it is collapsed. That way we could keep the two concerns separate (and having a matcher for linalg.generic -> linalg.copy might be helpful in other places as well.

Actually am fully convient with the current design also. I am afraid that by the time I will need to collapse the linalg.copy it would be hard to identify it was linalg.copy at the beginning (i.e. after fusion). Any idea how to solve it?

I don't have concerns with the current implementation.
In the future, when we finally have "specialize" patterns available, the changes will be very minimal.

@AviadCo AviadCo force-pushed the linalg/enable-collapse-on-copy branch from 78d4218 to 144b497 Compare October 11, 2023 19:10
@AviadCo AviadCo requested review from MaheshRavishankar and removed request for amrami October 11, 2023 19:29
@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 11, 2023

@MaheshRavishankar @nicolasvasilache thanks for the code reivew.
@MaheshRavishankar I answered the comments, feel free to review again.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 17, 2023

@MaheshRavishankar ping

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will relent on allow linalg.copy be handled special case, though i'd strongly prefer not special casing linalg.copy this way. We should really just be collapsing on linalg.generic and "matching" back to linalg.copy if needed. But that might be a preference, so this path is OK.

Please do remove the static_asserts though... they cause deployment issues downstream.

Operation *createCollapsedOp(LinalgType op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_asserts create problems in deployment. This is an optimization, not required for correctness. It is completely reasonable to have an operation not collapse and optimization to return a failure. The caller can then handle appropriately. Please change to fail gracefully.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand it.
Whoever is planning to call this function can know in compile time that the type it uses as template is/isn't supported.
If we can catch bug in compilation time it is better than catch it during runtime.
Can you elaborate about static assert can cause deployment issues downstream?

@@ -1884,8 +1902,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populateCollapseDimensions(
RewritePatternSet &patterns,
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
controlCollapseDimensions);
patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Use a single patterns.add<...>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

@AviadCo AviadCo force-pushed the linalg/enable-collapse-on-copy branch from 144b497 to 41c5c80 Compare October 22, 2023 15:07
@AviadCo AviadCo force-pushed the linalg/enable-collapse-on-copy branch from 41c5c80 to b47dbcb Compare October 22, 2023 15:41
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I misread what was happening. Looks good. Thanks!

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 23, 2023

Ok. I misread what was happening. Looks good. Thanks!

thanks for the fast response!

@AviadCo AviadCo merged commit 5c3ed39 into llvm:main Oct 23, 2023
@AviadCo AviadCo deleted the linalg/enable-collapse-on-copy branch October 23, 2023 06:42
@nicolasvasilache
Copy link
Contributor

@chelini here is a place that can be cleaned up once we have transform.specialize working for just linalg.generic -> linalg.copy.

@chelini
Copy link
Contributor

chelini commented Oct 24, 2023

@chelini here is a place that can be cleaned up once we have transform.specialize working for just linalg.generic -> linalg.copy.

Sounds good @nicolasvasilache, thanks. I will start working on it coming days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants