-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp #68526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
AviadCo
commented
Oct 8, 2023
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Changes
Patch is 23.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68526.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..0b0be116ce1c1d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
- RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+ ArrayRef<ReassociationIndices> foldedIterationDims,
+ RewriterBase &rewriter);
struct LowerPackResult {
tensor::PadOp padOp;
@@ -1507,7 +1509,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
/// to return an array of `ReassociationIndices` representing dimensions that
/// should be merged.
using GetCollapsableDimensionsFn =
- std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+ std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
/// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..3e5f0ec24ffde99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,24 +1373,31 @@ getOperandReassociation(AffineMap indexingMap,
}
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- // If the number of entries in the reassocation for the operand is same as the
- // number of results of the indexing map, then nothing to do for this operand.
+ // If the number of entries in the reassociation for the operand is same as
+ // the number of results of the indexing map, then nothing to do for this
+ // operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;
// Insert a reshape to collapse the dimensions.
- auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
- loc, operand, operandReassociation);
- return reshapeOp.getResult();
+ if (isa<MemRefType>(operand.getType())) {
+ return builder
+ .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ } else {
+ return builder
+ .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ }
}
/// Modify the `linalg.index` operations in the original generic op, to its
@@ -1434,27 +1441,43 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
/// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+ LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
+ static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+ "unsupported linalg op type to collapse");
+
// Bail on trivial no-op cases.
- if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+ if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();
+ bool hasBufferSemantics = op.hasBufferSemantics();
+ if (hasBufferSemantics &&
+ !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
+ MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+ if (!memRefToCollapse)
+ return true;
+
+ return memref::CollapseShapeOp::isGuaranteedCollapsible(
+ memRefToCollapse, foldedIterationDims);
+ }))
+ return rewriter.notifyMatchFailure(op,
+ "memref is not guaranteed collapsible");
+
CollapsingInfo collapsingInfo;
- if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
- foldedIterationDims))) {
+ if (failed(
+ collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
- genericOp, "illegal to collapse specified dimensions");
+ op, "illegal to collapse specified dimensions");
}
// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
- cast<LinalgOp>(genericOp.getOperation())
- .createLoopRanges(rewriter, genericOp.getLoc());
+ cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
@@ -1467,80 +1490,97 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
opFoldIsConstantValue(range.stride, 1);
})) {
return rewriter.notifyMatchFailure(
- genericOp,
- "expected all loop ranges to have zero start and unit stride");
+ op, "expected all loop ranges to have zero start and unit stride");
}
// Get the iterator types for the operand.
- SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
- genericOp.getIteratorTypesArray(), collapsingInfo);
+ SmallVector<utils::IteratorType> iteratorTypes =
+ getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
// Get the indexing maps.
auto indexingMaps = llvm::to_vector(
- llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+ llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));
- Location loc = genericOp->getLoc();
+ Location loc = op->getLoc();
// Get the input operands.
- auto inputOperands = llvm::to_vector(llvm::map_range(
- genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
- return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+ auto inputOperands = llvm::to_vector(
+ llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+ return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
}));
// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
- resultTypes.reserve(genericOp.getNumDpsInits());
- outputOperands.reserve(genericOp.getNumDpsInits());
- for (OpOperand &output : genericOp.getDpsInitsMutable()) {
- Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
- collapsingInfo, rewriter);
+ resultTypes.reserve(op.getNumDpsInits());
+ outputOperands.reserve(op.getNumDpsInits());
+ for (OpOperand &output : op.getDpsInitsMutable()) {
+ Value newOutput =
+ getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
- resultTypes.push_back(newOutput.getType());
+ // If the op has "buffer semantics", then the init operands are ranked
+ // memrefs and the op has no results.
+ if (!hasBufferSemantics)
+ resultTypes.push_back(newOutput.getType());
}
// Create the generic op.
- auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTypes, inputOperands, outputOperands, indexingMaps,
- iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
- Block *origOpBlock = &genericOp->getRegion(0).front();
- Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
- rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
- collapsedOpBlock->getArguments());
-
- if (collapsedGenericOp.hasIndexSemantics()) {
+ Operation *collapsedOp;
+ if (isa<linalg::GenericOp>(op)) {
+ collapsedOp = rewriter.create<linalg::GenericOp>(
+ loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+ iteratorTypes,
+ [](OpBuilder &builder, Location loc, ValueRange args) {});
+ Block *origOpBlock = &op->getRegion(0).front();
+ Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+ rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+ collapsedOpBlock->getArguments());
+ } else {
+ assert(isa<linalg::CopyOp>(op));
+ collapsedOp = rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+ outputOperands[0]);
+ }
+ LinalgType collapsedLinalgOp = cast<LinalgType>(collapsedOp);
+
+ if (collapsedLinalgOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(collapsedGenericOp);
+ rewriter.setInsertionPoint(collapsedLinalgOp);
SmallVector<Value> loopBound =
llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
}));
generateCollapsedIndexingRegion(loc,
- &collapsedGenericOp->getRegion(0).front(),
+ &collapsedLinalgOp->getRegion(0).front(),
collapsingInfo, loopBound, rewriter);
}
// Insert expanding reshape for the result to get back the original result
// type.
SmallVector<Value> results;
- for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
+ for (const auto &originalResult : llvm::enumerate(op->getResults())) {
Value collapsedOpResult =
- collapsedGenericOp->getResult(originalResult.index());
+ collapsedLinalgOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
- genericOp.getIndexingMapMatchingResult(originalResult.value());
+ op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- Value result = rewriter.create<tensor::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
+ if (isa<MemRefType>(collapsedOpResult.getType())) {
+ Value result = rewriter.create<memref::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ } else {
+ Value result = rewriter.create<tensor::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ }
} else {
results.push_back(collapsedOpResult);
}
@@ -1578,8 +1618,8 @@ class FoldWithProducerReshapeOpByCollapsing
}
std::optional<SmallVector<Value>> replacements =
- collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- rewriter);
+ collapseOpIterationDims<linalg::GenericOp>(
+ genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
@@ -1596,36 +1636,36 @@ class FoldWithProducerReshapeOpByCollapsing
};
/// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
public:
CollapseLinalgDimensions(MLIRContext *context,
GetCollapsableDimensionsFn collapseDimensions,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
+ : OpRewritePattern<LinalgType>(context, benefit),
controlCollapseDimension(std::move(collapseDimensions)) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgType op,
PatternRewriter &rewriter) const override {
SmallVector<ReassociationIndices> collapsableIterationDims =
- controlCollapseDimension(genericOp);
+ controlCollapseDimension(op);
if (collapsableIterationDims.empty())
return failure();
// Check if the specified list of dimensions to collapse is a valid list.
- if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+ if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
collapsableIterationDims)) {
return rewriter.notifyMatchFailure(
- genericOp, "specified dimensions cannot be collapsed");
+ op, "specified dimensions cannot be collapsed");
}
std::optional<SmallVector<Value>> replacements =
- collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- rewriter);
+ collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+ rewriter);
if (!replacements) {
- return rewriter.notifyMatchFailure(genericOp,
- "failed to collapse dimensions");
+ return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
- rewriter.replaceOp(genericOp, *replacements);
+ rewriter.replaceOp(op, *replacements);
return success();
}
@@ -1856,8 +1896,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populateCollapseDimensions(
RewritePatternSet &patterns,
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
- patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
- controlCollapseDimensions);
+ patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>(
+ patterns.getContext(), controlCollapseDimensions);
+ patterns.add<CollapseLinalgDimensions<linalg::CopyOp>>(
+ patterns.getContext(), controlCollapseDimensions);
}
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,86 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
// CHECK-LABEL: func @uncollapsable(
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL: func.func private @collapsable_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK: linalg.yield %[[VAL_9]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK: }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+ %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall change looks reasonable. Some refactoring might make this better, but this throws up some design issues. Why is CopyOp
treated specially? This should work for other Linalg named ops as well. An alternative would be to use matchers to match a linalg.generic
and convert to a linalg.copy
after it is collapsed. That way we could keep the two concerns separate (and having a matcher for linalg.generic
-> linalg.copy
might be helpful in other places as well.
17609b4
to
f2c94f6
Compare
Actually am fully convient with the current design also. I am afraid that by the time I will need to collapse the |
f2c94f6
to
8a3290a
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
8a3290a
to
78d4218
Compare
I don't have concerns with the current implementation. |
78d4218
to
144b497
Compare
@MaheshRavishankar @nicolasvasilache thanks for the code reivew. |
@MaheshRavishankar ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will relent on allow linalg.copy
be handled special case, though i'd strongly prefer not special casing linalg.copy
this way. We should really just be collapsing on linalg.generic
and "matching" back to linalg.copy
if needed. But that might be a preference, so this path is OK.
Please do remove the static_asserts
though... they cause deployment issues downstream.
Operation *createCollapsedOp(LinalgType op, | ||
const CollapsingInfo &collapsingInfo, | ||
RewriterBase &rewriter) { | ||
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_asserts
create problems in deployment. This is an optimization, not required for correctness. It is completely reasonable to have an operation not collapse and optimization to return a failure. The caller can then handle appropriately. Please change to fail gracefully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I don't understand it.
Whoever is planning to call this function can know in compile time that the type it uses as template is/isn't supported.
If we can catch bug in compilation time it is better than catch it during runtime.
Can you elaborate about static assert can cause deployment issues downstream
?
@@ -1884,8 +1902,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( | |||
void mlir::linalg::populateCollapseDimensions( | |||
RewritePatternSet &patterns, | |||
const GetCollapsableDimensionsFn &controlCollapseDimensions) { | |||
patterns.add<CollapseLinalgDimensions>(patterns.getContext(), | |||
controlCollapseDimensions); | |||
patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Use a single patterns.add<...>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack
144b497
to
41c5c80
Compare
41c5c80
to
b47dbcb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I misread what was happening. Looks good. Thanks!
thanks for the fast response! |
@chelini here is a place that can be cleaned up once we have |
Sounds good @nicolasvasilache, thanks. I will start working on it coming days. |