Skip to content

[MLIR][XeGPU] Add unroll patterns for scatter ops #143602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 205 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,214 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};

struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];

TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();

SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the targetShape for indices should drop the last dim if chunkSize != 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave this to next PR.

SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);

SmallVector<Value> newOps;
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
op.getSource(), indice);
newOps.push_back(newOp);
}

Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this function called unpack when it is doing N:1? shouldn't it be a pack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that pack [m, n] to [m/bm, n/bn, bm, bn] so it is 1 to N. unpack does reverse so it is N to 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it follows pack/unpack definition in tensor dialect.

rewriter.replaceOp(op, castOp);

return success();
}
};

struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
return failure();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);

SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
auto newOp = rewriter.create<xegpu::LoadGatherOp>(
loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
newOps.push_back(newOp);
}

Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);

rewriter.replaceOp(op, castOp);
return success();
}
};

struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

for (auto t : convertedTdesc)
rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());

rewriter.eraseOp(op);
return success();
}
};

struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
return failure();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);

SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);

for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
Value t = convertedTdescs[i];
Value m = op.getMask() ? convertedMasks[i] : nullptr;
rewriter.create<xegpu::StoreScatterOp>(
loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
}

rewriter.eraseOp(op);
return success();
}
};

struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
VectorType offsetVecTy = offsetVec.getType();
SmallVector<Type> convertedOffsetTypes =
getUnrolledTypes(offsetVecTy, *targetShape);
SmallVector<Value> convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);

SmallVector<Value> newOps;
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
newOps.push_back(newOp);
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
return success();
}
};

} // namespace

void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
patterns.getContext(), options);
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
options);
}
141 changes: 141 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,145 @@ gpu.module @test {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}

//-----

// CHECK-LABEL: test_create_tdesc_vec
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
}

//-----

// CHECK-LABEL: test_create_tdesc_step
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
}

//-----

// CHECK-LABEL: test_load
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
gpu.func @test_load(%src: ui64) -> vector<32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
%ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>

gpu.return %ld : vector<32xf32>
}

//-----

// CHECK-LABEL: test_prefetch
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_prefetch(%src: ui64) {

%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>

xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
gpu.return
}

//-----

// CHECK-LABEL: test_store
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
gpu.func @test_store(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>

%st_vec = arith.constant dense<1023.0>: vector<32xf32>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>

gpu.return
}

//-----

// CHECK-LABEL: test_prefetch_load_store_update
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>

gpu.func @test_prefetch_load_store_update(%src: ui64) {

%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>

%delta = arith.constant dense<[
32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 64,
128, 128, 128, 128, 128, 128, 128, 128,
128, 128, 128, 128, 128, 128, 128, 256
]> : vector<32xindex>
%new_tdesc = xegpu.update_offset %tdesc, %delta
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>

%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>

%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
xegpu.store %st_vec, %tdesc, %mask:
vector<32xf32>,
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
vector<32xi1>

gpu.return
}
}
23 changes: 23 additions & 0 deletions mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ struct TestXeGPUUnrollingPatterns
}
}

if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
tdescTy = createOp.getType();
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
tdescTy = updateOp.getTensorDescType();
} else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
tdescTy = prefetchOp.getTensorDescType();
} else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
tdescTy = loadOp.getTensorDescType();
} else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
tdescTy = storeOp.getTensorDescType();
}

if (auto layout = tdescTy.getLayoutAttr()) {
auto inst_data = layout.getInstData();
if (inst_data && layout.isSgLayout())
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
inst_data.asArrayRef().end());
}
}

if (isa<xegpu::DpasOp>(op))
return SmallVector<int64_t>{8, 16, 16};

Expand Down