Skip to content

Commit a975be0

Browse files
committed
[mlir][shape] Make conversion passes more consistent.
- use select-ops to make the lowering simpler - change style of FileCheck variables names to be consistent - change some variable names in the code to be more explicit Differential Revision: https://reviews.llvm.org/D88258
1 parent 2d657d1 commit a975be0

File tree

4 files changed

+58
-77
lines changed

4 files changed

+58
-77
lines changed

mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,17 @@ class ConvertCstrBroadcastableOp
3939
// Find smaller and greater rank and extent tensor.
4040
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
4141
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
42-
Value lhsSmaller =
42+
Value lhsRankULE =
4343
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
4444
Type indexTy = rewriter.getIndexType();
45-
Type extentTensorTy = op.lhs().getType();
46-
auto ifOp = rewriter.create<scf::IfOp>(
47-
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
48-
lhsSmaller,
49-
[&](OpBuilder &b, Location loc) {
50-
b.create<scf::YieldOp>(
51-
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
52-
},
53-
[&](OpBuilder &b, Location loc) {
54-
b.create<scf::YieldOp>(
55-
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
56-
});
57-
Value lesserRank = ifOp.getResult(0);
58-
Value lesserRankOperand = ifOp.getResult(1);
59-
Value greaterRank = ifOp.getResult(2);
60-
Value greaterRankOperand = ifOp.getResult(3);
45+
Value lesserRank =
46+
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
47+
Value greaterRank =
48+
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
49+
Value lesserRankOperand =
50+
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
51+
Value greaterRankOperand =
52+
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
6153

6254
Value rankDiff =
6355
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -90,39 +90,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
9090
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
9191

9292
// Find smaller and greater rank and extent tensor.
93-
Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
94-
Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
95-
Value lhsSmaller =
93+
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
94+
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
95+
Value lhsRankULE =
9696
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
9797
Type indexTy = rewriter.getIndexType();
98-
Type extentTensorTy = op.getType();
99-
auto ifOp = rewriter.create<IfOp>(
100-
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
101-
lhsSmaller,
102-
[&](OpBuilder &b, Location loc) {
103-
b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
104-
rhsRank, transformed.rhs()});
105-
},
106-
[&](OpBuilder &b, Location loc) {
107-
b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
108-
lhsRank, transformed.lhs()});
109-
});
110-
Value smallerRank = ifOp.getResult(0);
111-
Value smallerOperand = ifOp.getResult(1);
112-
Value greaterRank = ifOp.getResult(2);
113-
Value greaterOperand = ifOp.getResult(3);
98+
Value lesserRank =
99+
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
100+
Value greaterRank =
101+
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
102+
Value lesserRankOperand =
103+
rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
104+
Value greaterRankOperand =
105+
rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
114106

115107
// Allocate stack memory for the broadcasted extent tensor.
116108
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
117109
Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
118110

119111
// Copy extents from greater operand that are not challenged.
120112
Value rankDiff =
121-
rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
113+
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
122114
rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
123115
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
124116
Value extent = b.create<ExtractElementOp>(
125-
loc, greaterOperand, ValueRange{iv});
117+
loc, greaterRankOperand, ValueRange{iv});
126118
b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
127119
b.create<scf::YieldOp>(loc);
128120
});
@@ -132,16 +124,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
132124
loc, rankDiff, greaterRank, one, llvm::None,
133125
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
134126
Value greaterOperandExtent =
135-
b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
127+
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
136128
Value greaterOperandExtentIsOne =
137129
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
138130
auto ifOp = b.create<IfOp>(
139131
loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
140132
[&](OpBuilder &b, Location loc) {
141133
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
142-
Value smallerOperandExtent = b.create<ExtractElementOp>(
143-
loc, smallerOperand, ValueRange{ivShifted});
144-
b.create<scf::YieldOp>(loc, smallerOperandExtent);
134+
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
135+
loc, lesserRankOperand, ValueRange{ivShifted});
136+
b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
145137
},
146138
[&](OpBuilder &b, Location loc) {
147139
b.create<scf::YieldOp>(loc, greaterOperandExtent);

mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,24 @@
77
// CHECK: %[[C0:.*]] = constant 0 : index
88
// CHECK: %[[C1:.*]] = constant 1 : index
99
// CHECK: %[[RET:.*]] = shape.const_witness true
10-
// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
11-
// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
12-
// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
13-
// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
14-
// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
15-
// CHECK: } else {
16-
// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
17-
// CHECK: }
18-
// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
19-
// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
20-
// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
21-
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
22-
// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
23-
// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
24-
// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
25-
// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
26-
// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
27-
// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
28-
// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast"
10+
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
11+
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
12+
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
13+
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
14+
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
15+
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
16+
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
17+
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
18+
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
19+
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
20+
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
21+
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
22+
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
23+
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
24+
// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index
25+
// CHECK: %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1
26+
// CHECK: %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1
27+
// CHECK: assert %[[BROADCAST_IS_VALID]], "invalid broadcast"
2928
// CHECK: }
3029
// CHECK: return %[[RET]] : !shape.witness
3130
// CHECK: }

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -312,27 +312,26 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
312312
// CHECK: %[[C1:.*]] = constant 1 : index
313313
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
314314
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
315-
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
316-
// CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
317-
// CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
318-
// CHECK: } else {
319-
// CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
320-
// CHECK: }
321-
// CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
322-
// CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
315+
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
316+
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
317+
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
318+
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
319+
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
320+
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
321+
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
323322
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
324-
// CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
323+
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
325324
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
326325
// CHECK: }
327-
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
328-
// CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
329-
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
326+
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
327+
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
328+
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
330329
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
331330
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
332-
// CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
333-
// CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
331+
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
332+
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
334333
// CHECK: } else {
335-
// CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index
334+
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
336335
// CHECK: }
337336
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
338337
// CHECK: }
@@ -341,4 +340,3 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
341340
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
342341
return
343342
}
344-

0 commit comments

Comments
 (0)