Skip to content

Commit 67cde37

Browse files
authored
[mlir] Match before rewrite in BubbleUpPackOpThroughGenericOp (llvm#126946)
The BubbleUpPackOpThroughGenericOp pattern had some unsafe rewrites happening before matching was fully complete, which causes the pattern rewriter to fail to converge. This PR fixes the bug by moving all matching logic to before the rewrite logic. Signed-off-by: Max Dawkins <[email protected]>
1 parent d200caa commit 67cde37

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
399399
if (!genericOp->getResult(0).hasOneUse())
400400
return failure();
401401

402+
// TODO: Add an option for allowing padding values. It could introduce
403+
// undefined behavior if we unconditionally propagate pack op through all
404+
// the ops. E.g., if the padding value is zero and there are division ops in
405+
// a generic op. Some values of padding area could be NaN (0/0).
406+
if (packOp.getPaddingValue())
407+
return failure();
408+
409+
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
410+
auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
411+
if (failed(packInfo))
412+
return failure();
413+
402414
// We want to move the pack not the generic.
403415
OpBuilder::InsertionGuard guard(rewriter);
404416
rewriter.setInsertionPoint(genericOp);
@@ -422,18 +434,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
422434
return failure();
423435
}
424436

425-
// TODO: Add an option for allowing padding values. It could introduce
426-
// undefined behavior if we unconditionally propagate pack op through all
427-
// the ops. E.g., if the padding value is zero and there are division ops in
428-
// a generic op. Some values of padding area could be NaN (0/0).
429-
if (packOp.getPaddingValue())
430-
return failure();
431-
432-
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433-
auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434-
if (failed(packInfo))
435-
return failure();
436-
437437
// Rebuild the indexing map for the corresponding init operand.
438438
auto [packedOutOperand, packedOutIndexingMap] =
439439
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,34 @@ func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>)
4646

4747
// -----
4848

49+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
50+
func.func @dynamic_elem_pack_padding_value(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
51+
{
52+
%c0 = arith.constant 0 : index
53+
%c1 = arith.constant 1 : index
54+
%cst = arith.constant 3.000000e+00 : f32
55+
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
56+
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
57+
%2 = tensor.empty(%0, %1) : tensor<?x?xf32>
58+
%3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
59+
ins(%arg0 : tensor<?x?xf32>)
60+
outs(%2 : tensor<?x?xf32>) {
61+
^bb0(%arg3: f32, %arg4: f32):
62+
%4 = arith.addf %arg3, %arg3 : f32
63+
linalg.yield %4 : f32
64+
} -> tensor<?x?xf32>
65+
%4 = tensor.pack %3 padding_value(%cst : f32)
66+
inner_dims_pos = [0, 1]
67+
inner_tiles = [8, 2]
68+
into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
69+
return %4 : tensor<?x?x8x2xf32>
70+
}
71+
// CHECK-LABEL: func.func @dynamic_elem_pack_padding_value
72+
// CHECK: %[[GENERIC:.+]] = linalg.generic
73+
// CHECK: tensor.pack %[[GENERIC]]
74+
75+
// -----
76+
4977
#map0 = affine_map<(d0, d1) -> (d0, d1)>
5078
func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{
5179
%init = tensor.empty() : tensor<128x256xi32>

0 commit comments

Comments
 (0)