Skip to content

Commit d28a4f1

Browse files
[mlir][affine]introducing new symbol rules that the result of a Pure operation that whose operands are valid symbolic identifiers (llvm#118478)
introducing new symbol rules that the result of a Pure operation that whose operands are valid symbolic identifiers.
1 parent eff6b64 commit d28a4f1

File tree

11 files changed

+162
-119
lines changed

11 files changed

+162
-119
lines changed

mlir/docs/Dialects/Affine.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ immediately enclosed by the latter),
6969
3. a value that dominates the `AffineScope` op enclosing the value's
7070
use,
7171
4. the result of a constant operation,
72-
5. the result of an
73-
[`affine.apply` operation](#affineapply-mliraffineapplyop) that recursively takes as
74-
arguments any valid symbolic identifiers, or
72+
5. the result of a `Pure` operation whose operands are valid symbolic identifiers.
7573
6. the result of a
7674
[`dim` operation](MemRef.md/#memrefdim-mlirmemrefdimop) on either a memref that
7775
is an argument to a `AffineScope` op or a memref where the corresponding

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ bool mlir::affine::isValidSymbol(Value value) {
410410
/// A value can be used as a symbol for `region` iff it meets one of the
411411
/// following conditions:
412412
/// *) It is a constant.
413-
/// *) It is the result of an affine apply operation with symbol arguments.
413+
/// *) It is a result of a `Pure` operation whose operands are valid symbolic
414+
/// *) identifiers.
414415
/// *) It is a result of the dim op on a memref whose corresponding size is
415416
/// a valid symbol.
416417
/// *) It is defined at the top level of 'region' or is its argument.
@@ -443,9 +444,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
443444
if (matchPattern(defOp, m_Constant(&operandCst)))
444445
return true;
445446

446-
// Affine apply operation is ok if all of its operands are ok.
447-
if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
448-
return applyOp.isValidSymbol(region);
447+
// `Pure` operation that whose operands are valid symbolic identifiers.
448+
if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
449+
return affine::isValidSymbol(operand, region);
450+
})) {
451+
return true;
452+
}
449453

450454
// Dim op results could be valid symbols at any level.
451455
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))

mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,13 +638,13 @@ func.func @vecdim_reduction_complex_ub(%in: memref<256x512xf32>, %out: memref<25
638638
return
639639
}
640640

641-
// CHECK: #[[$map3:.*]] = affine_map<([[d0:.*]], [[d1:.*]]) -> ([[d0]], [[d1]] * 2)>
642-
// CHECK: #[[$map3_sub:.*]] = affine_map<([[d0:.*]], [[d1:.*]]) -> ([[d0]] - [[d1]])>
641+
// CHECK: #[[$map3:.*]] = affine_map<(d0, d1) -> (d0, d1 * 2)>
642+
// CHECK: #[[$map3_sub:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)>
643643
// CHECK-LABEL: @vecdim_reduction_complex_ub
644644
// CHECK: %[[vzero:.*]] = arith.constant dense<0.000000e+00> : vector<128xf32>
645645
// CHECK: %{{.*}} = affine.for %[[iv:.*]] = 0 to min #[[$map3]](%[[M:.*]], %[[N:.*]]) step 128 iter_args(%[[red_iter:.*]] = {{.*}}) -> (vector<128xf32>) {
646646
// CHECK: %[[ub:.*]] = affine.min #[[$map3]](%[[M]], %[[N]])
647-
// CHECK: %[[elems_left:.*]] = affine.apply #[[$map3_sub]](%[[ub]], %[[iv]])
647+
// CHECK: %[[elems_left:.*]] = affine.apply #[[$map3_sub]](%[[iv]])[%[[ub]]]
648648
// CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1>
649649
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32>
650650
// CHECK: %[[select:.*]] = arith.select %[[mask]], %[[ld]], %[[vzero]] : vector<128xi1>, vector<128xf32>

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,6 @@ func.func @affine_apply_resul_non_index(%arg0 : index) {
2020
return
2121
}
2222

