Skip to content

Commit cb2e5a8

Browse files
Introduce AffineSymbol trait and use it for using gpu.threadid op in inner loops.
1 parent 26d513d commit cb2e5a8

File tree

8 files changed

+93
-31
lines changed

8 files changed

+93
-31
lines changed

mlir/docs/Dialects/Affine.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ immediately enclosed by the latter),
6969
3. a value that dominates the `AffineScope` op enclosing the value's
7070
use,
7171
4. the result of a constant operation,
72-
5. the result of an
72+
5. the result of an `AffineSymbol` op,
73+
6. the result of an
7374
[`affine.apply` operation](#affineapply-mliraffineapplyop) that recursively takes as
7475
arguments any valid symbolic identifiers, or
75-
6. the result of a
76+
7. the result of a
7677
[`dim` operation](MemRef.md/#memrefdim-mlirmemrefdimop) on either a memref that
7778
is an argument to a `AffineScope` op or a memref where the corresponding
7879
dimension is either static or a dynamic one in turn bound to a valid symbol.

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def GPU_GridDimOp : GPU_IndexOp<"grid_dim"> {
223223
There is an implicit upper bound of `kMaxDim` (currently uint32_t::max).
224224
}];
225225
}
226-
def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
226+
def GPU_ThreadIdOp : GPU_IndexOp<"thread_id", [AffineSymbol]> {
227227
let description = [{
228228
Returns the thread id, i.e. the index of the current thread within the block
229229
along the x, y, or z `dimension`.

mlir/include/mlir/IR/OpBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class PredOpTrait<string descr, Pred pred, list<Trait> traits = []>
6363

6464
// Op defines an affine scope.
6565
def AffineScope : NativeOpTrait<"AffineScope">;
66+
// Op defines an affine symbol.
67+
def AffineSymbol : NativeOpTrait<"AffineSymbol">;
6668
// Op defines an automatic allocation scope.
6769
def AutomaticAllocationScope :
6870
NativeOpTrait<"AutomaticAllocationScope">;

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
365365
LogicalResult verifyNoRegionArguments(Operation *op);
366366
LogicalResult verifyElementwise(Operation *op);
367367
LogicalResult verifyIsIsolatedFromAbove(Operation *op);
368+
LogicalResult verifyIndexResultType(Operation *op);
368369
} // namespace impl
369370

370371
/// Helper class for implementing traits. Clients are not expected to interact
@@ -1268,6 +1269,16 @@ class AffineScope : public TraitBase<ConcreteType, AffineScope> {
12681269
}
12691270
};
12701271

1272+
/// A trait of operation. Any operation holds the AffineSymbol, and its result
1273+
/// can be used as a symbol.
1274+
template <typename ConcreteType>
1275+
class AffineSymbol : public TraitBase<ConcreteType, AffineSymbol> {
1276+
public:
1277+
static LogicalResult verifyTrait(Operation *op) {
1278+
return impl::verifyIndexResultType(op);
1279+
}
1280+
};
1281+
12711282
/// A trait of region holding operations that define a new scope for automatic
12721283
/// allocations, i.e., allocations that are freed when control is transferred
12731284
/// back from the operation's region. Any operations performing such allocations

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,10 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
443443
if (matchPattern(defOp, m_Constant(&operandCst)))
444444
return true;
445445

446+
// AffineScope Op.
447+
if (defOp->hasTrait<OpTrait::AffineSymbol>())
448+
return true;
449+
446450
// Affine apply operation is ok if all of its operands are ok.
447451
if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
448452
return applyOp.isValidSymbol(region);

mlir/lib/IR/Operation.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,14 @@ LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) {
13901390
return success();
13911391
}
13921392

1393+
LogicalResult OpTrait::impl::verifyIndexResultType(Operation *op) {
1394+
if (op->getNumResults() != 1)
1395+
op->emitError("operation's result number should be 1.");
1396+
if (!mlir::isa<IndexType>(op->getResult(0).getType()))
1397+
op->emitError("operation's result type should be index.");
1398+
return success();
1399+
}
1400+
13931401
bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
13941402
return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() &&
13951403
op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>();

