Skip to content

Commit 1ca1197

Browse files
author
Peiming Liu
committed
[mlir][scf] support 1:N type conversion for scf.if/while/condition
Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D137100
1 parent 96a74c4 commit 1ca1197

File tree

2 files changed

+114
-61
lines changed

2 files changed

+114
-61
lines changed

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 49 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -155,44 +155,57 @@ class ConvertForOpTypes
155155
} // namespace
156156

157157
namespace {
158-
class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
158+
class ConvertIfOpTypes
159+
: public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
159160
public:
160-
using OpConversionPattern::OpConversionPattern;
161-
LogicalResult
162-
matchAndRewrite(IfOp op, OpAdaptor adaptor,
163-
ConversionPatternRewriter &rewriter) const override {
164-
// TODO: Generalize this to any type conversion, not just 1:1.
165-
//
166-
// We need to implement something more sophisticated here that tracks
167-
// which types convert to which other types and does the appropriate
168-
// materialization logic.
169-
// For example, it's possible that one result type converts to 0 types and
170-
// another to 2 types, so newResultTypes would at least be the right size
171-
// to not crash in the llvm::zip call below, but then we would set the the
172-
// wrong type on the SSA values! These edge cases are also why we cannot
173-
// safely use the TypeConverter::convertTypes helper here.
174-
SmallVector<Type, 6> newResultTypes;
175-
for (auto type : op.getResultTypes()) {
176-
Type newType = typeConverter->convertType(type);
177-
if (!newType)
178-
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
179-
newResultTypes.push_back(newType);
180-
}
161+
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
181162

182-
// See comments in the ForOp pattern for why we clone without regions and
183-
// then inline.
184-
IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
163+
Optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
164+
ConversionPatternRewriter &rewriter,
165+
TypeRange dstTypes) const {
166+
167+
IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
168+
adaptor.getCondition(), true);
169+
newOp->setAttrs(op->getAttrs());
170+
171+
// We do not need the empty blocks created by rewriter.
172+
rewriter.eraseBlock(newOp.elseBlock());
173+
rewriter.eraseBlock(newOp.thenBlock());
174+
175+
// Inlines block from the original operation.
185176
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
186177
newOp.getThenRegion().end());
187178
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
188179
newOp.getElseRegion().end());
189180

190-
// Update the operands and types.
191-
newOp->setOperands(adaptor.getOperands());
192-
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
193-
std::get<0>(t).setType(std::get<1>(t));
194-
rewriter.replaceOp(op, newOp.getResults());
195-
return success();
181+
return newOp;
182+
}
183+
};
184+
} // namespace
185+
186+
namespace {
187+
class ConvertWhileOpTypes
188+
: public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
189+
public:
190+
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
191+
192+
Optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter,
194+
TypeRange dstTypes) const {
195+
// Unpacked the iteration arguments.
196+
SmallVector<Value> flatArgs;
197+
for (Value arg : adaptor.getOperands())
198+
unpackUnrealizedConversionCast(arg, flatArgs);
199+
200+
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
201+
202+
for (auto i : {0u, 1u}) {
203+
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
204+
return llvm::None;
205+
auto &dstRegion = newOp.getRegion(i);
206+
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
207+
}
208+
return newOp;
196209
}
197210
};
198211
} // namespace
@@ -217,43 +230,18 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
217230
};
218231
} // namespace
219232

220-
namespace {
221-
class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
222-
public:
223-
using OpConversionPattern<WhileOp>::OpConversionPattern;
224-
225-
LogicalResult
226-
matchAndRewrite(WhileOp op, OpAdaptor adaptor,
227-
ConversionPatternRewriter &rewriter) const override {
228-
auto *converter = getTypeConverter();
229-
assert(converter);
230-
SmallVector<Type> newResultTypes;
231-
if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
232-
return failure();
233-
234-
auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
235-
adaptor.getOperands());
236-
for (auto i : {0u, 1u}) {
237-
auto &dstRegion = newOp.getRegion(i);
238-
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
239-
if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
240-
return rewriter.notifyMatchFailure(op, "could not convert body types");
241-
}
242-
rewriter.replaceOp(op, newOp.getResults());
243-
return success();
244-
}
245-
};
246-
} // namespace
247-
248233
namespace {
249234
class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
250235
public:
251236
using OpConversionPattern<ConditionOp>::OpConversionPattern;
252237
LogicalResult
253238
matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
254239
ConversionPatternRewriter &rewriter) const override {
255-
rewriter.updateRootInPlace(
256-
op, [&]() { op->setOperands(adaptor.getOperands()); });
240+
SmallVector<Value> unpackedYield;
241+
for (Value operand : adaptor.getOperands())
242+
unpackUnrealizedConversionCast(operand, unpackedYield);
243+
244+
rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); });
257245
return success();
258246
}
259247
};

mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,68 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
3030
return %1 : tensor<1024xf32, #SparseVector>
3131
}
3232

33+
34+
// CHECK-LABEL: func @if(
35+
// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
36+
// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
37+
// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
38+
// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
39+
// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
40+
// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
41+
// CHECK-SAME: %[[DIM_SIZE_1:.*6]]: memref<1xindex>,
42+
// CHECK-SAME: %[[DIM_CURSOR_1:.*7]]: memref<1xindex>,
43+
// CHECK-SAME: %[[MEM_SIZE_1:.*8]]: memref<3xindex>,
44+
// CHECK-SAME: %[[POINTER_1:.*9]]: memref<?xindex>,
45+
// CHECK-SAME: %[[INDICES_1:.*10]]: memref<?xindex>,
46+
// CHECK-SAME: %[[VALUE_1:.*11]]: memref<?xf32>,
47+
// CHECK-SAME: %[[TMP_arg12:.*12]]: i1) ->
48+
// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
49+
// CHECK: %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
50+
// CHECK: scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
51+
// CHECK: } else {
52+
// CHECK: scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
53+
// CHECK: }
54+
// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
55+
func.func @if(%t: tensor<1024xf32, #SparseVector>,
56+
%f: tensor<1024xf32, #SparseVector>,
57+
%c: i1) -> tensor<1024xf32, #SparseVector> {
58+
%1 = scf.if %c -> tensor<1024xf32, #SparseVector> {
59+
scf.yield %t : tensor<1024xf32, #SparseVector>
60+
} else {
61+
scf.yield %f : tensor<1024xf32, #SparseVector>
62+
}
63+
64+
return %1 : tensor<1024xf32, #SparseVector>
65+
}
66+
67+
// CHECK-LABEL: func @while(
68+
// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
69+
// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
70+
// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
71+
// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
72+
// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
73+
// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
74+
// CHECK-SAME: %[[TMP_arg6:.*6]]: i1) ->
75+
// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
76+
// CHECK: %[[SV:.*]]:6 = scf.while (
77+
// CHECK-SAME: %[[TMP_arg7:.*]] = %[[DIM_SIZE]],
78+
// CHECK-SAME: %[[TMP_arg8:.*]] = %[[DIM_CURSOR]],
79+
// CHECK-SAME: %[[TMP_arg9:.*]] = %[[MEM_SIZE]],
80+
// CHECK-SAME: %[[TMP_arg10:.*]] = %[[POINTER]],
81+
// CHECK-SAME: %[[TMP_arg11:.*]] = %[[INDICES]],
82+
// CHECK-SAME: %[[TMP_arg12:.*]] = %[[VALUE]])
83+
// CHECK: scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
84+
// CHECK: } do {
85+
// CHECK: ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref<?xindex>, %[[TMP_arg11]]: memref<?xindex>, %[[TMP_arg12]]: memref<?xf32>):
86+
// CHECK: scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
87+
// CHECK: }
88+
// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
89+
func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
90+
%0 = scf.while (%arg4 = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
91+
scf.condition(%c) %arg4 : tensor<1024xf32, #SparseVector>
92+
} do {
93+
^bb0(%arg7: tensor<1024xf32, #SparseVector>):
94+
scf.yield %arg7 : tensor<1024xf32, #SparseVector>
95+
}
96+
return %0: tensor<1024xf32, #SparseVector>
97+
}

0 commit comments

Comments
 (0)