Skip to content

[mlir][affine]if the result of a Pure operation that whose operands are dimensional identifiers,then their results are dimensional identifiers. #123542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/docs/Dialects/Affine.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ location of the SSA use. Dimensions may be bound not only to anything that a
symbol is bound to, but also to induction variables of enclosing
[`affine.for`](#affinefor-mliraffineforop) and
[`affine.parallel`](#affineparallel-mliraffineparallelop) operations, and the result
of an [`affine.apply` operation](#affineapply-mliraffineapplyop) (which recursively
may use other dimensions and symbols).
of a `Pure` operation whose operands are valid dimensional identifiers.
(which recursively may use other dimensions and symbols).

### Affine Expressions

Expand Down
22 changes: 14 additions & 8 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ Region *mlir::affine::getAffineScope(Operation *op) {
// conditions:
// *) It is valid as a symbol.
// *) It is an induction variable.
// *) It is the result of affine apply operation with dimension id arguments.
// *) It is the result of a `Pure` operation whose operands with dimension id
// *) arguments.
bool mlir::affine::isValidDim(Value value) {
// The value must be an index type.
if (!value.getType().isIndex())
Expand All @@ -294,7 +295,8 @@ bool mlir::affine::isValidDim(Value value) {
// conditions:
// *) It is valid as a symbol.
// *) It is an induction variable.
// *) It is the result of an affine apply operation with dimension id operands.
// *) It is the result of a `Pure` operation whose operands with dimension id
// *) operands.
bool mlir::affine::isValidDim(Value value, Region *region) {
// The value must be an index type.
if (!value.getType().isIndex())
Expand All @@ -304,20 +306,24 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
if (isValidSymbol(value, region))
return true;

auto *op = value.getDefiningOp();
if (!op) {
auto *defOp = value.getDefiningOp();
if (!defOp) {
// This value has to be a block argument for an affine.for or an
// affine.parallel.
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return isa<AffineForOp, AffineParallelOp>(parentOp);
}

// Affine apply operation is ok if all of its operands are ok.
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
return applyOp.isValidDim(region);
// `Pure` operation is ok if all of its operands are ok.
if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
return affine::isValidDim(operand, region);
})) {
return true;
}

// The dim op is okay if its operand memref/tensor is defined at the top
// level.
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
return isTopLevelValue(dimOp.getShapedValue());
return false;
}
Expand Down
12 changes: 0 additions & 12 deletions mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,6 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {

// -----

func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
affine.for %x = 0 to 7 {
%y = arith.addi %x, %x : index
// expected-error@+1 {{operand cannot be used as a dimension id}}
affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
}
}
return
}

// -----
Comment on lines -228 to -238
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't just remove these tests. They are still valid. They should be updated to still emit the message, e.g., by using another operation instead of addi that is not Pure.

Copy link
Member Author

@linuxlonelyeagle linuxlonelyeagle Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't delete the tests, I moved them to the positive tests.You can look at Affine/ops.mlir.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arith.addi is Pure, you can see below and https://mlir.llvm.org/docs/Dialects/ArithOps/#arithandi-arithandiop

def Pure : TraitList<[AlwaysSpeculatable, NoMemoryEffect]>;

// arith.addi
Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, Idempotent, SameOperandsAndResultType, Scalarizable, Tensorizable, Vectorizable

Interfaces: ConditionallySpeculatable, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), VectorUnrollOpInterface

Effects: MemoryEffects::Effect{}


func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
affine.for %x = 0 to 7 {
%y = arith.addi %x, %x : index
Expand Down
92 changes: 0 additions & 92 deletions mlir/test/Dialect/Affine/load-store-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,6 @@ func.func @store_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %v

// -----

func.func @load_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = arith.muli %i0, %arg0 : index
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
%v = affine.load %0[%1] : memref<10xf32>
}
return
}

// -----

func.func @store_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
%1 = arith.constant 11.0 : f32
affine.for %i0 = 0 to 10 {
%2 = arith.muli %i0, %arg0 : index
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
affine.store %1, %0[%2] : memref<10xf32>
}
return
}

// -----

