Skip to content

Commit 44e6318

Browse files
[mlir][transforms] Revamp the implementation of mapping loops to GPUs
This revision significantly simplifies the specification and implementation of mapping loops to GPU ids. Each type of mapping (block, warpgroup, warp, thread) now comes with 2 mapping modes: 1. a 3-D "grid-like" mode, subject to alignment considerations on threadIdx.x, on which predication may occur on a per-dimension 3-D sub-rectangle basis. 2. a n-D linearized mode, on which predication may only occur on a linear basis. In the process, better size and alignment requirement inference are introduced along with improved runtime verification messages. The `warp_dims` attribute was deemed confusing and is removed from the transform in favor of better size inference. Differential Revision: https://reviews.llvm.org/D155941
1 parent 2ee4d03 commit 44e6318

File tree

13 files changed

+1152
-516
lines changed

13 files changed

+1152
-516
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td

Lines changed: 169 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -20,107 +20,214 @@ include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
2020
def DimX : I64EnumAttrCase<"DimX", 0, "x">;
2121
def DimY : I64EnumAttrCase<"DimY", 1, "y">;
2222
def DimZ : I64EnumAttrCase<"DimZ", 2, "z">;
23-
24-
def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [
25-
DimX, DimY, DimZ]> {
23+
def LinearDim0 : I64EnumAttrCase<"LinearDim0", 3, "linear_dim_0">;
24+
def LinearDim1 : I64EnumAttrCase<"LinearDim1", 4, "linear_dim_1">;
25+
def LinearDim2 : I64EnumAttrCase<"LinearDim2", 5, "linear_dim_2">;
26+
def LinearDim3 : I64EnumAttrCase<"LinearDim3", 6, "linear_dim_3">;
27+
def LinearDim4 : I64EnumAttrCase<"LinearDim4", 7, "linear_dim_4">;
28+
def LinearDim5 : I64EnumAttrCase<"LinearDim5", 8, "linear_dim_5">;
29+
def LinearDim6 : I64EnumAttrCase<"LinearDim6", 9, "linear_dim_6">;
30+
def LinearDim7 : I64EnumAttrCase<"LinearDim7", 10, "linear_dim_7">;
31+
def LinearDim8 : I64EnumAttrCase<"LinearDim8", 11, "linear_dim_8">;
32+
def LinearDim9 : I64EnumAttrCase<"LinearDim9", 12, "linear_dim_9">;
33+
34+
// TODO: This would be better represented with separate Grid and Linear Mapping
35+
// ids. Unfortunately it is not yet possible to have an optional EnumParameter
36+
// so we currently embed the 2 modes in the same enum.
37+
def MappingIdEnum : I64EnumAttr<"MappingId", "Mapping ids for loop mapping", [
38+
DimX, DimY, DimZ,
39+
LinearDim0, LinearDim1, LinearDim2, LinearDim3, LinearDim4,
40+
LinearDim5, LinearDim6, LinearDim7, LinearDim8, LinearDim9]> {
2641
let cppNamespace = "::mlir::gpu";
2742
}
2843

29-
def GPUThreadMappingAttr
30-
: GPU_Attr<"GPUThreadMapping", "thread", [
31-
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
44+
def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
45+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
3246
let parameters = (ins
33-
EnumParameter<ThreadsEnum>:$thread
47+
EnumParameter<MappingIdEnum>:$block
3448
);
3549
let assemblyFormat = "`<` params `>`";
3650
let description = [{
37-
An attribute that allows defining thread parallelism for GPU devices.
51+
An attribute that allows defining thread block parallelism for GPU devices.
3852

39-
Thread (aka work item) are grouped into a thread blocks where block may be
40-
described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
41-
that thread parallelism is desired. It can be consumed by lowering to
42-
generate GPU.
43-
}];
44-
}
53+
Thread blocks (aka workgroup) are grouped into a grid described by a
54+
3-dimensional rectangle.
55+
This attribute indicates that thread block parallelism is desired.
56+
It can be consumed by lowering to generate GPU code.
57+
2 modes are supported: (1) 3D mapping mode and (2) linear mapping mode.
4558

46-
def WarpsEnum : I64EnumAttr<"Warps", "threads for loop mapping", [
47-
DimX, DimY, DimZ]> {
48-
let cppNamespace = "::mlir::gpu";
59+
#### 3D mapping mode
60+
61+
The 3D block id is simply the 3D index of the block `(bidx, bidy, bidz)`.
62+
If required, predication occurs on a per-dimension basis. This allows
63+
specifying predication on a 3D sub-rectangle of the grid.
64+
65+
#### Linear mapping mode
66+
67+
The linear block id is obtained by linearizing the index of the block.
68+
If required, predication occurs on the linear id. This allows specifying
69+
predication on a 1D subset of the (linearized) grid.
70+
71+
For instance, if the basis is denoted as (GX, GY, GZ) and the block id is
72+
denoted by (bx, by, bz), the block id is:
73+
`linear_id = bx + by * GX + bz * GX * GBY)`.
74+
The linear block id is fixed for the duration of a GPU kernel.
75+
76+
This linear id mapping attribute indicates a different linearization relation
77+
is applied locally to a loop nest.
78+
79+
For instance, if the new basis is denoted as (LBD0, LBD1, LBD2, LBD3) the
80+
block id in the new basis is:
81+
```(linear_id mod LBD0 ,
82+
(linear_id / LBD0) mod * LBD1,
83+
(linear_id / (LBD0 * LBD1)) mod LBD2,
84+
(linear_id / (LBD0 * LBD1 * LBD2)) mod LBD3)```.
85+
This reinterpretation is only fixed for the duration of a loop nest.
86+
}];
4987
}
5088