mlir/test/Dialect/Affine/ops.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,39 @@ module attributes {gpu.container_module} {
324324
// CHECK: affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
325325
// CHECK: }
326326
// CHECK: gpu.return
327+
328+
// -----
329+
330+
#map = affine_map<()[s0] -> (s0 mod 32)>
331+
332+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 mod 32)>
333+
334+
module {
335+
gpu.module @gpu {
336+
gpu.func @affine_thread_id(%arg0: memref<?x?xf32>) kernel {
337+
%c3 = arith.constant 3 : index
338+
%dim = memref.dim %arg0, %c3 : memref<?x?xf32>
339+
%c0 = arith.constant 0 : index
340+
affine.for %arg3 = %c0 to %dim step 32 {
341+
%thread_id_x = gpu.thread_id x
342+
%0 = affine.apply #map()[%thread_id_x]
343+
%c128 = arith.constant 128 : index
344+
affine.for %arg4 = %0 to %c128 step 8 {
345+
%c32 = arith.constant 32 : index
346+
}
347+
}
348+
gpu.return
349+
}
350+
}
351+
}
352+
353+
// CHECK-LABEL: @affine_thread_id
354+
// CHECK-SAME: (%[[VAL_0:.*]]: memref<?x?xf32>) kernel {
355+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
356+
// CHECK: %[[VAL_2:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : memref<?x?xf32>
357+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
358+
// CHECK: affine.for %[[VAL_4:.*]] = %[[VAL_3]] to %[[VAL_2]] step 32 {
359+
// CHECK: %[[VAL_5:.*]] = gpu.thread_id x
360+
// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
361+
// CHECK: %[[VAL_7:.*]] = arith.constant 128 : index
362+
// CHECK: affine.for %[[VAL_8:.*]] = %[[VAL_6]] to %[[VAL_7]] step 8 {

mlir/test/Dialect/GPU/transform-gpu.mlir

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ module attributes {transform.with_named_sequence} {
4343
!type = memref<2 x 32 x f32>
4444
!type1d = memref<32 x f32>
4545

46-
// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 128)>
46+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 128)>
4747

4848
// CHECK-LABEL: func.func @warpgroup_3d(
4949
// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -61,7 +61,7 @@ func.func @warpgroup_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
6161
// CHECK: gpu.launch
6262
// CHECK: %[[TIDX:.*]] = gpu.thread_id x
6363
// CHECK: %[[TIDY:.*]] = gpu.thread_id y
64-
// CHECK-DAG: %[[WG:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
64+
// CHECK-DAG: %[[WG:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
6565
// CHECK-DAG: %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C384]] : index
6666
// CHECK-DAG: %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C1]] : index
6767
// CHECK: %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
9595
!type = memref<2 x 32 x f32>
9696
!type1d = memref<32 x f32>
9797

98-
// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
98+
// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 16)>
9999

100100
// CHECK-LABEL: func.func @warp_3d(
101101
// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -114,7 +114,7 @@ func.func @warp_3d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
114114
// CHECK: gpu.launch
115115
// CHECK: %[[TIDX:.*]] = gpu.thread_id x
116116
// CHECK: %[[TIDY:.*]] = gpu.thread_id y
117-
// CHECK-DAG: %[[W:.*]] = affine.apply #[[$MAP]](%[[TIDX]])
117+
// CHECK-DAG: %[[W:.*]] = affine.apply #[[$MAP]]()[%[[TIDX]]]
118118
// CHECK-DAG: %[[CMPX:.*]] = arith.cmpi ult, %[[TIDX]], %[[C32]] : index
119119
// CHECK-DAG: %[[CMPY:.*]] = arith.cmpi ult, %[[TIDY]], %[[C3]] : index
120120
// CHECK: %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
@@ -354,9 +354,9 @@ module attributes {transform.with_named_sequence} {
354354
!type = memref<2 x 32 x f32>
355355
!type1d = memref<32 x f32>
356356

357-
// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
358-
// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 32) floordiv 128) mod 2)>
359-
// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<(d0, d1, d2) -> (d2 + ((d0 + d1 * 32) floordiv 128) floordiv 2)>
357+
// CHECK-DAG: #[[$MAPWGLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
358+
// CHECK-DAG: #[[$MAPWGX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 32) floordiv 128) mod 2)>
359+
// CHECK-DAG: #[[$MAPWGY:.*]] = affine_map<()[s0, s1, s2] -> (s2 + ((s0 + s1 * 32) floordiv 128) floordiv 2)>
360360

361361
// CHECK-LABEL: func.func @warpgroup_linear(
362362
// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -376,9 +376,9 @@ func.func @warpgroup_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %st
376376
// CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id x
377377
// CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id y
378378
// CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id z
379-
// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
380-
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]](%[[TIDX]], %[[TIDY]])
381-
// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
379+
// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWGLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
380+
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWGX]]()[%[[TIDX]], %[[TIDY]]]
381+
// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWGY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
382382
// CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C768]] : index
383383
// CHECK: scf.if %[[CMPLIN]]
384384
// CHECK: memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -410,9 +410,9 @@ module attributes {transform.with_named_sequence} {
410410
!type = memref<2 x 32 x f32>
411411
!type1d = memref<32 x f32>
412412

