Skip to content

Commit c0f4690

Browse files
committed
[MLIR] Fix crash in AffineMap::replace for zero result maps
Fix obvious bug in AffineMap::replace for the case of zero result maps. Extend/complete inferExprsFromList to work with empty expression lists.
1 parent 69a661c commit c0f4690

File tree

17 files changed

+99
-41
lines changed

17 files changed

+99
-41
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
6767
OpBuilder<(ins "ArrayRef<AffineExpr> ":$exprList,"ValueRange":$mapOperands),
6868
[{
6969
build($_builder, $_state, $_builder.getIndexType(),
70-
AffineMap::inferFromExprList(exprList).front(), mapOperands);
70+
AffineMap::inferFromExprList(exprList, $_builder.getContext())
71+
.front(), mapOperands);
7172
}]>
7273
];
7374

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ class StructuredGenerator {
121121
}
122122

123123
bool layout(MapList l) {
124-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
124+
auto infer = [&](MapList m) {
125+
return AffineMap::inferFromExprList(m, ctx);
126+
};
125127
return maps == infer(l);
126128
}
127129

mlir/include/mlir/IR/AffineMap.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ class AffineMap {
122122
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
123123
/// symbols as the largest symbol in `exprs`.
124124
static SmallVector<AffineMap, 4>
125-
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList);
125+
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList,
126+
MLIRContext *context);
126127
static SmallVector<AffineMap, 4>
127-
inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList);
128+
inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList,
129+
MLIRContext *context);
128130

129131
MLIRContext *getContext() const;
130132

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,7 +2010,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20102010
}
20112011

20122012
bool didEncounterError = false;
2013-
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
2013+
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2014+
rewriter.getContext());
20142015
auto linalgOp = rewriter.create<linalg::GenericOp>(
20152016
loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
20162017
ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
@@ -2351,9 +2352,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
23512352
createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
23522353

23532354
// Indexing maps for input and output tensors
2354-
auto indexingMaps = AffineMap::inferFromExprList(llvm::ArrayRef{
2355-
affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2),
2356-
affineDimsExpr(rewriter, 0, 1, 2)});
2355+
auto indexingMaps = AffineMap::inferFromExprList(
2356+
llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2357+
affineDimsExpr(rewriter, 0, 1, 2),
2358+
affineDimsExpr(rewriter, 0, 1, 2)},
2359+
rewriter.getContext());
23572360

23582361
// Width and height dimensions of the original input.
23592362
auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
@@ -2463,7 +2466,8 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
24632466
ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
24642467
RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
24652468
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2466-
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)});
2469+
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2470+
rewriter.getContext());
24672471

24682472
// Width and height dimensions of the original input.
24692473
auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
7777
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
7878
bool useNvGpu) {
7979
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
80+
auto infer = [&](MapList m) {
81+
return AffineMap::inferFromExprList(m, contract.getContext());
82+
};
8183
AffineExpr m, n, k;
8284
bindDims(contract.getContext(), m, n, k);
8385
auto iteratorTypes = contract.getIteratorTypes().getValue();
@@ -394,7 +396,9 @@ struct PrepareContractToGPUMMA
394396

395397
// Set up the parallel/reduction structure in right form.
396398
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
399+
auto infer = [&](MapList m) {
400+
return AffineMap::inferFromExprList(m, op.getContext());
401+
};
398402
AffineExpr m, n, k;
399403
bindDims(rewriter.getContext(), m, n, k);
400404
static constexpr std::array<int64_t, 2> perm = {1, 0};

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,9 @@ AffineApplyOp
11451145
mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
11461146
ArrayRef<OpFoldResult> operands) {
11471147
return makeComposedAffineApply(
1148-
b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(),
1148+
b, loc,
1149+
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}, b.getContext())
1150+
.front(),
11491151
operands);
11501152
}
11511153

@@ -1220,7 +1222,9 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12201222
AffineExpr expr,
12211223
ArrayRef<OpFoldResult> operands) {
12221224
return makeComposedFoldedAffineApply(
1223-
b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
1225+
b, loc,
1226+
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}, b.getContext())
1227+
.front(),
12241228
operands);
12251229
}
12261230

