Skip to content

Commit 05d5125

Browse files
[mlir] Generalize OpFoldResult usage in ops with offsets, sizes and operands.
This revision starts evolving the APIs to manipulate ops with offsets, sizes and operands towards a ValueOrAttr abstraction that is already used in folding under the name OpFoldResult. The objective, in the future, is to allow such manipulations all the way to the level of ODS to avoid all the genuflexions involved in distinguishing between values and attributes for generic constant foldings. Once this evolution is accepted, the next step will be a mechanical OpFoldResult -> ValueOrAttr. Differential Revision: https://reviews.llvm.org/D95310
1 parent 7163aa9 commit 05d5125

File tree

15 files changed

+425
-395
lines changed

15 files changed

+425
-395
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,14 +1959,19 @@ def MemRefReinterpretCastOp:
19591959
let builders = [
19601960
// Build a ReinterpretCastOp with mixed static and dynamic entries.
19611961
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
1962-
"int64_t":$staticOffset, "ArrayRef<int64_t>":$staticSizes,
1963-
"ArrayRef<int64_t>":$staticStrides, "ValueRange":$offset,
1964-
"ValueRange":$sizes, "ValueRange":$strides,
1962+
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
1963+
"ArrayRef<OpFoldResult>":$strides,
19651964
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1966-
// Build a ReinterpretCastOp with all dynamic entries.
1965+
// Build a ReinterpretCastOp with static entries.
19671966
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
1968-
"Value":$offset, "ValueRange":$sizes, "ValueRange":$strides,
1967+
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,
1968+
"ArrayRef<int64_t>":$strides,
19691969
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1970+
// Build a ReinterpretCastOp with dynamic entries.
1971+
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
1972+
"Value":$offset, "ValueRange":$sizes,
1973+
"ValueRange":$strides,
1974+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
19701975
];
19711976

19721977
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -2927,23 +2932,33 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
29272932
let results = (outs AnyMemRef:$result);
29282933