413-
// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 32 + d2 * 256)>
414-
// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) mod 2)>
415-
// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1, d2) -> ((d1 + d2 * 8 + d0 floordiv 32) floordiv 2)>
413+
// CHECK-DAG: #[[$MAPWLIN:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 256)>
414+
// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) mod 2)>
415+
// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1, s2] -> ((s1 + s2 * 8 + s0 floordiv 32) floordiv 2)>
416416

417417
// CHECK-LABEL: func.func @warp_linear(
418418
// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
@@ -432,9 +432,9 @@ func.func @warp_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
432432
// CHECK-DAG: %[[TIDX:.*]] = gpu.thread_id x
433433
// CHECK-DAG: %[[TIDY:.*]] = gpu.thread_id y
434434
// CHECK-DAG: %[[TIDZ:.*]] = gpu.thread_id z
435-
// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
436-
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
437-
// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
435+
// CHECK-DAG: %[[WIDLIN:.*]] = affine.apply #[[$MAPWLIN]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
436+
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
437+
// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
438438
// CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[WIDLIN]], %[[C192]] : index
439439
// CHECK: scf.if %[[CMPLIN]]
440440
// CHECK: memref.load %[[ARGX]][%[[WIDX]], %[[WIDY]]]
@@ -466,12 +466,12 @@ module attributes {transform.with_named_sequence} {
466466
!type = memref<2 x 32 x f32>
467467
!type1d = memref<32 x f32>
468468

469-
// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 18) floordiv 32) mod 3)>
470-
// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 18) floordiv 32) mod 6) floordiv 3)>
469+
// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<()[s0, s1] -> (((s0 + s1 * 18) floordiv 32) mod 3)>
470+
// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<()[s0, s1] -> ((((s0 + s1 * 18) floordiv 32) mod 6) floordiv 3)>
471471

472-
// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 18)>
473-
// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) mod 10)>
474-
// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 18) floordiv 10)>
472+
// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 18)>
473+
// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) mod 10)>
474+
// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<()[s0, s1] -> ((s0 + s1 * 18) floordiv 10)>
475475

476476
// CHECK-LABEL: func.func @map_multi_level_linear(
477477
func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
@@ -504,9 +504,9 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
504504
memref.store %6, %y[%i, %j] : !type
505505
} { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
506506

507-
// CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
508-
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
509-
// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
507+
// CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]]()[%[[TIDX]], %[[TIDY]]]
508+
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]]()[%[[TIDX]], %[[TIDY]]]
509+
// CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]]()[%[[TIDX]], %[[TIDY]]]
510510
// CHECK-DAG: %[[CMPLIN:.*]] = arith.cmpi ult, %[[LIN]], %[[C192]] : index
511511
// CHECK: scf.if %[[CMPLIN]]
512512
scf.forall (%i, %j, %k) in (%c3, %c2, %c1) {
@@ -515,8 +515,8 @@ func.func @map_multi_level_linear(%x: !type, %y: !type, %t: !type1d, %alpha : f3
515515
memref.store %8, %y[%i, %j] : !type
516516
} {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_2>] }
517517

518-
// CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
519-
// CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
518+
// CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]]()[%[[TIDX]], %[[TIDY]]]
519+
// CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]]()[%[[TIDX]], %[[TIDY]]]
520520
// CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
521521
// CHECK: scf.if %[[COND]]
522522
// CHECK: memref.load %{{.*}}[%[[LIDX]]] : memref<32xf32>
@@ -648,7 +648,7 @@ module attributes {transform.with_named_sequence} {
648648
#map1 = affine_map<(d0) -> (d0 * 32)>
649649

650650
// CHECK-DAG: #[[$MAPB:.*]] = affine_map<(d0) -> (d0 * 128)>
651-
// CHECK-DAG: #[[$MAPW:.*]] = affine_map<(d0, d1, d2) -> (d2 * 32 + ((d0 + d1 * 4) floordiv 32) * 32)>
651+
// CHECK-DAG: #[[$MAPW:.*]] = affine_map<()[s0, s1, s2] -> (s2 * 32 + ((s0 + s1 * 4) floordiv 32) * 32)>
652652

653653
// CHECK-LABEL: func.func @simple_fill(
654654
func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
@@ -667,7 +667,7 @@ func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> {
667667
// CHECK: %[[TIDX:.*]] = gpu.thread_id x
668668
// CHECK: %[[TIDY:.*]] = gpu.thread_id y
669669
// CHECK: %[[TIDZ:.*]] = gpu.thread_id z
670-
// CHECK: %[[THX:.*]] = affine.apply #[[$MAPW]](%[[TIDX]], %[[TIDY]], %[[TIDZ]])
670+
// CHECK: %[[THX:.*]] = affine.apply #[[$MAPW]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]]
671671
// CHECK-NOT: scf.if
672672
// CHECK: memref.subview %{{.*}}[%[[THX]]]
673673
%1 = affine.apply #map1(%arg2)

0 commit comments

Comments
 (0)