mlir/lib/Dialect/Linalg/Transforms/Split.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
8383
bindDims(rewriter.getContext(), d0, d1, d2);
8484
OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin(
8585
rewriter, op.getLoc(),
86-
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2}).front(),
86+
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2},
87+
rewriter.getContext())
88+
.front(),
8789
{splitPoint, offsets[dimension], sizes[dimension]});
8890

8991
// Compute the size of the second part. Return early if the second part would

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,20 +670,21 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
670670
<< ": make sure in bound with affine.min\n");
671671

672672
AffineExpr dim0, dim1, dim2;
673-
bindDims(builder.getContext(), dim0, dim1, dim2);
673+
MLIRContext *context = builder.getContext();
674+
bindDims(context, dim0, dim1, dim2);
674675

675676
// Get the dimension size for this dimension. We need to first calculate
676677
// the max index and then plus one. This is important because for
677678
// convolution ops, we have its input window dimension's affine map of the
678679
// form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
679680
// dimension and `s0` is stride. Directly use the dimension size of
680681
// output/filer window dimensions will cause incorrect calculation.
681-
AffineMap minusOneMap =
682-
AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
683-
.front();
684-
AffineMap plusOneMap =
685-
AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
686-
.front();
682+
AffineMap minusOneMap = AffineMap::inferFromExprList(
683+
{ArrayRef<AffineExpr>{dim0 - 1}}, context)
684+
.front();
685+
AffineMap plusOneMap = AffineMap::inferFromExprList(
686+
{ArrayRef<AffineExpr>{dim0 + 1}}, context)
687+
.front();
687688
SmallVector<OpFoldResult> maxIndices =
688689
llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
689690
return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
@@ -696,7 +697,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
696697

