Skip to content

[mlir][xegpu] Add SIMT distribution patterns for UpdateNdOffset and PrefetchNd ops. #138033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}

def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
[AllTypesMatch<["TensorDesc", "result"]>]> {
[Pure, AllTypesMatch<["TensorDesc", "result"]>]> {
let summary = "It updates the offsets for the TensorDesc.";
let description = [{The op updates the offset of the given TensorDesc.
The offsets are relative offset to the current position in the number
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return scatter_attr.getChunkSize().getInt();
return 1;
}

/// Helper to drop all layout information from the TensorDesc type.
TensorDescType dropLayouts() {
if (!getLayoutAttr())
return *this;

return get(getContext(), getShape(), getElementType(), getEncoding(),
xegpu::LayoutAttr());
}
}];

let hasCustomAssemblyFormat = true;
Expand Down
223 changes: 189 additions & 34 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);

void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);

void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
Expand Down Expand Up @@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
})
.Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
visitPrefetchNdOp(prefetchNdOp, operands, results);
})
// No need to propagate the layout to operands in CreateNdDescOp because
// they are scalars (offsets, sizes, etc.).
.Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
Expand Down Expand Up @@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}

void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Here we assign the default layout to the tensor descriptor operand of
// prefetch.
auto tdescTy = prefetch.getTensorDescType();
auto prefetchLayout = getDefaultLayoutInfo(
VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}

void LayoutInfoPropagation::visitVectorMultiReductionOp(
vector::MultiDimReductionOp reduction,
ArrayRef<LayoutInfoLattice *> operands,
Expand Down Expand Up @@ -865,18 +884,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
return VectorType::get(distributedShape, originalType.getElementType());
}

// Drop the layout attribute from the tensor descriptor type if layout is
// present.
static xegpu::TensorDescType dropLayouts(xegpu::TensorDescType tensorDesc) {
if (tensorDesc.getLayoutAttr() == xegpu::LayoutAttr())
return tensorDesc;

return xegpu::TensorDescType::get(
tensorDesc.getContext(), tensorDesc.getShape(),
tensorDesc.getElementType(), tensorDesc.getEncoding(),
xegpu::LayoutAttr());
}

/// Helper function to resolve types if the distributed type out of
/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
/// Example 1:
Expand Down Expand Up @@ -1023,12 +1030,12 @@ struct MoveFuncBodyToWarpExecuteOnLane0
/// Example:
///
/// ```
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (!xegpu.tensor_desc<4x8xf32, #lo0>) {
/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
/// ...
/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0>
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
/// vector.yield %td
/// }
/// ```
Expand All @@ -1037,7 +1044,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
/// ...
/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0>
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
/// vector.yield %arg0, %dead
/// }
/// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
Expand Down Expand Up @@ -1080,8 +1087,8 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
}
rewriter.setInsertionPointAfter(newWarpOp);
xegpu::TensorDescType distributedTensorDescTy =
dropLayouts(descOp.getType()); // Distributed tensor descriptor type
// does not contain layout info.
descOp.getType().dropLayouts(); // Distributed tensor descriptor type
// does not contain layout info.
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
descOp->getAttrs());
Expand All @@ -1101,23 +1108,23 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
/// Example:
///
/// ```
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
/// !xegpu.tensor_desc<4x8xf32, #lo0>
/// !xegpu.tensor_desc<4x8xf32, #layout0>
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
/// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
/// #lo0>
/// #layout0>
/// }
/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
/// #lo0>
/// #layout0>
/// -> !xegpu.tensor_desc<4x8xf32>
/// xegpu.store_nd %0, %1: vector<4xf32>,
/// !xegpu.tensor_desc<4x8xf32>
Expand Down Expand Up @@ -1173,10 +1180,10 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
newStoreOperands.push_back(resolveDistributedTy(
newWarpOp.getResult(newRetIndices[0]),
storeNdDistributedValueTyOrFailure.value(), rewriter));
// For the tensor descriptor operand, the layout attibute is dropped after
// For the tensor descriptor operand, the layout attribute is dropped after
// distribution. Types needs to be resolved in this case also.
xegpu::TensorDescType distributedTensorDescTy =
dropLayouts(storeOp.getTensorDescType());
storeOp.getTensorDescType().dropLayouts();
newStoreOperands.push_back(
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
distributedTensorDescTy, rewriter));
Expand All @@ -1201,25 +1208,26 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
/// Example:
///
/// ```
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (vector<4x1xf32>) {
/// ...
/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #lo0> ->
/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
/// ->
/// vector<4x8xf32>
/// gpu.yield %ld
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
/// ...
/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #lo0> ->
/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
/// vector<4x8xf32> gpu.yield %dead, %arg0
/// }
/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
///
Expand Down Expand Up @@ -1260,9 +1268,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
return rewriter.notifyMatchFailure(
loadOp, "Failed to get distributed vector type for the load op");
xegpu::TensorDescType distributedTensorDescTy =
dropLayouts(loadOp.getTensorDescType()); // Distributed tensor
// descriptor type does not
// contain layout info.
loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
// descriptor type does not
// contain layout info.
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
Expand Down Expand Up @@ -1412,6 +1420,152 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
}
};

