-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][XeGPU] Add unroll patterns for scatter ops #143602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bac8bc6
30b099e
8a0e145
c91156c
e532cbe
30cb8d8
f493e52
2606a4b
a3e064d
b383194
75a65ea
402c015
02e6e67
2845237
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,11 +396,214 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> { | |
} | ||
}; | ||
|
||
struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> { | ||
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern; | ||
LogicalResult matchAndRewrite(xegpu::CreateDescOp op, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
xegpu::TensorDescType tdescTy = op.getType(); | ||
|
||
// check if the tensor descriptor type is a 1d vector type | ||
if (tdescTy.getRank() > 1) | ||
return failure(); | ||
|
||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); | ||
if (!targetShape) | ||
return failure(); | ||
|
||
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; | ||
|
||
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets(); | ||
VectorType indiceVecTy = indiceVec.getType(); | ||
|
||
SmallVector<Type> convertedIndiceTypes = | ||
getUnrolledTypes(indiceVecTy, *targetShape); | ||
SmallVector<Value> convertedIndiceVec = | ||
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter); | ||
|
||
SmallVector<Value> newOps; | ||
for (auto indice : convertedIndiceVec) { | ||
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, | ||
op.getSource(), indice); | ||
newOps.push_back(newOp); | ||
} | ||
|
||
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this function called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my understanding is that pack [m, n] to [m/bm, n/bn, bm, bn] so it is 1 to N. unpack does reverse so it is N to 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it follows pack/unpack definition in tensor dialect. |
||
rewriter.replaceOp(op, castOp); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> { | ||
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern; | ||
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
Location loc = op.getLoc(); | ||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); | ||
xegpu::TensorDescType tdescTy = op.getTensorDescType(); | ||
|
||
// check if the tensor descriptor type is a 1d vector type | ||
if (tdescTy.getRank() > 1) | ||
return failure(); | ||
|
||
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType()); | ||
|
||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); | ||
if (!targetShape) | ||
return failure(); | ||
|
||
Type elemTy = tdescTy.getElementType(); | ||
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); | ||
|
||
SmallVector<Type> convertedTdescTypes = | ||
getUnrolledTypes(tdescTy, *targetShape); | ||
SmallVector<Value> convertedTdescs = pack( | ||
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); | ||
|
||
SmallVector<Type> convertedMaskTypes = | ||
getUnrolledTypes(maskTy, *targetShape); | ||
SmallVector<Value> convertedMasks = | ||
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter); | ||
|
||
SmallVector<Value> newOps; | ||
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) { | ||
auto newOp = rewriter.create<xegpu::LoadGatherOp>( | ||
loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(), | ||
op.getL2HintAttr(), op.getL3HintAttr()); | ||
newOps.push_back(newOp); | ||
} | ||
|
||
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); | ||
|
||
rewriter.replaceOp(op, castOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> { | ||
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern; | ||
LogicalResult matchAndRewrite(xegpu::PrefetchOp op, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
xegpu::TensorDescType tdescTy = op.getTensorDescType(); | ||
|
||
// check if the tensor descriptor type is a 1d vector type | ||
if (tdescTy.getRank() > 1) | ||
return failure(); | ||
|
||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); | ||
if (!targetShape) | ||
return failure(); | ||
|
||
SmallVector<Type> convertedTdescTypes = | ||
getUnrolledTypes(tdescTy, *targetShape); | ||
SmallVector<Value> convertedTdesc = pack( | ||
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); | ||
|
||
for (auto t : convertedTdesc) | ||
rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs()); | ||
|
||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> { | ||
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern; | ||
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
Location loc = op.getLoc(); | ||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); | ||
xegpu::TensorDescType tdescTy = op.getTensorDescType(); | ||
|
||
// check if the tensor descriptor type is a 1d vector type | ||
if (tdescTy.getRank() > 1) | ||
return failure(); | ||
|
||
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType()); | ||
|
||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); | ||
if (!targetShape) | ||
return failure(); | ||
|
||
SmallVector<Type> convertedValTypes = | ||
getUnrolledTypes(valueTy, *targetShape); | ||
SmallVector<Type> convertedTdescTypes = | ||
getUnrolledTypes(tdescTy, *targetShape); | ||
|
||
SmallVector<Value> convertedValues = | ||
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); | ||
SmallVector<Value> convertedTdescs = pack( | ||
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); | ||
|
||
SmallVector<Type> convertedMaskTypes = | ||
getUnrolledTypes(maskTy, *targetShape); | ||
SmallVector<Value> convertedMasks = | ||
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter); | ||
|
||
for (size_t i = 0; i < convertedValues.size(); ++i) { | ||
Value v = convertedValues[i]; | ||
Value t = convertedTdescs[i]; | ||
Value m = op.getMask() ? convertedMasks[i] : nullptr; | ||
rewriter.create<xegpu::StoreScatterOp>( | ||
loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(), | ||
op.getL2HintAttr(), op.getL3HintAttr()); | ||
} | ||
|
||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> { | ||
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern; | ||
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
xegpu::TensorDescType tdescTy = op.getTensorDescType(); | ||
|
||
// check if the tensor descriptor type is a 1d vector type | ||
if (tdescTy.getRank() > 1) | ||
return failure(); | ||
|
||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); | ||
if (!targetShape) | ||
return failure(); | ||
|
||
SmallVector<Type> convertedTdescTypes = | ||
getUnrolledTypes(tdescTy, *targetShape); | ||
SmallVector<Value> convertedTdesc = pack( | ||
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); | ||
|
||
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets(); | ||
VectorType offsetVecTy = offsetVec.getType(); | ||
SmallVector<Type> convertedOffsetTypes = | ||
getUnrolledTypes(offsetVecTy, *targetShape); | ||
SmallVector<Value> convertedOffsetVec = | ||
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter); | ||
|
||
SmallVector<Value> newOps; | ||
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) { | ||
auto newOp = | ||
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o); | ||
newOps.push_back(newOp); | ||
} | ||
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); | ||
rewriter.replaceOp(op, castOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::xegpu::populateXeGPUUnrollPatterns( | ||
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { | ||
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp, | ||
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>( | ||
patterns.getContext(), options); | ||
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, | ||
UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp, | ||
UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(), | ||
options); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the targetShape for indices should drop the last dim if chunkSize != 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave this to next PR.