697698
// Compute min(dim - offset, size) to avoid out-of-bounds accesses.
698699
AffineMap minMap = AffineMap::inferFromExprList(
699-
{ArrayRef<AffineExpr>{dim1 - dim2, dim0}})
700+
{ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
700701
.front();
701702
size =
702703
makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,9 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
12631263
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
12641264

12651265
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1266-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1266+
auto infer = [&](MapList m) {
1267+
return AffineMap::inferFromExprList(m, op.getContext());
1268+
};
12671269
AffineExpr i, j, k;
12681270
bindDims(getContext(), i, j, k);
12691271

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,9 +675,10 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
675675
ArrayRef<IteratorType> iteratorTypes) {
676676
result.addOperands({lhs, rhs, acc});
677677
result.addTypes(acc.getType());
678-
result.addAttribute(getIndexingMapsAttrName(result.name),
679-
builder.getAffineMapArrayAttr(
680-
AffineMap::inferFromExprList(indexingExprs)));
678+
result.addAttribute(
679+
getIndexingMapsAttrName(result.name),
680+
builder.getAffineMapArrayAttr(
681+
AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
681682
result.addAttribute(
682683
getIteratorTypesAttrName(result.name),
683684
builder.getArrayAttr(llvm::to_vector(llvm::map_range(

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,9 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
695695
Value lhs = op.getLhs(), rhs = op.getRhs();
696696

697697
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
698-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
698+
auto infer = [&](MapList m) {
699+
return AffineMap::inferFromExprList(m, op.getContext());
700+
};
699701
AffineExpr m, n, k;
700702
bindDims(rewriter.getContext(), m, n, k);
701703
SmallVector<AffineMap> maps = op.getIndexingMapsArray();

mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
209209
AffineExpr i, j, k;
210210
bindDims(xferOp.getContext(), i, j, k);
211211
SmallVector<AffineMap, 4> maps =
212-
AffineMap::inferFromExprList(MapList{{i - j, k}});
212+
AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
213213
// affine_min(%dimMemRef - %index, %dimAlloc)
214214
Value affineMin = b.create<affine::AffineMinOp>(
215215
loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ struct MultiReduceToContract
160160
iteratorTypes.push_back(vector::IteratorType::reduction);
161161
}
162162
}
163-
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
164-
/*symCount=*/0, exprs, reduceOp.getContext());
163+
auto dstMap =
164+
AffineMap::get(/*dimCount=*/reductionMask.size(),
165+
/*symbolCount=*/0, exprs, reduceOp.getContext());
165166
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
166167
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
167168
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
@@ -1399,7 +1400,9 @@ struct CanonicalizeContractMatmulToMMT final
13991400

14001401
// Set up the parallel/reduction structure in right form.
14011402
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1402-
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1403+
auto infer = [&](MapList m) {
1404+
return AffineMap::inferFromExprList(m, op.getContext());
1405+
};
14031406
AffineExpr m;
14041407
AffineExpr n;
14051408
AffineExpr k;

mlir/lib/IR/AffineMap.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,16 @@ AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
272272
return result;
273273
}
274274

275+
/// Creates an affine map each for each list of AffineExpr's in `exprsList`
276+
/// while inferring the right number of dimensional and symbolic inputs needed
277+
/// based on the maximum dimensional and symbolic identifier appearing in the
278+
/// expressions.
275279
template <typename AffineExprContainer>
276280
static SmallVector<AffineMap, 4>
277-
inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
278-
assert(!exprsList.empty());
279-
assert(!exprsList[0].empty());
280-
auto context = exprsList[0][0].getContext();
281+
inferFromExprList(ArrayRef<AffineExprContainer> exprsList,
282+
MLIRContext *context) {
283+
if (exprsList.empty())
284+
return {};
281285
int64_t maxDim = -1, maxSym = -1;
282286
getMaxDimAndSymbol(exprsList, maxDim, maxSym);
283287
SmallVector<AffineMap, 4> maps;
@@ -289,13 +293,15 @@ inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
289293
}
290294

291295
SmallVector<AffineMap, 4>
292-
AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
293-
return ::inferFromExprList(exprsList);
296+
AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList,
297+
MLIRContext *context) {
298+
return ::inferFromExprList(exprsList, context);
294299
}
295300

296301
SmallVector<AffineMap, 4>
297-
AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
298-
return ::inferFromExprList(exprsList);
302+
AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList,
303+
MLIRContext *context) {
304+
return ::inferFromExprList(exprsList, context);
299305
}
300306

301307
uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() {
@@ -521,7 +527,7 @@ AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
521527
newResults.reserve(getNumResults());
522528
for (AffineExpr e : getResults())
523529
newResults.push_back(e.replace(map));
524-
return AffineMap::inferFromExprList(newResults).front();
530+
return AffineMap::inferFromExprList(newResults, getContext()).front();
525531
}
526532

527533
AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
921921
return getAffineConstantExpr(0, context);
922922

923923
assert(!exprs.empty() && "expected exprs");
924-
auto maps = AffineMap::inferFromExprList(exprs);
924+
auto maps = AffineMap::inferFromExprList(exprs, context);
925925
assert(!maps.empty() && "Expected one non-empty map");
926926
unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
927927

mlir/unittests/IR/AffineMapTest.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- AffineMapTest.cpp - unit tests for affine map API ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/IR/AffineMap.h"
10+
#include "mlir/IR/Builders.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace mlir;
14+
15+
// Test AffineMap replace API for the zero result case.
16+
TEST(AffineMapTest, inferMapFromAffineExprs) {
17+
MLIRContext ctx;
18+
OpBuilder b(&ctx);
19+
AffineMap map = b.getEmptyAffineMap();
20+
DenseMap<AffineExpr, AffineExpr> replacements;
21+
map.replace(replacements);
22+
EXPECT_EQ(map, map);
23+
}

mlir/unittests/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIRIRTests
22
AdaptorTest.cpp
3+
AffineMapTest.cpp
34
AttributeTest.cpp
45
DialectTest.cpp
56
InterfaceTest.cpp

0 commit comments

Comments
 (0)