Skip to content

Commit c3640b1

Browse files
rafaelubalmwyuxuanchen1997
authored andcommitted
[mlir] New canonicalization patterns for shape.shape_of and tensor.reshape (#98531)
This PR includes 3 new canonicalization patterns: - Operation `shape.shape_of`: shape of reshape ``` // Before func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> { %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> %0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex> return %0 : tensor<?xindex> } // After func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> { return %arg1 : tensor<?xindex> } ``` - Operation `tensor.reshape`: reshape of reshape ``` // Before func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> { %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> %1 = tensor.reshape %0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> return %1 : tensor<*xf32> } // After func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> { %reshape = tensor.reshape %arg0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> return %reshape : tensor<*xf32> } ``` - Operation `tensor.reshape`: reshape 1D to 1D ``` // Before func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> { %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> return %0 : tensor<?xf32> } // After func.func @fold_reshape_1d(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> { return %arg0 : tensor<?xf32> } ``` These three canonicalization patterns cooperate to simplify the IR structure emerging from the lowering of certain element-wise ops with unranked tensor inputs. See file `unranked-tensor-lowering.mlir` in the proposed change list for a detailed example and description. For context, this PR is meant to enable code optimizations for the code generated while lowering ops `quant.qcast` and `quant.dcast` with unranked tensors, as proposed in https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942 (implementation currently in progress).
1 parent a01788a commit c3640b1

File tree

5 files changed

+196
-10
lines changed

5 files changed

+196
-10
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,18 +1702,36 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
17021702
}
17031703
};
17041704

1705-
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1705+
// Canonicalize
1706+
//
1707+
// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1708+
// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1709+
//
1710+
// to
1711+
//
1712+
// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1713+
// %1 = %shape
1714+
//
1715+
struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
17061716
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
17071717

17081718
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
17091719
PatternRewriter &rewriter) const override {
1710-
if (!llvm::isa<ShapedType>(op.getArg().getType()))
1711-
return failure();
1712-
if (llvm::isa<ShapedType>(op.getType()))
1713-
return failure();
1714-
1715-
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1716-
op.getArg());
1720+
auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1721+
if (!tensorReshapeOp)
1722+
return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1723+
if (!isa<TensorType>(op.getType()))
1724+
return rewriter.notifyMatchFailure(op, "result is not a tensor");
1725+
1726+
// Operand 'shape' of 'tensor.reshape' may now be used as the result of
1727+
// 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1728+
// formed IR, it may not be identical (dynamically vs statically shaped),
1729+
// in which case it needs to be cast first.
1730+
Value shape = tensorReshapeOp.getShape();
1731+
if (op.getType() != shape.getType())
1732+
shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
1733+
1734+
rewriter.replaceOp(op, shape);
17171735
return success();
17181736
}
17191737
};
@@ -1753,7 +1771,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
17531771

17541772
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
17551773
MLIRContext *context) {
1756-
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1774+
patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
17571775
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
17581776
context);
17591777
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1585,13 +1585,25 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
15851585
getResult().getType()))
15861586
return reshapedSource;
15871587

1588+
// If the producer of operand 'source' is another 'tensor.reshape' op, use the
1589+
// producer's input instead as the original tensor to reshape. This could
1590+
// render such producer dead code.
1591+
if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1592+
getSourceMutable().assign(reshapeOpProducer.getSource());
1593+
return getResult();
1594+
}
1595+
15881596
auto source = getSource();
15891597
auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
15901598
auto resultTy = dyn_cast<RankedTensorType>(getType());
1591-
15921599
if (!sourceTy || !resultTy || sourceTy != resultTy)
15931600
return {};
15941601

1602+
// If the source and result are both 1D tensors and have the same type, the
1603+
// reshape has no effect, even if the tensor is dynamically shaped.
1604+
if (sourceTy.getRank() == 1)
1605+
return source;
1606+
15951607
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
15961608
auto elements = fromElements.getElements();
15971609
bool dynamicNoop =

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,45 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
13611361

13621362
// -----
13631363

