Skip to content

Commit d776346

Browse files
authored
[mlir][linalg] Avoid emitting errors in block pack matmul (llvm#93170)
Tweaks linalg.generic verification in block pack matmul pass to avoid using emitting errors which pollute stderr during operation matching.
1 parent f0b0c02 commit d776346

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,7 @@ struct BlockPackMatmul<linalg::GenericOp>
244244
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
245245
PatternRewriter &rewriter) const override {
246246
// Match suitable generics.
247-
if (failed(linalg::detail::verifyContractionInterface(
248-
linalgOp.getOperation()))) {
247+
if (!linalg::isaContractionOpInterface(linalgOp)) {
249248
return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
250249
}
251250

mlir/test/Dialect/Linalg/block-pack-matmul.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,32 @@ func.func @block_generic_matmul_transpose_b(
476476
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
477477
// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
478478
// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
479+
480+
// -----
481+
482+
#map = affine_map<(d0, d1) -> (d0, d1)>
483+
484+
func.func @non_contraction_generic(
485+
%A: tensor<64x128xf32>) -> tensor<64x128xf32> {
486+
%c0 = arith.constant 0.000000e+00 : f32
487+
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
488+
outs(%A : tensor<64x128xf32>) {
489+
^bb0(%out: f32):
490+
%1 = arith.maximumf %out, %c0 : f32
491+
linalg.yield %1 : f32
492+
} -> tensor<64x128xf32>
493+
return %0 : tensor<64x128xf32>
494+
}
495+
496+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
497+
498+
// CHECK-LABEL: func @non_contraction_generic(
499+
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32>
500+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
501+
// CHECK-NOT: tensor.pack
502+
// CHECK: %[[GENERIC:.+]] = linalg.generic
503+
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
504+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
505+
// CHECK-SAME: outs(%[[A]] : tensor<64x128xf32>)
506+
// CHECK-NOT: tensor.unpack
507+
// CHECK: return %[[GENERIC]] : tensor<64x128xf32>

0 commit comments

Comments
 (0)