51-
def GPUWarpMappingAttr : GPU_Attr<"GPUWarpMapping", "warp", [
52-
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
89+
def GPUWarpgroupMappingAttr
90+
: GPU_Attr<"GPUWarpgroupMapping", "warpgroup", [
91+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
5392
let parameters = (ins
54-
EnumParameter<WarpsEnum>:$warp
93+
EnumParameter<MappingIdEnum>:$warpgroup
5594
);
5695
let assemblyFormat = "`<` params `>`";
5796
let description = [{
58-
An attribute that allows defining thread block parallelism for GPU devices.
97+
An attribute that allows defining warpgroup parallelism for GPU devices.
5998

60-
Warp (aka subgroup) are grouped into a grid where grid may be
61-
described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
62-
that thread block parallelism is desired. It can be consumed by lowering to
63-
generate GPU code.
64-
}];
65-
}
99+
Threads of proper granularity (e.g. multiple of
100+
"kNumWarpsPerGroup * kWarpSize" on CUDA devices) can be grouped into
101+
warpgroups described by a 3-dimensional rectangle.
102+
This attribute indicates that warpgroup parallelism is desired.
103+
It can be consumed by lowering to generate GPU code.
104+
2 modes are supported: (1) 3D mapping mode and (2) linear mapping mode.
66105

67-
def LinearIdEnum : I64EnumAttr<"LinearId", "linear ids for loop mapping", [
68-
DimX, DimY, DimZ]> {
69-
let cppNamespace = "::mlir::gpu";
106+
#### 3D mapping mode
107+
108+
The 3D warpgroup id is simply the adjusted 3D index of the thread
109+
`(tidx / (kNumWarpsPerGroup * kWarpSize), tidy, tidz)`.
110+
If required, predication occurs on a per-dimension basis. This allows
111+
specifying predication on a 3D sub-rectangle of the warpgroups.
112+
113+
#### Linear mapping mode
114+
115+
The linear warpgroup id is obtained by linearizing the index of the warpgroup.
116+
If required, predication occurs on the linear id. This allows specifying
117+
predication on a 1D "kNumWarpsPerGroup * kWarpSize"-aligned subset of the
118+
(linearized) block.
119+
120+
For instance, if the basis is denoted as (BX, BY, BZ) and the thread id is
121+
id is denoted by (tx, ty, tz), the linear warpgroup id is:
122+
```linear_id = (tx + ty * BX + tz * BX * BY)
123+
/ (kNumWarpsPerGroup * kWarpSize)```.
124+
The linear warpgroup id is fixed for the duration of a GPU kernel.
125+
126+
This linear id mapping attribute indicates a different linearization relation
127+
is applied locally to a loop nest.
128+
129+
For instance, if the new basis is denoted as (LWGD0, LWGD1, LWGD2, LWGD3) the
130+
warpgroup id in the new basis is:
131+
```(linear_id mod LWGD0 ,
132+
(linear_id / LWGD0) mod * LWGD1,
133+
(linear_id / (LWGD0 * LWGD1)) mod LWGD2,
134+
(linear_id / (LWGD0 * LWGD1 * LWGD2)) mod LWGD3)```.
135+
This reinterpretation is only fixed for the duration of a loop nest.
136+
}];
70137
}
71138