func.func @invalid_prefetch_rw(%i : index) {
%0 = memref.alloc() : memref<10xf32>
// expected-error@+1 {{rw specifier has to be 'read' or 'write'}}
Expand All @@ -73,70 +48,3 @@ func.func @invalid_prefetch_cache_type(%i : index) {
affine.prefetch %0[%i], read, locality<0>, false : memref<10xf32>
return
}

// -----

func.func @dma_start_non_affine_src_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
// expected-error@+1 {{op src index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
return
}

// -----

func.func @dma_start_non_affine_dst_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
// expected-error@+1 {{op dst index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
return
}

// -----

func.func @dma_start_non_affine_tag_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
// expected-error@+1 {{op tag index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
return
}

// -----

func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
}
return
}
174 changes: 174 additions & 0 deletions mlir/test/Dialect/Affine/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,177 @@ func.func @arith_add_vaild_symbol_lower_bound(%arg : index) {
// CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_0]](%[[VAL_2]]){{\[}}%[[VAL_0]]] to 7 {
// CHECK: }
// CHECK: }

// -----

// CHECK-LABEL: func @affine_parallel

func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
affine.for %x = 0 to 7 {
%y = arith.addi %x, %x : index
affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
}
}
return
}

// CHECK-NEXT: affine.for
// CHECK-SAME: %[[VAL_0:.*]] = 0 to 7 {
// CHECK: %[[VAL_1:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (0, 0) to (%[[VAL_1]], 100) step (10, 10) {
// CHECK: }
// CHECK: }

// -----

func.func @load_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = arith.muli %i0, %arg0 : index
%v = affine.load %0[%1] : memref<10xf32>
}
return
}

// CHECK-LABEL: func @load_non_affine_index
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
// CHECK: affine.for %[[VAL_2:.*]] = 0 to 10 {
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_0]] : index
// CHECK: %{{.*}} = affine.load %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref<10xf32>
// CHECK: }

// -----

func.func @store_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
%1 = arith.constant 11.0 : f32
affine.for %i0 = 0 to 10 {
%2 = arith.muli %i0, %arg0 : index
affine.store %1, %0[%2] : memref<10xf32>
}
return
}

// CHECK-LABEL: func @store_non_affine_index
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
// CHECK: %[[VAL_2:.*]] = arith.constant 1.100000e+01 : f32
// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
// CHECK: affine.store %[[VAL_2]], %[[VAL_1]]{{\[}}%[[VAL_4]]] : memref<10xf32>
// CHECK: }

// -----

func.func @dma_start_non_affine_src_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
return
}

// CHECK-LABEL: func @dma_start_non_affine_src_index
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_7]]], %[[VAL_2]]{{\[}}%[[VAL_6]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
// CHECK: }

// -----

func.func @dma_start_non_affine_dst_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
return
}

// CHECK-LABEL: func @dma_start_non_affine_dst_index
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_6]]], %[[VAL_2]]{{\[}}%[[VAL_7]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
// CHECK: }

// -----

func.func @dma_start_non_affine_tag_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
return
}

// CHECK-LABEL: func @dma_start_non_affine_tag_index
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
// CHECK: %{{.*}} = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
// CHECK: affine.for %[[VAL_5:.*]] = 0 to 10 {
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_0]] : index
// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_5]]], %[[VAL_2]]{{\[}}%[[VAL_0]]], %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_4]]
// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
// CHECK: }

// -----

func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
%0 = memref.alloc() : memref<100xf32>
%1 = memref.alloc() : memref<100xf32, 2>
%2 = memref.alloc() : memref<1xi32, 4>
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
}
return
}

// CHECK-LABEL: func @dma_wait_non_affine_tag_index
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %{{.*}} = memref.alloc() : memref<100xf32>
// CHECK: %{{.*}} = memref.alloc() : memref<100xf32, 2>
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<1xi32, 4>
// CHECK: %{{.*}} = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 64 : index
// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
// CHECK: affine.dma_wait %[[VAL_1]]{{\[}}%[[VAL_4]]], %[[VAL_2]] : memref<1xi32, 4>
// CHECK: }
Loading