1364+
// CHECK-LABEL: func @shape_of_from_reshape
1365+
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
1366+
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
1367+
func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
1368+
// CHECK: return %[[SHAPE]] : tensor<?xindex>
1369+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1370+
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1371+
return %1 : tensor<?xindex>
1372+
}
1373+
1374+
// -----
1375+
1376+
// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
1377+
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
1378+
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
1379+
func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
1380+
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
1381+
// CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
1382+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
1383+
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1384+
return %1 : tensor<?xindex>
1385+
}
1386+
1387+
// -----
1388+
1389+
// CHECK-LABEL: func @shape_of_from_reshape_nofold
1390+
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
1391+
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
1392+
func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
1393+
// CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1394+
// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
1395+
// CHECK: return %[[SHAPE_OF]] : !shape.shape
1396+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1397+
%1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
1398+
return %1 : !shape.shape
1399+
}
1400+
1401+
// -----
1402+
13641403
// CHECK-LABEL: @cast_extent_tensor
13651404
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
13661405
func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s
2+
3+
// This test verifies the simplification of IR patterns that emerge when
4+
// lowering high-level element-wise ops with unranked tensor inputs. Consider
5+
// the following function incrementing and doubling the value of an input
6+
// unranked tensor using ops in a hypothetical high-level dialect called 'hl':
7+
//
8+
// func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> {
9+
// %0 = hl.inc %input : tensor<*xf32>
10+
// %1 = hl.double %0 : tensor<*xf32>
11+
// return %1 : tensor<*xf32>
12+
// }
13+
//
14+
// A possible strategy to lower 'hl.inc' consists in reshaping its operand into
15+
// a 1D tensor, creating a 1D tensor splat with the same total size as the input
16+
// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and
17+
// reshaping the result back into the original input shape. A similar process
18+
// applies for 'hl.double', except with a tensor splat with value 2.0 and an
19+
// 'arith.mulf' op. The body of the function in the test below contains the full
20+
// sequence.
21+
//
22+
// Since such lowering process would operate on individual 'hl' ops in a
23+
// context-oblivious manner, the emitted code produces a redundant IR pattern
24+
// where the result of 'arith.addf' is reshaped into an unranked tensor, just
25+
// for it to be immediately reshaped back into the 1D tensor consumed by
26+
// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor
27+
// shape ('shape.shape_of') and size ('shape.num_elements').
28+
//
29+
// This test verifies that the consecutive application of a canonicalization and
30+
// a CSE pass successfully simplifies this emerging pattern, leading to a
31+
// version of the code in which the result of the emitted 'arith.addf' op
32+
// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op
33+
// associated with 'hl.double', as observed in the FileCheck directives. The
34+
// main rewrite patterns at play are 'shape.shape_of' canonicalization,
35+
// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression
36+
// elimination.
37+
//
38+
39+
// CHECK-LABEL: @unranked_tensor_lowering
40+
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
41+
42+
// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
43+
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
44+
45+
// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex>
46+
// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
47+
// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
48+
// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
49+
50+
// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
51+
// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32>
52+
53+
// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
54+
// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32>
55+
56+
// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
57+
// CHECK: return %[[PRODUCT]] : tensor<*xf32>
58+
59+
func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> {
60+
61+
// Collapse input
62+
%input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex>
63+
%input_size = shape.num_elements %input_shape : tensor<?xindex> -> index
64+
%input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex>
65+
%input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
66+
67+
// Second operand for sum
68+
%one = arith.constant 1.0 : f32
69+
%one_splat = tensor.splat %one[%input_size] : tensor<?xf32>
70+
71+
// Compute sum and expand it
72+
%sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32>
73+
%sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
74+
75+
// Collapse sum
76+
%sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex>
77+
%sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index
78+
%sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex>
79+
%sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
80+
81+
// Second operand for product
82+
%two = arith.constant 2.0 : f32
83+
%two_splat = tensor.splat %two[%sum_size] : tensor<?xf32>
84+
85+
// Compute product and expand it
86+
%product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32>
87+
%product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
88+
89+
return %product : tensor<*xf32>
90+
}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,33 @@ func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32>
847847

848848
// -----
849849

850+
// CHECK-LABEL: func @fold_reshape_chain
851+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
852+
// CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
853+
// CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
854+
// CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
855+
// CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
856+
// CHECK: return %[[RESULT]]
857+
func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
858+
%0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
859+
%1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
860+
%2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
861+
return %2 : tensor<*xf32>
862+
}
863+
864+
// -----
865+
866+
// CHECK-LABEL: func @fold_reshape_1d
867+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
868+
// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
869+
// CHECK: return %[[INPUT]]
870+
func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
871+
%0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
872+
return %0 : tensor<?xf32>
873+
}
874+
875+
// -----
876+
850877
// CHECK-LABEL: func @fold_extract_constant_splat
851878
// CHECK-NOT: tensor.extract_slice
852879
// CHECK: arith.constant dense<42> : tensor<4x4xi32>

0 commit comments

Comments
 (0)