72-
def GPULinearIdMapping : GPU_Attr<"GPULinearIdMapping", "linear", [
73-
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
139+
def GPUWarpMappingAttr
140+
: GPU_Attr<"GPUWarpMapping", "warp", [
141+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
74142
let parameters = (ins
75-
EnumParameter<LinearIdEnum>:$linear_id
143+
EnumParameter<MappingIdEnum>:$warp
76144
);
77145
let assemblyFormat = "`<` params `>`";
78146
let description = [{
79-
An attribute to allow re-interpreting the linear mapping for threads in GPU
80-
devices.
147+
An attribute that allows defining warp parallelism for GPU devices.
81148

82-
Threads (aka work item) are grouped into a thread block where block may be
83-
described by a 1-, 2-, or 3-dimensional rectangular basis.
84-
The linear thread id is obtained by linearizing the 1-, 2- or 3-dimensional
85-
index. For instance, if the basis is denoted as (BX, BY, BZ) and the thread
86-
id is denoted by (tx, ty, tz), the linear thread id is:
87-
`linear_id = tx + ty * BX + tz * BX * BY)`.
88-
The linear thread id is fixed for the duration of a GPU kernel.
149+
Threads of proper granularity (e.g. multiple of "warp size" on CUDA devices)
150+
can be grouped into warps described by a 3-dimensional rectangle.
151+
This attribute indicates that warp parallelism is desired.
152+
It can be consumed by lowering to generate GPU code.
153+
2 modes are supported: (1) 3D mapping mode and (2) linear mapping mode.
154+
155+
#### 3D mapping mode
156+
157+
The 3D warp id is simply the adjusted 3D index of the thread
158+
`(tidx / kWarpSize, tidy, tidz)`.
159+
If required, predication occurs on a per-dimension basis. This allows
160+
specifying predication on a 3D sub-rectangle of the warpgroups.
161+
162+
#### Linear mapping mode
163+
164+
The linear warp id is obtained by linearizing the index of the warp.
165+
If required, predication occurs on the linear id. This allows specifying
166+
predication on a 1D "kWarpSize"-aligned subset of the (linearized) block.
167+
168+
For instance, if the basis is denoted as (BX, BY, BZ) and the thread id is
169+
id is denoted by (tx, ty, tz), the linear warp id is:
170+
`linear_id = (tx + ty * BX + tz * BX * BY) / kWarpSize`.
171+
The linear warp id is fixed for the duration of a GPU kernel.
89172

90173
This linear id mapping attribute indicates a different linearization relation
91174
is applied locally to a loop nest.
92175

93-
For instance, if the new basis is denoted as (LBX, LBY, LBZ) the thread id
94-
in the new basis is:
95-
`(linear_id mod LBX , (linear_id / LBX) mod * LBY, linear_id / (LBX * LBY))`.
96-
This reinterpretation is only fixe for the duration of a loop nest.
97-
98-
It can be consumed by lowering to generate GPU code.
176+
For instance, if the new basis is denoted as (LWD0, LWD1, LWD2, LWD3) the
177+
warp id in the new basis is:
178+
```(linear_id mod LWD0 ,
179+
(linear_id / LWD0) mod * LWD1,
180+
(linear_id / (LWD0 * LWD1)) mod LWD2,
181+
(linear_id / (LWD0 * LWD1 * LWD2)) mod LWD3)```.
182+
This reinterpretation is only fixed for the duration of a loop nest.
99183
}];
100184
}
101185

102-
def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
103-
DimX, DimY, DimZ]> {
104-
let cppNamespace = "::mlir::gpu";
105-
}
106-
107-
def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
108-
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
186+
def GPUThreadMappingAttr
187+
: GPU_Attr<"GPUThreadMapping", "thread", [
188+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
109189
let parameters = (ins
110-
EnumParameter<BlocksEnum>:$block
190+
EnumParameter<MappingIdEnum>:$thread
111191
);
112192
let assemblyFormat = "`<` params `>`";
113193
let description = [{
114-
An attribute that allows defining thread block parallelism for GPU devices.
194+
An attribute that allows defining thread parallelism for GPU devices.
195+
196+
Thread (aka work item) are grouped into a thread blocks described by a
197+
3-dimensional rectangle.
198+
This attribute indicates that thread parallelism is desired.
199+
It can be consumed by lowering to generate GPU.
200+
201+
#### 3D mapping mode
202+
203+
The 3D thread id is simply the 3D index of the thread `(tidx, tidy, tidz)`.
204+
If required, predication occurs on a per-dimension basis. This allows
205+
specifying predication on a 3D sub-rectangle of the block.
206+
207+
#### Linear mapping mode
115208

116-
Thread blocks (aka work-group) are grouped into a grid where grid may be
117-
described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
118-
that thread block parallelism is desired. It can be consumed by lowering to
119-
generate GPU code.
209+
The linear thread id is obtained by linearizing the index of the thread.
210+
If required, predication occurs on the linear id. This allows specifying
211+
predication on a 1D subset of the (linearized) block.
212+
213+
For instance, if the basis is denoted as (BX, BY, BZ) and the thread id is
214+
id is denoted by (tx, ty, tz), the linear thread id is:
215+
```linear_id = (tx + ty * BX + tz * BX * BY)```.
216+
The linear thread id is fixed for the duration of a GPU kernel.
217+
218+
This linear id mapping attribute indicates a different linearization relation
219+
is applied locally to a loop nest.
220+
221+
For instance, if the new basis is denoted as (LTD0, LTD1, LTD2, LTD3) the
222+
thread id in the new basis is:
223+
```(linear_id mod LTD0 ,
224+
(linear_id / LTD0) mod * LTD1,
225+
(linear_id / (LTD0 * LTD1)) mod LTD2,
226+
(linear_id / (LTD0 * LTD1 * LTD2)) mod LTD3)```.
227+
This reinterpretation is only fixed for the duration of a loop nest.
120228
}];
121229
}
122230