29292934
let builders = [
2930-
// Build a SubViewOp with mixed static and dynamic entries.
2931-
OpBuilderDAG<(ins "Value":$source, "ArrayRef<int64_t>":$staticOffsets,
2932-
"ArrayRef<int64_t>":$staticSizes, "ArrayRef<int64_t>":$staticStrides,
2933-
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
2935+
// Build a SubViewOp with mixed static and dynamic entries and custom
2936+
// result type. If the type passed is nullptr, it is inferred.
2937+
OpBuilderDAG<(ins "Value":$source, "ArrayRef<OpFoldResult>":$offsets,
2938+
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
29342939
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
2935-
// Build a SubViewOp with all dynamic entries.
2936-
OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets,
2937-
"ValueRange":$sizes, "ValueRange":$strides,
2940+
// Build a SubViewOp with mixed static and dynamic entries and inferred
2941+
// result type.
2942+
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
2943+
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
2944+
"ArrayRef<OpFoldResult>":$strides,
29382945
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
2939-
// Build a SubViewOp with mixed static and dynamic entries
2940-
// and custom result type.
2946+
// Build a SubViewOp with static entries and custom result type. If the
2947+
// type passed is nullptr, it is inferred.
2948+
OpBuilderDAG<(ins "Value":$source, "ArrayRef<int64_t>":$offsets,
2949+
"ArrayRef<int64_t>":$sizes, "ArrayRef<int64_t>":$strides,
2950+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
2951+
// Build a SubViewOp with static entries and inferred result type.
29412952
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
2942-
"ArrayRef<int64_t>":$staticOffsets, "ArrayRef<int64_t>":$staticSizes,
2943-
"ArrayRef<int64_t>":$staticStrides, "ValueRange":$offsets,
2953+
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$sizes,
2954+
"ArrayRef<int64_t>":$strides,
2955+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
2956+
// Build a SubViewOp with dynamic entries and custom result type. If the
2957+
// type passed is nullptr, it is inferred.
2958+
OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets,
29442959
"ValueRange":$sizes, "ValueRange":$strides,
29452960
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
2946-
// Build a SubViewOp with all dynamic entries and custom result type.
2961+
// Build a SubViewOp with dynamic entries and inferred result type.
29472962
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
29482963
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
29492964
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
@@ -3039,26 +3054,6 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
30393054
let results = (outs AnyRankedTensor:$result);
30403055

30413056
let builders = [
3042-
// Build a SubTensorOp with mixed static and dynamic entries.
3043-
OpBuilderDAG<(ins "Value":$source, "ArrayRef<int64_t>":$staticOffsets,
3044-
"ArrayRef<int64_t>":$staticSizes, "ArrayRef<int64_t>":$staticStrides,
3045-
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
3046-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3047-
// Build a SubTensorOp with all dynamic entries.
3048-
OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets,
3049-
"ValueRange":$sizes, "ValueRange":$strides,
3050-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3051-
// Build a SubTensorOp with mixed static and dynamic entries
3052-
// and custom result type.
3053-
OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source,
3054-
"ArrayRef<int64_t>":$staticOffsets, "ArrayRef<int64_t>":$staticSizes,
3055-
"ArrayRef<int64_t>":$staticStrides, "ValueRange":$offsets,
3056-
"ValueRange":$sizes, "ValueRange":$strides,
3057-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3058-
// Build a SubTensorOp with all dynamic entries and custom result type.
3059-
OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source,
3060-
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
3061-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
30623057
// Build a SubTensorOp with mixed static and dynamic entries and inferred
30633058
// result type.
30643059
OpBuilderDAG<(ins "Value":$source, "ArrayRef<OpFoldResult>":$offsets,
@@ -3069,6 +3064,15 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
30693064
OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source,
30703065
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
30713066
"ArrayRef<OpFoldResult>":$strides,
3067+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3068+
// Build a SubTensorOp with dynamic entries and custom result type. If the
3069+
// type passed is nullptr, it is inferred.
3070+
OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets,
3071+
"ValueRange":$sizes, "ValueRange":$strides,
3072+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3073+
// Build a SubTensorOp with dynamic entries and inferred result type.
3074+
OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source,
3075+
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
30723076
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
30733077
];
30743078

@@ -3157,19 +3161,13 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
31573161

31583162
let builders = [
31593163
// Build a SubTensorInsertOp with mixed static and dynamic entries.
3160-
OpBuilderDAG<(ins "Value":$source, "Value":$dest,
3161-
"ArrayRef<int64_t>":$staticOffsets, "ArrayRef<int64_t>":$staticSizes,
3162-
"ArrayRef<int64_t>":$staticStrides, "ValueRange":$offsets,
3163-
"ValueRange":$sizes, "ValueRange":$strides,
3164-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3165-
// Build a SubTensorInsertOp with all dynamic entries.
3166-
OpBuilderDAG<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets,
3167-
"ValueRange":$sizes, "ValueRange":$strides,
3168-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3169-
// Build a SubTensorInsertOp with mixed static and dynamic entries.
31703164
OpBuilderDAG<(ins "Value":$source, "Value":$dest,
31713165
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
31723166
"ArrayRef<OpFoldResult>":$strides,
3167+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
3168+
// Build a SubTensorInsertOp with dynamic entries.
3169+
OpBuilderDAG<(ins "Value":$source, "Value":$dest,
3170+
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
31733171
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
31743172
];
31753173

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,25 @@ inline bool operator!=(OpState lhs, OpState rhs) {
213213
return lhs.getOperation() != rhs.getOperation();
214214
}
215215

216+
raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
217+
216218
/// This class represents a single result from folding an operation.
217219
class OpFoldResult : public PointerUnion<Attribute, Value> {
218220
using PointerUnion<Attribute, Value>::PointerUnion;
221+
222+
public:
223+
void dump() { llvm::errs() << *this << "\n"; }
219224
};
220225

226+
/// Allow printing to a stream.
227+
inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
228+
if (Value value = ofr.dyn_cast<Value>())
229+
value.print(os);
230+
else
231+
ofr.dyn_cast<Attribute>().print(os);
232+
return os;
233+
}
234+
221235
/// Allow printing to a stream.
222236
inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
223237
op.print(os, OpPrintingFlags().useLocalScope());

mlir/include/mlir/Interfaces/ViewLikeInterface.td

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,28 +108,6 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
108108
return $_op.sizes();
109109
}]
110110
>,
111-
InterfaceMethod<
112-
/*desc=*/[{
113-
Return a vector of all the static or dynamic sizes of the op.
114-
}],
115-
/*retTy=*/"SmallVector<OpFoldResult, 4>",
116-
/*methodName=*/"getMixedSizes",
117-
/*args=*/(ins),
118-
/*methodBody=*/"",
119-
/*defaultImplementation=*/[{
120-
SmallVector<OpFoldResult, 4> res;
121-
std::array<unsigned, 3> ranks = $_op.getArrayAttrRanks();
122-
unsigned numDynamic = 0;
123-
unsigned count = ranks[getOffsetOperandGroupPosition()];
124-
for (unsigned idx = 0; idx < count; ++idx) {
125-
if (isDynamicSize(idx))
126-
res.push_back($_op.sizes()[numDynamic++]);
127-
else
128-
res.push_back($_op.static_sizes()[idx]);
129-
}
130-
return res;
131-
}]
132-
>,
133111
InterfaceMethod<
134112
/*desc=*/[{
135113
Return the dynamic stride operands.
@@ -178,6 +156,72 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
178156
return $_op.static_strides();
179157
}]
180158
>,
159+
InterfaceMethod<
160+
/*desc=*/[{
161+
Return a vector of all the static or dynamic sizes of the op.
162+
}],
163+
/*retTy=*/"SmallVector<OpFoldResult, 4>",
164+
/*methodName=*/"getMixedOffsets",
165+
/*args=*/(ins),
166+
/*methodBody=*/"",
167+
/*defaultImplementation=*/[{
168+
SmallVector<OpFoldResult, 4> res;
169+
std::array<unsigned, 3> ranks = $_op.getArrayAttrRanks();
170+
unsigned numDynamic = 0;
171+
unsigned count = ranks[getOffsetOperandGroupPosition()];
172+
for (unsigned idx = 0; idx < count; ++idx) {
173+
if (isDynamicOffset(idx))
174+
res.push_back($_op.offsets()[numDynamic++]);
175+
else
176+
res.push_back($_op.static_offsets()[idx]);
177+
}
178+
return res;
179+
}]
180+
>,
181+
InterfaceMethod<
182+
/*desc=*/[{
183+
Return a vector of all the static or dynamic sizes of the op.
184+
}],
185+
/*retTy=*/"SmallVector<OpFoldResult, 4>",
186+
/*methodName=*/"getMixedSizes",
187+
/*args=*/(ins),
188+
/*methodBody=*/"",
189+
/*defaultImplementation=*/[{
190+
SmallVector<OpFoldResult, 4> res;
191+
std::array<unsigned, 3> ranks = $_op.getArrayAttrRanks();
192+
unsigned numDynamic = 0;
193+
unsigned count = ranks[getSizeOperandGroupPosition()];
194+
for (unsigned idx = 0; idx < count; ++idx) {
195+
if (isDynamicSize(idx))
196+
res.push_back($_op.sizes()[numDynamic++]);
197+
else
198+
res.push_back($_op.static_sizes()[idx]);
199+
}
200+
return res;
201+
}]
202+
>,
203+
InterfaceMethod<
204+
/*desc=*/[{
205+
Return a vector of all the static or dynamic strides of the op.
206+
}],
207+
/*retTy=*/"SmallVector<OpFoldResult, 4>",
208+
/*methodName=*/"getMixedStrides",
209+
/*args=*/(ins),
210+
/*methodBody=*/"",
211+
/*defaultImplementation=*/[{
212+
SmallVector<OpFoldResult, 4> res;
213+
std::array<unsigned, 3> ranks = $_op.getArrayAttrRanks();
214+
unsigned numDynamic = 0;
215+
unsigned count = ranks[getStrideOperandGroupPosition()];
216+
for (unsigned idx = 0; idx < count; ++idx) {
217+
if (isDynamicStride(idx))
218+
res.push_back($_op.strides()[numDynamic++]);
219+
else
220+
res.push_back($_op.static_strides()[idx]);
221+
}
222+
return res;
223+
}]
224+
>,
181225

182226
InterfaceMethod<
183227
/*desc=*/[{

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,8 @@ class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
241241
Value alloc =
242242
rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
243243
Value subView = rewriter.create<SubViewOp>(
244-
op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()),
245-
extractFromI64ArrayAttr(op.static_sizes()),
246-
extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
247-
op.strides());
244+
op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
245+
op.getMixedStrides());
248246
rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
249247
rewriter.replaceOp(op, alloc);
250248
return success();
@@ -283,10 +281,8 @@ class SubTensorInsertOpConverter
283281

284282
// Take a subview to copy the small memref.
285283
Value subview = rewriter.create<SubViewOp>(
286-
op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()),
287-
extractFromI64ArrayAttr(op.static_sizes()),
288-
extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(),
289-
adaptor.sizes(), adaptor.strides());
284+
op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
285+
op.getMixedStrides());
290286
// Copy the small memref.
291287
rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
292288
rewriter.replaceOp(op, destMemRef);

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ using llvm::dbgs;
6161
// by `permutationMap`.
6262
static void inferShapeComponents(AffineMap permutationMap,
6363
ArrayRef<Range> loopRanges,
64-
SmallVectorImpl<Value> &offsets,
65-
SmallVectorImpl<Value> &sizes,
66-
SmallVectorImpl<Value> &strides) {
64+
SmallVectorImpl<OpFoldResult> &offsets,
65+
SmallVectorImpl<OpFoldResult> &sizes,
66+
SmallVectorImpl<OpFoldResult> &strides) {
6767
assert(permutationMap.isProjectedPermutation() &&
6868
"expected some subset of a permutation map");
6969
SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
@@ -101,7 +101,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
101101
AffineMap map = op.getIndexingMap(shapedOperandIdx);
102102
LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
103103
<< " with indexingMap: " << map << "\n");
104-
SmallVector<Value, 4> offsets, sizes, strides;
104+
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
105105
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
106106
Value shape = en.value();
107107
Value sub = shape.getType().isa<MemRefType>()

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
214214
ScopedContext scopedContext(b, loc);
215215
auto viewType = subView.getType();
216216
auto rank = viewType.getRank();
217-
SmallVector<Value, 4> fullSizes, partialSizes;
217+
SmallVector<Value, 4> fullSizes;
218+
SmallVector<OpFoldResult> partialSizes;
218219
fullSizes.reserve(rank);
219220
partialSizes.reserve(rank);
220221
for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
@@ -226,18 +227,16 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
226227
(!sizeAttr) ? rangeValue.size : b.create<ConstantOp>(loc, sizeAttr);
227228
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
228229
fullSizes.push_back(size);
229-
partialSizes.push_back(folded_std_dim(folder, subView, en.index()));
230+
partialSizes.push_back(folded_std_dim(folder, subView, en.index()).value);
230231
}
231232
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
232233
// If a callback is not specified, then use the default implementation for
233234
// allocating the promoted buffer.
234235
Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, folder);
235236
if (!fullLocalView)
236237
return {};
237-
auto zero = folded_std_constant_index(folder, 0);
238-
auto one = folded_std_constant_index(folder, 1);
239-
SmallVector<Value, 4> zeros(fullSizes.size(), zero);
240-
SmallVector<Value, 4> ones(fullSizes.size(), one);
238+
SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
239+
SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1));
241240
auto partialLocalView =
242241
folded_std_subview(folder, *fullLocalView, zeros, partialSizes, ones);
243242
return PromotionInfo{*fullLocalView, partialLocalView};

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,15 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
255255
}
256256

