Skip to content

Commit 7a66746

Browse files
authored
[mlir][xegpu] Handle scalar uniform ops in SIMT distribution. (#138593)
This PR adds support for moving scalar uniform (gpu index ops, constants etc) outside the `gpu.warp_execute_on_lane0` op. These kinds of ops do not require distribution and are safe to move out of the warp op. This also avoid adding separate distribution patterns for these ops. Example: ``` %1 = gpu.warp_execute_on_lane_0(%laneid) -> (index) { ... %block_id_x = gpu.block_id x gpu.yield %block_id_x } // use %1 ``` To: ``` %block_id_x = gpu.block_id x %1 = gpu.warp_execute_on_lane_0(%laneid) -> (index) { ... gpu.yield %block_id_x } // use %1 ```
1 parent 3a5af23 commit 7a66746

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
14631463
signalPassFailure();
14641464
return;
14651465
}
1466+
// At this point, we have moved the entire function body inside the warpOp.
1467+
// Now move any scalar uniform code outside of the warpOp (like GPU index
1468+
// ops, scalar constants, etc.). This will simplify the later lowering and
1469+
// avoid custom patterns for these ops.
1470+
getOperation()->walk([&](Operation *op) {
1471+
if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1472+
vector::moveScalarUniformCode(warpOp);
1473+
}
1474+
});
14661475
}
14671476
// Finally, do the SIMD to SIMT distribution.
14681477
RewritePatternSet patterns(&getContext());

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -xegpu-subgroup-distribute -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: gpu.func @store_nd_1d
44
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
@@ -160,3 +160,50 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
160160
gpu.return
161161
}
162162
}
163+
164+
// -----
165+
// CHECK-LABEL: gpu.func @gemm_loop
166+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
167+
// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
168+
// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
169+
// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
170+
// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
171+
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
172+
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
173+
// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
174+
// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
175+
// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
176+
// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
177+
// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
178+
// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
179+
// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
180+
// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
181+
// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
182+
// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
183+
// CHECK: }
184+
// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
185+
// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
186+
gpu.module @test {
187+
gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
188+
%c0 = arith.constant 0 : index
189+
%c16 = arith.constant 16 : index
190+
%c8 = arith.constant 8 : index
191+
%c1024 = arith.constant 1024 : index
192+
%0 = gpu.block_id x
193+
%1 = gpu.block_id y
194+
%2 = arith.muli %0, %c8 : index
195+
%3 = arith.muli %1, %c16 : index
196+
%4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
197+
%5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
198+
%6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
199+
%7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
200+
%8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
201+
%9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
202+
%10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
203+
%11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
204+
scf.yield %11 : vector<8x16xf32>
205+
}
206+
xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
207+
gpu.return
208+
}
209+
}

0 commit comments

Comments
 (0)