123-
124231
def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
125232
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
126233
let parameters = (ins
@@ -138,5 +245,4 @@ def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space"
138245
}];
139246
}
140247

141-
142248
#endif // GPU_DEVICE_MAPPING_ATTR

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,49 +33,49 @@ namespace transform {
3333
namespace gpu {
3434
struct GpuIdBuilder;
3535

36-
/// Map the top level `scf.forall` op to GPU Thread Blocks.
36+
/// Map the top level `scf.forall` op to GPU blocks.
3737
/// Mapping is one-to-one and the induction variables of `scf.forall` are
3838
/// rewritten to gpu.block_id according to the thread_dim_mapping attribute.
3939
///
4040
/// Dynamic, `scf.forall` trip counts are currently not supported.
41-
/// Dynamic block dim sizes are currently not supported.
41+
/// Dynamic `gridDims` are currently not supported.
4242
DiagnosedSilenceableFailure
4343
mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp,
4444
scf::ForallOp forallOp,
4545
SmallVectorImpl<int64_t> &gridDims,
4646
const GpuIdBuilder &gpuIdBuilder);
4747

4848
/// Search `scf.forall` ops nested under `target` and map each such op to an
49-
/// explicit GPU implementation along `availableMappingSizes`.
49+
/// explicit GPU implementation along `blockDims`.
5050
/// The mapping is one-to-one and the induction variables of `scf.forall` are
5151
/// rewritten to gpuIdBuilder.idBuilder according to the
5252
/// gpuIdBuilder.mappingAttributes attribute.
5353
///
5454
/// Dynamic, `scf.forall` trip counts are currently not supported.
55-
/// Dynamic `availableMappingSizes` sizes are currently not supported.
56-
/// `availableMappingSizes` is expected to be of size 3.
57-
DiagnosedSilenceableFailure mapOneForallToThreadsImpl(
58-
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
59-
scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
60-
bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder);
55+
/// Dynamic `blockDims` sizes are currently not supported.
56+
/// `blockDims` is expected to be of size 3.
57+
DiagnosedSilenceableFailure
58+
mapOneForallToThreadsImpl(RewriterBase &rewriter,
59+
std::optional<TransformOpInterface> transformOp,
60+
scf::ForallOp forallOp, ArrayRef<int64_t> blockDims,
61+
int64_t warpSize, bool syncAfterDistribute);
6162

6263
/// Search `scf.forall` ops nested under `target` and map each such op to an
63-
/// explicit GPU implementation along blockDims and warpDims.
64+
/// explicit GPU implementation along `blockDims`.
6465
/// The mapping is one-to-one and the induction variables of `scf.forall` are
65-
/// rewritten to threads and warps ids according to the mapping attribute.
66+
/// rewritten to appropriate ids according to the mapping attribute.
6667
///
6768
/// Dynamic, `scf.forall` trip counts are currently not supported.
68-
/// Dynamic `blockDims` or `warpDims` or `linearDims` sizes are currently not
69-
/// supported.
70-
/// `blockDims` is expected to be of size 3.
71-
/// `warpDims` is expected to be empty or of size 3.
69+
/// Dynamic `blockDims` or `newBasis` entries are currently not
70+
/// supported. `blockDims` is expected to be of size 3.
7271
///
7372
/// The insertion point of the `rewriter` is expected to be set at the
7473
/// beginning of the `target` body block and dominate all other blocks.
75-
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
76-
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
77-
Operation *target, ArrayRef<int64_t> blockDimsOfr,
78-
ArrayRef<int64_t> warpDims, bool syncAfterDistribute);
74+
DiagnosedSilenceableFailure
75+
mapNestedForallToThreadsImpl(RewriterBase &rewriter,
76+
std::optional<TransformOpInterface> transformOp,
77+
Operation *target, ArrayRef<int64_t> blockDims,
78+
int64_t warpSize, bool syncAfterDistribute);
7979

8080
} // namespace gpu
8181
} // namespace transform

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,15 @@ def MapNestedForallToThreads :
167167

168168
let arguments = (ins TransformHandleTypeInterface:$target,
169169
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$block_dims,
170-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$warp_dims,
171-
DefaultValuedAttr<BoolAttr, "true">:$sync_after_distribute);
170+
DefaultValuedAttr<BoolAttr, "true">:$sync_after_distribute,
171+
DefaultValuedAttr<I64Attr, "32">:$warp_size);
172172
let results = (outs TransformHandleTypeInterface:$result);
173173

174174
let assemblyFormat = [{
175175
$target
176176
`block_dims` `=` $block_dims
177-
(`warp_dims` `=` $warp_dims^)?
178177
(`sync_after_distribute` `=` $sync_after_distribute^)?
178+
(`warp_size` `=` $warp_size^)?
179179
attr-dict
180180
`:` functional-type($target, $result)
181181
}];

0 commit comments

Comments
 (0)