Skip to content

[MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size #144447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 133 additions & 43 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];

TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
if (originalChunkSize > 1)
targetIndiceShape.pop_back();

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
getUnrolledTypes(indiceVecTy, targetIndiceShape);
SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);

SmallVector<Value> newOps;
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
op.getSource(), indice);
newOps.push_back(newOp);

// More indices is need when chunkSize > 1. Since a big load from one
// address could be break into multiple small loads.
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;

for (auto [indice, indiceType] :
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
for (int64_t i = 0; i < numNewChunks; ++i) {
// Compute the offset
Value inc = rewriter.create<arith::ConstantIndexOp>(
loc, i * blockedChunkSize);
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
Value offsetIndice =
rewriter.create<arith::AddIOp>(loc, indice, incVec);

auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), offsetIndice);

newOps.push_back(newOp);
}
}
} else {
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), indice);
newOps.push_back(newOp);
}
}

Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
Expand All @@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<int64_t> targetMaskShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);

Expand All @@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;

if (originalChunkSize > 1) {
targetMaskShape.pop_back();
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
SmallVector<Value> convertedMasks1D = pack(
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;

for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i)
convertedMasks.push_back(mask);
}
// This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
loc, rewriter);
}

SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
Expand All @@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}

Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);

rewriter.replaceOp(op, castOp);
return success();
}
Expand All @@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand Down Expand Up @@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (!tdescTy.isScattered())
return failure();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();

VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);

SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;

if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
SmallVector<Value> convertedMasks1D = pack(
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);

for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
// This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);

} else {
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
}

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);

for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
Expand All @@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();

// check if the tensor descriptor type is a 1d vector type
if (tdescTy.getRank() > 1)
if (tdescTy.getRank() > 2)
return failure();

if (!tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand All @@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {

TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
VectorType offsetVecTy = offsetVec.getType();
SmallVector<Type> convertedOffsetTypes =
getUnrolledTypes(offsetVecTy, *targetShape);
SmallVector<Value> convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);

SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
int64_t originalChunkSize = tdescTy.getChunkSize();
if (originalChunkSize > 1) {
SmallVector<int64_t> shape1D(targetShape->begin(),
targetShape->end() - 1);
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
SmallVector<Value> convertedOffsetVec1D =
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);

int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = originalChunkSize / blockedChunkSize;

for (auto offset : convertedOffsetVec1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedOffsetVec.push_back(offset);
}
}

} else {
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
convertedOffsetVec =
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
}

for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
Expand Down
Loading