Skip to content

[MLIR][Affine] Fix affine.apply verifier and add functionality to demote invalid symbols to dims #128289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,25 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
let description = [{
The `affine.apply` operation applies an [affine mapping](#affine-maps)
to a list of SSA values, yielding a single SSA value. The number of
dimension and symbol arguments to `affine.apply` must be equal to the
dimension and symbol operands to `affine.apply` must be equal to the
respective number of dimensional and symbolic inputs to the affine mapping;
the affine mapping has to be one-dimensional, and so the `affine.apply`
operation always returns one value. The input operands and result must all
have ‘index’ type.

An operand that is a valid dimension as per the [rules on valid affine
dimensions and symbols](#restrictions-on-dimensions-and-symbols)
cannot be used as a symbolic operand.

Example:

```mlir
#map10 = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)>
#map = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)>
...
%1 = affine.apply #map10 (%s, %t)
%1 = affine.apply #map (%s, %t)

// Inline example.
%2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n]
%2 = affine.apply affine_map<(i)[s0] -> (i + s0)> (%42)[%n]
```
}];
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands);
Expand Down
65 changes: 63 additions & 2 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,15 @@ LogicalResult AffineApplyOp::verify() {
if (affineMap.getNumResults() != 1)
return emitOpError("mapping must produce one value");

// Do not allow valid dims to be used in symbol positions. We do allow
// affine.apply to use operands for values that may neither qualify as affine
// dims or affine symbols due to usage outside of affine ops, analyses, etc.
Region *region = getAffineScope(*this);
for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
return emitError("dimensional operand cannot be used as a symbol");
}

return success();
}

Expand Down Expand Up @@ -1359,13 +1368,64 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,

resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
*operands = resultOperands;
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
oldNumSyms + nextSym);
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);

assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
}

/// A valid affine dimension may appear as a symbol in affine.apply operations.
/// Given an application of `operands` to an affine map or integer set
/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
/// dims, but not valid symbols into actual dims. Without such a legalization,
/// the affine.apply will be invalid. This method is the exact inverse of
/// canonicalizePromotedSymbols.
template <class MapOrSet>
static void legalizeDemotedDims(MapOrSet &mapOrSet,
SmallVectorImpl<Value> &operands) {
if (!mapOrSet || operands.empty())
return;

unsigned numOperands = operands.size();

assert(mapOrSet->getNumInputs() == numOperands &&
"map/set inputs must match number of operands");

auto *context = mapOrSet.getContext();
SmallVector<Value, 8> resultOperands;
resultOperands.reserve(numOperands);
SmallVector<Value, 8> remappedDims;
remappedDims.reserve(numOperands);
SmallVector<Value, 8> symOperands;
symOperands.reserve(mapOrSet.getNumSymbols());
unsigned nextSym = 0;
unsigned nextDim = 0;
unsigned oldNumDims = mapOrSet.getNumDims();
SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
// This is a valid dim that appears as a symbol, legalize it.
symRemapping[i - oldNumDims] =
getAffineDimExpr(oldNumDims + nextDim++, context);
remappedDims.push_back(operands[i]);
} else {
symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
symOperands.push_back(operands[i]);
}
}

append_range(resultOperands, remappedDims);
append_range(resultOperands, symOperands);
operands = resultOperands;
mapOrSet = mapOrSet.replaceDimsAndSymbols(
/*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);

assert(mapOrSet->getNumInputs() == operands.size() &&
"map/set inputs must match number of operands");
}

// Works for either an affine map or an integer set.
template <class MapOrSet>
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
Expand All @@ -1380,6 +1440,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
"map/set inputs must match number of operands");

canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);

// Check to see what dims are used.
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
// CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
affine.for %arg3 = 0 to 8 {
%1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
// CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
%1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
// CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32>
affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
}
return
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,17 @@ func.func @no_upper_bound() {
}
return
}

// -----

func.func @invalid_symbol() {
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 26 {
affine.for %arg3 = 0 to 23 {
affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3]
// expected-error@above {{dimensional operand cannot be used as a symbol}}
}
}
}
return
}
36 changes: 18 additions & 18 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16

// -----

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
Expand All @@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>

// -----

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
Expand All @@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>

// -----

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
Expand All @@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>

// -----

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
Expand All @@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>

// -----
Expand Down Expand Up @@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
// -----

// CHECK-LABEL: func @fold_store_keep_nontemporal(
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
Expand Down
Loading