/// Sink an update_nd_offset op feeding into yield op of an enclosing
/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
/// original op that will not be used by the yield op (and should be cleaned
/// up later). The yield op will bypass the updateOp's arguments. The tensor
/// descriptor type is not distributed. Appropriate cast ops are inserted if
/// the distributed types does not match expected xegpu SIMT types.
/// Example:
/// ```
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
/// ...
/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
/// !xegpu.tensor_desc<4x8xf32, #layout0>
/// gpu.yield %update
/// }
/// ...
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
/// !xegpu.tensor_desc<4x8xf32, #layout0>,
/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
/// ...
/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
/// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
/// gpu.yield %dead, %arg0, %c32, %c16
/// }
/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
/// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
/// !xegpu.tensor_desc<4x8xf32>
/// ...
/// ```
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
if (!operand)
return rewriter.notifyMatchFailure(
subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
unsigned operandIdx = operand->getOperandNumber();
// new update op does not have layout attribute.
xegpu::TensorDescType newTensorDescTy =
updateOp.getTensorDescType().dropLayouts();

SmallVector<Value, 3> newYieldValues;
SmallVector<Type, 3> newYieldTypes;
for (Value operand : updateOp->getOperands()) {
newYieldValues.push_back(operand);
if (isa<xegpu::TensorDescType>(operand.getType())) {
newYieldTypes.push_back(newTensorDescTy);
} else {
newYieldTypes.push_back(operand.getType());
}
}
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newUpdateOperands;
for (size_t i : newRetIndices) {
// For the tensor descriptor operand, the layout attribute is dropped
// after distribution. Types needs to be resolved in this case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does resolve mean here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in XeGPU SIMT code, the layout is dropped in TensorDesc. But upstream warpOp distribution infra does not understand TensorDesc type. So it will still return the desc with the layouts. resolve will add unrealized_cast to go from desc with layout to without.

Example:

/// ```
///   #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
///   gpu.warp_execute_on_lane_0(%laneid) -> () {
///     ...
///     xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
///   }
/// ```
/// To
/// ```
///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
//    !xegpu.tensor_desc<4x8xf32, #layout0>) {
///     gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
///   }
///   %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
///     #layout0> -> !xegpu.tensor_desc<4x8xf32>
///   xegpu.prefetch_nd %0 : !xegpu.tensor_desc<4x8xf32>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If updateOp is not used anymore, is it safe to drop all of its uses and remove it from the warp op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done by the upstream pattern warpDeadResult, so we don't need to do anything here. Its clean separation of concerns.

if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
newUpdateOperands.push_back(resolveDistributedTy(
newWarpOp.getResult(i), newTensorDescTy, rewriter));
} else {
newUpdateOperands.push_back(newWarpOp.getResult(i));
}
}
// Create a new update op outside the warp op.
auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
removeTemporaryLayoutAttributes(updateOp->getAttrs()));
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
return success();
}
};

/// Distribute a prefetch_nd op at the end of enclosing
/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
/// through the warp op interface they would be propagated as returned values.
/// Tensor descriptor shape is not distributed because it is a uniform value
/// across all work items within the subgroup. Appropriate cast ops are inserted
/// if the distributed types does not match expected xegpu SIMT types.
///
/// Example:
///
/// ```
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
/// }
/// ```
/// To
/// ```
/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
/// }
/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
/// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
///
/// ```
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
return failure();
xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
if (!layout)
return rewriter.notifyMatchFailure(
prefetchOp, "the source tensor descriptor lacks layout attribute");

SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
// Create a new prefetch op outside the warp op with updated tensor
// descriptor type. Source tensor descriptor require type resolution.
xegpu::TensorDescType newTensorDescTy =
prefetchOp.getTensorDescType().dropLayouts();
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
rewriter.create<xegpu::PrefetchNdOp>(
newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
rewriter.eraseOp(prefetchOp);
return success();
}
};

} // namespace

namespace {
Expand All @@ -1430,7 +1584,8 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
LoadNdDistribution, DpasDistribution>(patterns.getContext());
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
UpdateNdOffsetDistribution>(patterns.getContext());
}

void XeGPUSubgroupDistributePass::runOnOperation() {
Expand Down
Loading
Loading