23-
// -----
24-
25-
#map = affine_map<(d0)[s0] -> (d0 + s0)>
26-
27-
func.func @affine_for_lower_bound_invalid_dim(%arg : index) {
28-
affine.for %n0 = 0 to 7 {
29-
%dim = arith.addi %arg, %arg : index
30-
31-
// expected-error@+1 {{operand cannot be used as a dimension id}}
32-
affine.for %n1 = 0 to #map(%dim)[%arg] {
33-
}
34-
}
35-
return
36-
}
37-
38-
// -----
39-
40-
#map = affine_map<(d0)[s0] -> (d0 + s0)>
41-
42-
func.func @affine_for_upper_bound_invalid_dim(%arg : index) {
43-
affine.for %n0 = 0 to 7 {
44-
%dim = arith.addi %arg, %arg : index
45-
46-
// expected-error@+1 {{operand cannot be used as a dimension id}}
47-
affine.for %n1 = #map(%dim)[%arg] to 7 {
48-
}
49-
}
50-
return
51-
}
52-
5323
// -----
5424
func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
5525
"unknown"() ({
@@ -93,20 +63,6 @@ func.func @affine_for_upper_bound_invalid_sym() {
9363

9464
#set0 = affine_set<(i)[N] : (i >= 0, N - i >= 0)>
9565

96-
func.func @affine_if_invalid_dim(%arg : index) {
97-
affine.for %n0 = 0 to 7 {
98-
%dim = arith.addi %arg, %arg : index
99-
100-
// expected-error@+1 {{operand cannot be used as a dimension id}}
101-
affine.if #set0(%dim)[%n0] {}
102-
}
103-
return
104-
}
105-
106-
// -----
107-
108-
#set0 = affine_set<(i)[N] : (i >= 0, N - i >= 0)>
109-
11066
func.func @affine_if_invalid_sym() {
11167
affine.for %i0 = 0 to 7 {
11268
// expected-error@+1 {{operand cannot be used as a symbol}}

mlir/test/Dialect/Affine/ops.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,88 @@ module attributes {gpu.container_module} {
324324
// CHECK: affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
325325
// CHECK: }
326326
// CHECK: gpu.return
327+
328+
// -----
329+
330+
#map = affine_map<()[s0] -> (s0 mod 32)>
331+
332+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 mod 32)>
333+
334+
// CHECK-LABEL: gpu.func @affine_thread_id
335+
336+
module {
337+
gpu.module @gpu {
338+
gpu.func @affine_thread_id(%arg0: memref<?x?xf32>) kernel {
339+
%c3 = arith.constant 3 : index
340+
%dim = memref.dim %arg0, %c3 : memref<?x?xf32>
341+
%c0 = arith.constant 0 : index
342+
affine.for %arg3 = %c0 to %dim step 32 {
343+
%thread_id_x = gpu.thread_id x
344+
%0 = affine.apply #map()[%thread_id_x]
345+
%c128 = arith.constant 128 : index
346+
affine.for %arg4 = %0 to %c128 step 8 {
347+
%c32 = arith.constant 32 : index
348+
}
349+
}
350+
gpu.return
351+
}
352+
}
353+
}
354+
355+
// CHECK-SAME: (%[[VAL_0:.*]]: memref<?x?xf32>) kernel {
356+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
357+
// CHECK: %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x?xf32>
358+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
359+
// CHECK: affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
360+
// CHECK: %[[VAL_5:.*]] = gpu.thread_id x
361+
// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
362+
// CHECK: %[[VAL_7:.*]] = arith.constant 128 : index
363+
// CHECK: affine.for %{{.*}} = %[[VAL_6]] to %[[VAL_7]] step 8 {
364+
365+
// -----
366+
367+
#map = affine_map<(d0)[s0] -> (d0 + s0)>
368+
369+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
370+
371+
// CHECK-LABEL: func @arith_add_vaild_symbol_upper_bound
372+
373+
func.func @arith_add_vaild_symbol_upper_bound(%arg : index) {
374+
affine.for %n0 = 0 to 7 {
375+
%dim = arith.addi %arg, %arg : index
376+
affine.for %n1 = 0 to #map(%dim)[%arg] {
377+
}
378+
}
379+
return
380+
}
381+
382+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
383+
// CHECK: affine.for %[[VAL_1:.*]] = 0 to 7 {
384+
// CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
385+
// CHECK: affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_0]](%[[VAL_2]]){{\[}}%[[VAL_0]]] {
386+
// CHECK: }
387+
// CHECK: }
388+
389+
// -----
390+
391+
#map = affine_map<(d0)[s0] -> (d0 + s0)>
392+
393+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
394+
395+
// CHECK-LABEL: func @arith_add_vaild_symbol_lower_bound
396+
397+
func.func @arith_add_vaild_symbol_lower_bound(%arg : index) {
398+
affine.for %n0 = 0 to 7 {
399+
%dim = arith.addi %arg, %arg : index
400+
affine.for %n1 = #map(%dim)[%arg] to 7 {
401+
}
402+
}
403+
return
404+
}
405+
406+
// CHECK-SAME: %[[VAL_0:.*]]: index) {
407+
// CHECK: affine.for %[[VAL_1:.*]] = 0 to 7 {
408+
// CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
409+
// CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_0]](%[[VAL_2]]){{\[}}%[[VAL_0]]] to 7 {
410+
// CHECK: }
411+
// CHECK: }

0 commit comments

Comments
 (0)