Skip to content

Commit 8063bd1

Browse files
authored
[MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (#142797)
This PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup.
1 parent 556e69b commit 8063bd1

File tree

2 files changed

+252
-1
lines changed

2 files changed

+252
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
99

1010
#include "mlir/Dialect/Affine/Utils.h"
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1314
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1415
#include "mlir/Dialect/Index/IR/IndexOps.h"
16+
#include "mlir/Dialect/Math/IR/Math.h"
1517
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1618
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1719
#include "mlir/Dialect/Utils/IndexingUtils.h"
1820
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1921
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
2022
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
2123
#include "mlir/Transforms/DialectConversion.h"
24+
#include <optional>
2225

2326
namespace mlir {
2427
namespace xegpu {
@@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
328331
}
329332
};
330333

334+
// This pattern transforms elementwise ops to work at subgroup level.
335+
struct WgToSgElementwiseOp : public ConversionPattern {
336+
WgToSgElementwiseOp(MLIRContext *ctx)
337+
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
338+
339+
LogicalResult
340+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
341+
ConversionPatternRewriter &rewriter) const override {
342+
// Only match ops with elementwise trait and single result.
343+
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
344+
return failure();
345+
346+
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
347+
assert(resultType && "Expected result to be a VectorType");
348+
349+
ArrayRef<int64_t> wgShape = resultType.getShape();
350+
351+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
352+
if (!layout || !layout.getSgLayout())
353+
return failure();
354+
355+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
356+
357+
size_t numVariants = operands.empty() ? 0 : operands.front().size();
358+
359+
if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
360+
return operandVec.size() != numVariants;
361+
}))
362+
return failure();
363+
364+
SmallVector<Value> newResults;
365+
VectorType newResultType =
366+
VectorType::get(sgShape, resultType.getElementType());
367+
368+
for (size_t i = 0; i < numVariants; ++i) {
369+
SmallVector<Value> opOperands;
370+
for (auto &operandVec : operands)
371+
opOperands.push_back(operandVec[i]);
372+
373+
OperationState state(op->getLoc(), op->getName());
374+
state.addOperands(opOperands);
375+
state.addTypes(newResultType);
376+
// Copy all attributes, but update "layout_result_0" to drop
377+
// sgLayout/sgData
378+
for (auto attr : op->getAttrs()) {
379+
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
380+
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
381+
else
382+
state.addAttribute(attr.getName(), attr.getValue());
383+
}
384+
Operation *newOp = rewriter.create(state);
385+
newResults.push_back(newOp->getResult(0));
386+
}
387+
388+
rewriter.replaceOpWithMultiple(op, {newResults});
389+
return success();
390+
}
391+
};
392+
331393
// Handles UnrealizedConversionCastOp generated during
332394
// SCFStructuralTypeConversions (step 1). This op may appear as either a
333395
// target or source materialization for Vector values, e.g.:
@@ -411,7 +473,8 @@ namespace xegpu {
411473
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
412474
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
413475
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
414-
UnrealizedConversionCastOpPattern>(patterns.getContext());
476+
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
477+
patterns.getContext());
415478
}
416479
} // namespace xegpu
417480
} // namespace mlir
@@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
518581
return isLegal(layout);
519582
});
520583

