Skip to content

Commit 5bb29d2

Browse files
committed
[MLIR][Affine] Fix affine.apply verifier and add functionality to demote invalid symbols to dims
Fixes: #120189, #128403 Fix affine.apply verifier to reject symbolic operands that are valid dims for affine purposes. This doesn't affect other users in other context where the operands were neither valid dims or symbols (for eg. in scf.for or other region ops). Otherwise, it was possible for `-canonicalize` to have generated invalid IR when such affine.apply ops were composed. Introduce a method to demote a symbolic operand to a dimensional one (the inverse of the current canonicalizePromotedSymbols). Demote operands that could/should have been valid affine dimensional values (affine loop IVs or their functions) from symbols to dims. This is a general method that can be used to legalize a map + operands post construction depending on its operands. Use it during `canonicalizeMapOrSetAndOperands` so that pattern rewriter-based passes are able to generate valid IR post folding. Users outside of affine analyses/dialects remain unaffected. In some cases, this change also leads to better simplified operands, duplicates eliminated as shown in one of the test cases where the same operand appeared as a symbol and as a dim. This PR also fixes test cases where dimensional positions should have been ideally used with affine.apply (for affine loop IVs for example).
1 parent 99ec6f8 commit 5bb29d2

File tree

4 files changed

+95
-22
lines changed

4 files changed

+95
-22
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ LogicalResult AffineApplyOp::verify() {
578578
if (affineMap.getNumResults() != 1)
579579
return emitOpError("mapping must produce one value");
580580

581+
// Do not allow valid dims to be used in symbol positions. We do allow
582+
// affine.apply to use operands for values that may neither qualify as affine
583+
// dims or affine symbols due to usage outside of affine ops, analyses, etc.
584+
Region *region = getAffineScope(*this);
585+
for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
586+
if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
587+
return emitError("dimensional operand cannot be used as a symbol");
588+
}
589+
581590
return success();
582591
}
583592

@@ -1359,13 +1368,62 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
13591368

13601369
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
13611370
*operands = resultOperands;
1362-
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1363-
oldNumSyms + nextSym);
1371+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1372+
dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
13641373

13651374
assert(mapOrSet->getNumInputs() == operands->size() &&
13661375
"map/set inputs must match number of operands");
13671376
}
13681377

1378+
// A valid affine dimension may appear as a symbol in affine.apply operations.
1379+
// This function canonicalizes symbols that are valid dims, but not valid
1380+
// symbols into actual dims. Without such a legalization, the affine.apply will
1381+
// be invalid. This method is the exact inverse of canonicalizePromotedSymbols.
1382+
template <class MapOrSet>
1383+
static void legalizeDemotedDims(MapOrSet *mapOrSet,
1384+
SmallVectorImpl<Value> &operands) {
1385+
if (!mapOrSet || operands.empty())
1386+
return;
1387+
1388+
assert(mapOrSet->getNumInputs() == operands.size() &&
1389+
"map/set inputs must match number of operands");
1390+
1391+
auto *context = mapOrSet->getContext();
1392+
SmallVector<Value, 8> resultOperands;
1393+
resultOperands.reserve(operands.size());
1394+
SmallVector<Value, 8> remappedDims;
1395+
remappedDims.reserve(operands.size());
1396+
unsigned nextSym = 0;
1397+
unsigned nextDim = 0;
1398+
unsigned oldNumDims = mapOrSet->getNumDims();
1399+
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1400+
for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1401+
if (i >= oldNumDims) {
1402+
if (operands[i] && isValidDim(operands[i]) &&
1403+
!isValidSymbol(operands[i])) {
1404+
// This is a valid dim that appears as a symbol, legalize it.
1405+
symRemapping[i - oldNumDims] =
1406+
getAffineDimExpr(oldNumDims + nextDim++, context);
1407+
remappedDims.push_back(operands[i]);
1408+
} else {
1409+
symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1410+
resultOperands.push_back(operands[i]);
1411+
}
1412+
} else {
1413+
resultOperands.push_back(operands[i]);
1414+
}
1415+
}
1416+
1417+
resultOperands.insert(resultOperands.begin() + oldNumDims,
1418+
remappedDims.begin(), remappedDims.end());
1419+
operands = resultOperands;
1420+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1421+
/*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1422+
1423+
assert(mapOrSet->getNumInputs() == operands.size() &&
1424+
"map/set inputs must match number of operands");
1425+
}
1426+
13691427
// Works for either an affine map or an integer set.
13701428
template <class MapOrSet>
13711429
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
@@ -1380,6 +1438,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
13801438
"map/set inputs must match number of operands");
13811439

13821440
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1441+
legalizeDemotedDims<MapOrSet>(mapOrSet, *operands);
13831442

13841443
// Check to see what dims are used.
13851444
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
14601460
func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14611461
// CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
14621462
affine.for %arg3 = 0 to 8 {
1463-
%1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
1464-
// CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
1463+
%1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
1464+
// CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32>
14651465
affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
14661466
}
14671467
return

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,17 @@ func.func @no_upper_bound() {
563563
}
564564
return
565565
}
566+
567+
// -----
568+
569+
func.func @invalid_symbol() {
570+
affine.for %arg1 = 0 to 1 {
571+
affine.for %arg2 = 0 to 26 {
572+
affine.for %arg3 = 0 to 23 {
573+
affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3]
574+
// expected-error@above {{dimensional operand cannot be used as a symbol}}
575+
}
576+
}
577+
}
578+
return
579+
}

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
496496

497497
// -----
498498

499-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
499+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
500+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
501501
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502502
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503503
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
518518
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519519
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520520
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521-
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522-
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
521+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
522+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
523523
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524-
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
524+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
525525
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526526

527527
// -----
528528

529-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
529+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
530+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
531531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
532532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
533533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
549549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
550550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
551551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
552-
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553-
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
552+
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
553+
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
554554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
555555

556556
// -----
557557

558-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
558+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
559+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
560560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
561561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
562562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
578578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
579579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
580580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
581-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582-
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
581+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
582+
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
583583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
584584

585585
// -----
586586

587-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
587+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
588+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
589589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
590590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
591591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
608608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
609609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
610610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
611-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612-
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
611+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
612+
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
613613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
614614

615615
// -----
@@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
678678
// -----
679679

680680
// CHECK-LABEL: func @fold_store_keep_nontemporal(
681-
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
681+
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
682682
func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
683683
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
684684
memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>

0 commit comments

Comments
 (0)