Skip to content

Commit 7b26cbe

Browse files
committed
[MLIR][Affine] Add functionality to demote invalid symbols to dims
Fixes: #120189 Introduce a method to demote a symoblic operand to a dimensional one (the inverse of the current canonicalizePromotedSymbols). Demote operands that could/should have been valid affine dimensional values (affine loop IVs or their functions) from symbols to dims. This is a general method that can be used to legalize a map + operands post construction depending on its operands. In several cases, affine.apply operands that could be dims for the purpose of affine analysis were being used in symbolic positions undetected. Fix the verifier so that affine dim-only SSA values aren't used in symbolic positions; otherwise, it was possible for `-canonicalize` to have generated invalid IR. This doesn't affect other users in other context where the operands were neither valid dims or symbols (for eg. in scf.for or other region ops). In some cases, this change also leads to better simplified operands, duplicates eliminated as shown in one of the test cases where the same operand appeared as a symbol and as a dim. This PR also fixes test cases where dimensional positions should have been ideally used with affine.apply (for affine loop IVs for example). For several use cases, the IR builder/user wouldn't need to worry about dim/symbols since canonicalizeMapOrSetAndOperands is updated to transparently take care of it to ensure valid IR. Users outside of affine analyses/dialects remain unaffected.
1 parent 86ef031 commit 7b26cbe

File tree

5 files changed

+99
-26
lines changed

5 files changed

+99
-26
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,15 @@ LogicalResult AffineApplyOp::verify() {
568568
if (affineMap.getNumResults() != 1)
569569
return emitOpError("mapping must produce one value");
570570

571+
// Do not allow valid dims to be used in symbol positions. We do allow
572+
// affine.apply to use operands for values that may neither qualify as affine
573+
// dims or affine symbols due to usage outside of affine ops, analyses, etc.
574+
Region *region = getAffineScope(*this);
575+
for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
576+
if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
577+
return emitError("dimensional operand cannot be used as a symbol");
578+
}
579+
571580
return success();
572581
}
573582

@@ -1351,13 +1360,62 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
13511360

13521361
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
13531362
*operands = resultOperands;
1354-
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1355-
oldNumSyms + nextSym);
1363+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1364+
dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
13561365

13571366
assert(mapOrSet->getNumInputs() == operands->size() &&
13581367
"map/set inputs must match number of operands");
13591368
}
13601369

1370+
// A valid affine dimension may appear as a symbol in affine.apply operations.
1371+
// This function canonicalizes symbols that are valid dims, but not valid
1372+
// symbols into actual dims. Without such a legalization, the affine.apply will
1373+
// be invalid. This method is the exact inverse of canonicalizePromotedSymbols.
1374+
template <class MapOrSet>
1375+
static void legalizeDemotedDims(MapOrSet *mapOrSet,
1376+
SmallVectorImpl<Value> &operands) {
1377+
if (!mapOrSet || operands.empty())
1378+
return;
1379+
1380+
assert(mapOrSet->getNumInputs() == operands.size() &&
1381+
"map/set inputs must match number of operands");
1382+
1383+
auto *context = mapOrSet->getContext();
1384+
SmallVector<Value, 8> resultOperands;
1385+
resultOperands.reserve(operands.size());
1386+
SmallVector<Value, 8> remappedDims;
1387+
remappedDims.reserve(operands.size());
1388+
unsigned nextSym = 0;
1389+
unsigned nextDim = 0;
1390+
unsigned oldNumDims = mapOrSet->getNumDims();
1391+
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1392+
for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1393+
if (i >= oldNumDims) {
1394+
if (operands[i] && isValidDim(operands[i]) &&
1395+
!isValidSymbol(operands[i])) {
1396+
// This is a valid dim that appears as a symbol, legalize it.
1397+
symRemapping[i - oldNumDims] =
1398+
getAffineDimExpr(oldNumDims + nextDim++, context);
1399+
remappedDims.push_back(operands[i]);
1400+
} else {
1401+
symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1402+
resultOperands.push_back(operands[i]);
1403+
}
1404+
} else {
1405+
resultOperands.push_back(operands[i]);
1406+
}
1407+
}
1408+
1409+
resultOperands.insert(resultOperands.begin() + oldNumDims,
1410+
remappedDims.begin(), remappedDims.end());
1411+
operands = resultOperands;
1412+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1413+
/*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1414+
1415+
assert(mapOrSet->getNumInputs() == operands.size() &&
1416+
"map/set inputs must match number of operands");
1417+
}
1418+
13611419
// Works for either an affine map or an integer set.
13621420
template <class MapOrSet>
13631421
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
@@ -1372,6 +1430,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
13721430
"map/set inputs must match number of operands");
13731431

13741432
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1433+
legalizeDemotedDims<MapOrSet>(mapOrSet, *operands);
13751434

13761435
// Check to see what dims are used.
13771436
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
14601460
func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14611461
// CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
14621462
affine.for %arg3 = 0 to 8 {
1463-
%1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
1464-
// CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
1463+
%1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
1464+
// CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32>
14651465
affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
14661466
}
14671467
return

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,17 @@ func.func @no_upper_bound() {
563563
}
564564
return
565565
}
566+
567+
// -----
568+
569+
func.func @invalid_symbol() {
570+
affine.for %arg1 = 0 to 1 {
571+
affine.for %arg2 = 0 to 26 {
572+
affine.for %arg3 = 0 to 23 {
573+
affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3]
574+
// expected-error@above {{dimensional operand cannot be used as a symbol}}
575+
}
576+
}
577+
}
578+
return
579+
}

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
360360

361361
// -----
362362

363-
#map = affine_map<()[s0] -> (s0 + 5)>
364-
#map1 = affine_map<()[s0] -> (s0 + 17)>
363+
#map = affine_map<(d0) -> (d0 + 5)>
364+
#map1 = affine_map<(d0) -> (d0 + 17)>
365365

366366
// Test with non-int/float memref types.
367367

@@ -382,8 +382,8 @@ func.func @memref_index_type() {
382382
}
383383
affine.for %arg3 = 0 to 3 {
384384
%4 = affine.load %alloc_2[%arg3] : memref<3xindex>
385-
%5 = affine.apply #map()[%4]
386-
%6 = affine.apply #map1()[%3]
385+
%5 = affine.apply #map(%4)
386+
%6 = affine.apply #map1(%3)
387387
%7 = memref.load %alloc[%5, %6] : memref<8x18xf32>
388388
affine.store %7, %alloc_1[%arg3] : memref<3xf32>
389389
}

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
496496

497497
// -----
498498

499-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
499+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
500+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
501501
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502502
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503503
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
518518
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519519
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520520
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521-
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522-
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
521+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
522+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
523523
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524-
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
524+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
525525
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526526

527527
// -----
528528

529-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
529+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
530+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
531531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
532532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
533533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
549549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
550550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
551551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
552-
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553-
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
552+
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
553+
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
554554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
555555

556556
// -----
557557

558-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
558+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
559+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
560560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
561561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
562562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
578578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
579579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
580580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
581-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582-
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
581+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
582+
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
583583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
584584

585585
// -----
586586

587-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
587+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
588+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
589589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
590590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
591591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
608608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
609609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
610610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
611-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612-
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
611+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
612+
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
613613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
614614

615615
// -----
@@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
678678
// -----
679679

680680
// CHECK-LABEL: func @fold_store_keep_nontemporal(
681-
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
681+
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
682682
func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
683683
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
684684
memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>

0 commit comments

Comments
 (0)