584+
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
585+
[=](Operation *op) -> std::optional<bool> {
586+
// Only handle elementwise mappable ops
587+
if (!OpTrait::hasElementwiseMappableTraits(op))
588+
return true;
589+
590+
VectorType resultType =
591+
dyn_cast<VectorType>(op->getResult(0).getType());
592+
if (!resultType)
593+
return true;
594+
595+
// Check if all operands are vectors of the same shape
596+
// TODO: Support other types.
597+
for (Value operand : op->getOperands()) {
598+
VectorType operandType = dyn_cast<VectorType>(operand.getType());
599+
if (!operandType || operandType.getShape() != resultType.getShape()) {
600+
return true;
601+
}
602+
}
603+
604+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
605+
return isLegal(layout);
606+
});
607+
521608
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
522609
[=](UnrealizedConversionCastOp op) {
523610
return llvm::is_contained(existingCastOps, op.getOperation());
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
2+
3+
gpu.module @test_elementwise_ops {
4+
// CHECK-LABEL: unary_ops
5+
gpu.func @unary_ops(%a: memref<24x32xf32>) {
6+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
7+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
8+
%load_a = xegpu.load_nd %tdesc_a
9+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
10+
-> vector<24x32xf32>
11+
// CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
12+
%exp = math.exp %load_a
13+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
14+
: vector<24x32xf32>
15+
// CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
16+
%negf = arith.negf %load_a
17+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
18+
: vector<24x32xf32>
19+
gpu.return
20+
}
21+
22+
// CHECK-LABEL: binary_ops
23+
gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
24+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
25+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
26+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
27+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
28+
%load_a = xegpu.load_nd %tdesc_a
29+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
30+
-> vector<24x32xf32>
31+
%load_b = xegpu.load_nd %tdesc_b
32+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
33+
-> vector<24x32xf32>
34+
// CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
35+
// CHECK-SAME: : vector<12x8xf32>
36+
%addf = arith.addf %load_a, %load_b
37+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
38+
: vector<24x32xf32>
39+
// CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
40+
// CHECK-SAME: : vector<12x8xf32>
41+
%powf = math.powf %load_a, %load_b
42+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
43+
: vector<24x32xf32>
44+
gpu.return
45+
}
46+
47+
// CHECK-LABEL: ternary_ops
48+
gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) {
49+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
50+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
51+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
52+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
53+
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1>
54+
-> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
55+
%load_a = xegpu.load_nd %tdesc_a
56+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
57+
-> vector<24x32xf32>
58+
%load_b = xegpu.load_nd %tdesc_b
59+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
60+
-> vector<24x32xf32>
61+
%load_c = xegpu.load_nd %tdesc_c
62+
: !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
63+
-> vector<24x32xi1>
64+
// CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
65+
// CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
66+
%select = arith.select %load_c, %load_a, %load_b
67+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
68+
: vector<24x32xi1>, vector<24x32xf32>
69+
// CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
70+
// CHECK-SAME: : vector<12x8xf32>
71+
%fma = math.fma %load_a, %load_b, %load_a
72+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
73+
: vector<24x32xf32>
74+
gpu.return
75+
}
76+
77+
// CHECK-LABEL: type_conversion_ops
78+
gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) {
79+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
80+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
81+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
82+
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
83+
%load_a = xegpu.load_nd %tdesc_a
84+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
85+
-> vector<24x32xf32>
86+
%load_b = xegpu.load_nd %tdesc_b
87+
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
88+
-> vector<24x32xi32>
89+
// CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
90+
// CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
91+
%truncf = arith.truncf %load_a
92+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
93+
: vector<24x32xf32> to vector<24x32xf16>
94+
// CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
95+
// CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
96+
%bitcast = arith.bitcast %load_b
97+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
98+
: vector<24x32xi32> to vector<24x32xf32>
99+
gpu.return
100+
}
101+
102+
// CHECK-LABEL: comparison_ops
103+
gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
104+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
105+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
106+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
107+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
108+
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
109+
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
110+
%tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
111+
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
112+
%load_a = xegpu.load_nd %tdesc_a
113+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
114+
-> vector<24x32xf32>
115+
%load_b = xegpu.load_nd %tdesc_b
116+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
117+
-> vector<24x32xf32>
118+
%load_c = xegpu.load_nd %tdesc_c
119+
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
120+
-> vector<24x32xi32>
121+
%load_d = xegpu.load_nd %tdesc_d
122+
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
123+
-> vector<24x32xi32>
124+
// CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
125+
// CHECK-SAME: : vector<12x8xf32>
126+
%cmpf = arith.cmpf ult, %load_a, %load_b
127+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
128+
: vector<24x32xf32>
129+
// CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
130+
// CHECK-SAME: : vector<12x8xi32>
131+
%cmpi = arith.cmpi eq, %load_c, %load_d
132+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
133+
: vector<24x32xi32>
134+
gpu.return
135+
}
136+
137+
// 1 to N decomposition of elementwise operations
138+
// CHECK-LABEL: elementwise_ops_rr_assignment
139+
gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
140+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
141+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
142+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
143+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
144+
%load_a = xegpu.load_nd %tdesc_a
145+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
146+
-> vector<24x32xf32>
147+
%load_b = xegpu.load_nd %tdesc_b
148+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
149+
-> vector<24x32xf32>
150+
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
151+
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
152+
// CHECK-NOT: arith.negf
153+
%negf = arith.negf %load_a
154+
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
155+
: vector<24x32xf32>
156+
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
157+
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
158+
// CHECK-NOT: math.powf
159+
%powf = math.powf %load_a, %load_b
160+
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
161+
: vector<24x32xf32>
162+
gpu.return
163+
}
164+
}

0 commit comments

Comments
 (0)