257257
// Construct a new subview / subtensor for the tile.
258-
SmallVector<Value, 4> offsets, sizes, strides;
258+
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
259259
offsets.reserve(rank);
260260
sizes.reserve(rank);
261261
strides.reserve(rank);
262262
for (unsigned r = 0; r < rank; ++r) {
263263
if (!isTiled(map.getSubMap({r}), tileSizes)) {
264-
offsets.push_back(std_constant_index(0));
265-
sizes.push_back(std_dim(shapedOp, r));
266-
strides.push_back(std_constant_index(1));
264+
offsets.push_back(b.getIndexAttr(0));
265+
sizes.push_back(std_dim(shapedOp, r).value);
266+
strides.push_back(b.getIndexAttr(1));
267267
continue;
268268
}
269269

@@ -297,7 +297,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
297297
}
298298

299299
sizes.push_back(size);
300-
strides.push_back(std_constant_index(1));
300+
strides.push_back(b.getIndexAttr(1));
301301
}
302302

303303
if (shapedType.isa<MemRefType>())

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,17 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
192192
// This later folds away.
193193
SmallVector<Value> paddedSubviewResults;
194194
paddedSubviewResults.reserve(opToPad->getNumResults());
195-
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
196-
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
197195
llvm::SetVector<Operation *> newUsersOfOpToPad;
198196
for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
199197
auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
200-
SmallVector<Value> offsets(rank, zero);
201-
auto sizes = llvm::to_vector<4>(
202-
llvm::map_range(llvm::seq<unsigned>(0, rank), [&](unsigned d) -> Value {
198+
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
199+
auto sizes = llvm::to_vector<4>(llvm::map_range(
200+
llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
203201
auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d);
204202
newUsersOfOpToPad.insert(dimOp);
205-
return dimOp;
203+
return dimOp.getResult();
206204
}));
207-
SmallVector<Value> strides(rank, one);
205+
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
208206
paddedSubviewResults.push_back(rewriter.create<SubTensorOp>(
209207
loc, std::get<1>(it), offsets, sizes, strides));
210208
}

0 commit comments

Comments
 (0)