Skip to content

Commit 666b722

Browse files
if the result of a Pure operation that whose operands are dimensional identifiers,then their results are dimensional identifiers.
1 parent 814b34f commit 666b722

File tree

5 files changed

+190
-114
lines changed

5 files changed

+190
-114
lines changed

mlir/docs/Dialects/Affine.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ location of the SSA use. Dimensions may be bound not only to anything that a
8383
symbol is bound to, but also to induction variables of enclosing
8484
[`affine.for`](#affinefor-mliraffineforop) and
8585
[`affine.parallel`](#affineparallel-mliraffineparallelop) operations, and the result
86-
of an [`affine.apply` operation](#affineapply-mliraffineapplyop) (which recursively
87-
may use other dimensions and symbols).
86+
of a `Pure` operation whose operands are valid dimensional identifiers.
87+
(which recursively may use other dimensions and symbols).
8888

8989
### Affine Expressions
9090

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ Region *mlir::affine::getAffineScope(Operation *op) {
274274
// conditions:
275275
// *) It is valid as a symbol.
276276
// *) It is an induction variable.
277-
// *) It is the result of affine apply operation with dimension id arguments.
277+
// *) It is the result of a `Pure` operation whose operands with dimension id
278+
// *) arguments.
278279
bool mlir::affine::isValidDim(Value value) {
279280
// The value must be an index type.
280281
if (!value.getType().isIndex())
@@ -294,7 +295,8 @@ bool mlir::affine::isValidDim(Value value) {
294295
// conditions:
295296
// *) It is valid as a symbol.
296297
// *) It is an induction variable.
297-
// *) It is the result of an affine apply operation with dimension id operands.
298+
// *) It is the result of a `Pure` operation whose operands with dimension id
299+
// *) operands.
298300
bool mlir::affine::isValidDim(Value value, Region *region) {
299301
// The value must be an index type.
300302
if (!value.getType().isIndex())
@@ -304,20 +306,24 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
304306
if (isValidSymbol(value, region))
305307
return true;
306308

307-
auto *op = value.getDefiningOp();
308-
if (!op) {
309+
auto *defOp = value.getDefiningOp();
310+
if (!defOp) {
309311
// This value has to be a block argument for an affine.for or an
310312
// affine.parallel.
311313
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
312314
return isa<AffineForOp, AffineParallelOp>(parentOp);
313315
}
314316

315-
// Affine apply operation is ok if all of its operands are ok.
316-
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
317-
return applyOp.isValidDim(region);
317+
// `Pure` operation is ok if all of its operands are ok.
318+
if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
319+
return affine::isValidDim(operand, region);
320+
})) {
321+
return true;
322+
}
323+
318324
// The dim op is okay if its operand memref/tensor is defined at the top
319325
// level.
320-
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
326+
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
321327
return isTopLevelValue(dimOp.getShapedValue());
322328
return false;
323329
}

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,18 +225,6 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
225225

226226
// -----
227227

228-
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
229-
affine.for %x = 0 to 7 {
230-
%y = arith.addi %x, %x : index
231-
// expected-error@+1 {{operand cannot be used as a dimension id}}
232-
affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
233-
}
234-
}
235-
return
236-
}
237-
238-
// -----
239-
240228
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
241229
affine.for %x = 0 to 7 {
242230
%y = arith.addi %x, %x : index

mlir/test/Dialect/Affine/load-store-invalid.mlir

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,6 @@ func.func @store_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %v
3333

3434
// -----
3535

36-
func.func @load_non_affine_index(%arg0 : index) {
37-
%0 = memref.alloc() : memref<10xf32>
38-
affine.for %i0 = 0 to 10 {
39-
%1 = arith.muli %i0, %arg0 : index
40-
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
41-
%v = affine.load %0[%1] : memref<10xf32>
42-
}
43-
return
44-
}
45-
46-
// -----
47-
48-
func.func @store_non_affine_index(%arg0 : index) {
49-
%0 = memref.alloc() : memref<10xf32>
50-
%1 = arith.constant 11.0 : f32
51-
affine.for %i0 = 0 to 10 {
52-
%2 = arith.muli %i0, %arg0 : index
53-
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
54-
affine.store %1, %0[%2] : memref<10xf32>
55-
}
56-
return
57-
}
58-
59-
// -----
60-
6136
func.func @invalid_prefetch_rw(%i : index) {
6237
%0 = memref.alloc() : memref<10xf32>
6338
// expected-error@+1 {{rw specifier has to be 'read' or 'write'}}
@@ -73,70 +48,3 @@ func.func @invalid_prefetch_cache_type(%i : index) {
7348
affine.prefetch %0[%i], read, locality<0>, false : memref<10xf32>
7449
return
7550
}
76-
77-
// -----
78-
79-
func.func @dma_start_non_affine_src_index(%arg0 : index) {
80-
%0 = memref.alloc() : memref<100xf32>
81-
%1 = memref.alloc() : memref<100xf32, 2>
82-
%2 = memref.alloc() : memref<1xi32, 4>
83-
%c0 = arith.constant 0 : index
84-
%c64 = arith.constant 64 : index
85-
affine.for %i0 = 0 to 10 {
86-
%3 = arith.muli %i0, %arg0 : index
87-
// expected-error@+1 {{op src index must be a valid dimension or symbol identifier}}
88-
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
89-
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
90-
}
91-
return
92-
}
93-
94-
// -----
95-
96-
func.func @dma_start_non_affine_dst_index(%arg0 : index) {
97-
%0 = memref.alloc() : memref<100xf32>
98-
%1 = memref.alloc() : memref<100xf32, 2>
99-
%2 = memref.alloc() : memref<1xi32, 4>
100-
%c0 = arith.constant 0 : index
101-
%c64 = arith.constant 64 : index
102-
affine.for %i0 = 0 to 10 {
103-
%3 = arith.muli %i0, %arg0 : index
104-
// expected-error@+1 {{op dst index must be a valid dimension or symbol identifier}}
105-
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
106-
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
107-
}
108-
return
109-
}
110-
111-
// -----
112-
113-
func.func @dma_start_non_affine_tag_index(%arg0 : index) {
114-
%0 = memref.alloc() : memref<100xf32>
115-
%1 = memref.alloc() : memref<100xf32, 2>
116-
%2 = memref.alloc() : memref<1xi32, 4>
117-
%c0 = arith.constant 0 : index
118-
%c64 = arith.constant 64 : index
119-
affine.for %i0 = 0 to 10 {
120-
%3 = arith.muli %i0, %arg0 : index
121-
// expected-error@+1 {{op tag index must be a valid dimension or symbol identifier}}
122-
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
123-
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
124-
}
125-
return
126-
}
127-
128-
// -----
129-
130-
func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
131-
%0 = memref.alloc() : memref<100xf32>
132-
%1 = memref.alloc() : memref<100xf32, 2>
133-
%2 = memref.alloc() : memref<1xi32, 4>
134-
%c0 = arith.constant 0 : index
135-
%c64 = arith.constant 64 : index
136-
affine.for %i0 = 0 to 10 {
137-
%3 = arith.muli %i0, %arg0 : index
138-
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
139-
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
140-
}
141-
return
142-
}

mlir/test/Dialect/Affine/ops.mlir

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,177 @@ func.func @arith_add_vaild_symbol_lower_bound(%arg : index) {
409409
// CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_0]](%[[VAL_2]]){{\[}}%[[VAL_0]]] to 7 {
410410
// CHECK: }
411411
// CHECK: }
412+
413+
// -----
414+
415+
// CHECK-LABEL: func @affine_parallel
416+
417+
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
418+
affine.for %x = 0 to 7 {
419+
%y = arith.addi %x, %x : index
420+
affine.parallel (%i, %j) = (0, 0) to (%y, 100) step (10, 10) {
421+
}
422+
}
423+
return
424+
}
425+
426+
// CHECK-NEXT: affine.for
427+
// CHECK-SAME: %[[VAL_0:.*]] = 0 to 7 {
428+
// CHECK: %[[VAL_1:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
429+
// CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (0, 0) to (%[[VAL_1]], 100) step (10, 10) {
430+
// CHECK: }
431+
// CHECK: }
432+
433+
// -----
434+
435+
func.func @load_non_affine_index(%arg0 : index) {
436+
%0 = memref.alloc() : memref<10xf32>
437+
affine.for %i0 = 0 to 10 {
438+
%1 = arith.muli %i0, %arg0 : index
439+
%v = affine.load %0[%1] : memref<10xf32>
440+
}
441+
return
442+
}
443+
444+
// CHECK-LABEL: func @load_non_affine_index
445+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
446+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
447+
// CHECK: affine.for %[[VAL_2:.*]] = 0 to 10 {
448+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_0]] : index
449+
// CHECK: %{{.*}} = affine.load %[[VAL_1]]{{\[}}%[[VAL_3]]] : memref<10xf32>
450+
// CHECK: }
451+
452+
// -----
453+
454+
func.func @store_non_affine_index(%arg0 : index) {
455+
%0 = memref.alloc() : memref<10xf32>
456+
%1 = arith.constant 11.0 : f32
457+
affine.for %i0 = 0 to 10 {
458+
%2 = arith.muli %i0, %arg0 : index
459+
affine.store %1, %0[%2] : memref<10xf32>
460+
}
461+
return
462+
}
463+
464+
// CHECK-LABEL: func @store_non_affine_index
465+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
466+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<10xf32>
467+
// CHECK: %[[VAL_2:.*]] = arith.constant 1.100000e+01 : f32
468+
// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
469+
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
470+
// CHECK: affine.store %[[VAL_2]], %[[VAL_1]]{{\[}}%[[VAL_4]]] : memref<10xf32>
471+
// CHECK: }
472+
473+
// -----
474+
475+
func.func @dma_start_non_affine_src_index(%arg0 : index) {
476+
%0 = memref.alloc() : memref<100xf32>
477+
%1 = memref.alloc() : memref<100xf32, 2>
478+
%2 = memref.alloc() : memref<1xi32, 4>
479+
%c0 = arith.constant 0 : index
480+
%c64 = arith.constant 64 : index
481+
affine.for %i0 = 0 to 10 {
482+
%3 = arith.muli %i0, %arg0 : index
483+
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
484+
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
485+
}
486+
return
487+
}
488+
489+
// CHECK-LABEL: func @dma_start_non_affine_src_index
490+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
491+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
492+
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
493+
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
494+
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
495+
// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
496+
// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
497+
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
498+
// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_7]]], %[[VAL_2]]{{\[}}%[[VAL_6]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
499+
// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
500+
// CHECK: }
501+
502+
// -----
503+
504+
func.func @dma_start_non_affine_dst_index(%arg0 : index) {
505+
%0 = memref.alloc() : memref<100xf32>
506+
%1 = memref.alloc() : memref<100xf32, 2>
507+
%2 = memref.alloc() : memref<1xi32, 4>
508+
%c0 = arith.constant 0 : index
509+
%c64 = arith.constant 64 : index
510+
affine.for %i0 = 0 to 10 {
511+
%3 = arith.muli %i0, %arg0 : index
512+
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
513+
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
514+
}
515+
return
516+
}
517+
518+
// CHECK-LABEL: func @dma_start_non_affine_dst_index
519+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
520+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
521+
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
522+
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
523+
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
524+
// CHECK: %[[VAL_5:.*]] = arith.constant 64 : index
525+
// CHECK: affine.for %[[VAL_6:.*]] = 0 to 10 {
526+
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_0]] : index
527+
// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_6]]], %[[VAL_2]]{{\[}}%[[VAL_7]]], %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_5]]
528+
// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
529+
// CHECK: }
530+
531+
// -----
532+
533+
func.func @dma_start_non_affine_tag_index(%arg0 : index) {
534+
%0 = memref.alloc() : memref<100xf32>
535+
%1 = memref.alloc() : memref<100xf32, 2>
536+
%2 = memref.alloc() : memref<1xi32, 4>
537+
%c0 = arith.constant 0 : index
538+
%c64 = arith.constant 64 : index
539+
affine.for %i0 = 0 to 10 {
540+
%3 = arith.muli %i0, %arg0 : index
541+
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
542+
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
543+
}
544+
return
545+
}
546+
547+
// CHECK-LABEL: func @dma_start_non_affine_tag_index
548+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
549+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<100xf32>
550+
// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<100xf32, 2>
551+
// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<1xi32, 4>
552+
// CHECK: %{{.*}} = arith.constant 0 : index
553+
// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index
554+
// CHECK: affine.for %[[VAL_5:.*]] = 0 to 10 {
555+
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_0]] : index
556+
// CHECK: affine.dma_start %[[VAL_1]]{{\[}}%[[VAL_5]]], %[[VAL_2]]{{\[}}%[[VAL_0]]], %[[VAL_3]]{{\[}}%[[VAL_6]]], %[[VAL_4]]
557+
// CHECK-SAME: : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
558+
// CHECK: }
559+
560+
// -----
561+
562+
func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
563+
%0 = memref.alloc() : memref<100xf32>
564+
%1 = memref.alloc() : memref<100xf32, 2>
565+
%2 = memref.alloc() : memref<1xi32, 4>
566+
%c0 = arith.constant 0 : index
567+
%c64 = arith.constant 64 : index
568+
affine.for %i0 = 0 to 10 {
569+
%3 = arith.muli %i0, %arg0 : index
570+
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
571+
}
572+
return
573+
}
574+
575+
// CHECK-LABEL: func @dma_wait_non_affine_tag_index
576+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
577+
// CHECK: %{{.*}} = memref.alloc() : memref<100xf32>
578+
// CHECK: %{{.*}} = memref.alloc() : memref<100xf32, 2>
579+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<1xi32, 4>
580+
// CHECK: %{{.*}} = arith.constant 0 : index
581+
// CHECK: %[[VAL_2:.*]] = arith.constant 64 : index
582+
// CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
583+
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_0]] : index
584+
// CHECK: affine.dma_wait %[[VAL_1]]{{\[}}%[[VAL_4]]], %[[VAL_2]] : memref<1xi32, 4>
585+
// CHECK: }

0 commit comments

Comments
 (0)