Skip to content

[MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] #142797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 17, 2025
89 changes: 88 additions & 1 deletion mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

// This pattern transforms elementwise ops to work at subgroup level.
struct WgToSgElementwiseOp : public ConversionPattern {
WgToSgElementwiseOp(MLIRContext *ctx)
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const override {
// Only match ops with elementwise trait and single result.
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();

auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
assert(resultType && "Expected result to be a VectorType");

ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
if (!layout || !layout.getSgLayout())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

size_t numVariants = operands.empty() ? 0 : operands.front().size();

if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
return operandVec.size() != numVariants;
}))
return failure();

SmallVector<Value> newResults;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());

for (size_t i = 0; i < numVariants; ++i) {
SmallVector<Value> opOperands;
for (auto &operandVec : operands)
opOperands.push_back(operandVec[i]);

OperationState state(op->getLoc(), op->getName());
state.addOperands(opOperands);
state.addTypes(newResultType);
// Copy all attributes, but update "layout_result_0" to drop
// sgLayout/sgData
for (auto attr : op->getAttrs()) {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
else
state.addAttribute(attr.getName(), attr.getValue());
}
Operation *newOp = rewriter.create(state);
newResults.push_back(newOp->getResult(0));
}

rewriter.replaceOpWithMultiple(op, {newResults});
return success();
}
};

// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
Expand Down Expand Up @@ -411,7 +473,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern>(patterns.getContext());
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
if (!OpTrait::hasElementwiseMappableTraits(op))
return true;

VectorType resultType =
dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return true;

// Check if all operands are vectors of the same shape
// TODO: Support other types.
for (Value operand : op->getOperands()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider the use of llvm::all_equal on op->getOperandTypes()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this loop is equivalent to

if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return type != resultType; }))
  return true; 

VectorType operandType = dyn_cast<VectorType>(operand.getType());
if (!operandType || operandType.getShape() != resultType.getShape()) {
return true;
}
}

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
return isLegal(layout);
});

target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
Expand Down
164 changes: 164 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s

gpu.module @test_elementwise_ops {
// CHECK-LABEL: unary_ops
gpu.func @unary_ops(%a: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
%exp = math.exp %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: binary_ops
gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%addf = arith.addf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: ternary_ops
gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1>
-> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_c = xegpu.load_nd %tdesc_c
: !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi1>
// CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
%select = arith.select %load_c, %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xi1>, vector<24x32xf32>
// CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%fma = math.fma %load_a, %load_b, %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: type_conversion_ops
gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi32>
// CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
%truncf = arith.truncf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32> to vector<24x32xf16>
// CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
%bitcast = arith.bitcast %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xi32> to vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: comparison_ops
gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_c = xegpu.load_nd %tdesc_c
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi32>
%load_d = xegpu.load_nd %tdesc_d
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi32>
// CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%cmpf = arith.cmpf ult, %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xi32>
%cmpi = arith.cmpi eq, %load_c, %load_d
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xi32>
gpu.return
}

// 1 to N decomposition of elementwise operations
// CHECK-LABEL: elementwise_ops_rr_assignment
gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
// CHECK-NOT: arith.negf
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
// CHECK-NOT: math.powf
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}
}