Skip to content

[mlir][xegpu] XeGPU distribution patterns for load_nd, store_nd, and create_nd_tdesc. #119783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dab6841
[MLIR] Create GPU utils library & move distribution utils
kurapov-peter Dec 9, 2024
fe745c6
Merge remote-tracking branch 'petr_llvm/distribution-utils' into xegp…
charithaintc Dec 10, 2024
f6cd50a
pass added
charithaintc Dec 12, 2024
1c06920
fix
charithaintc Dec 12, 2024
9888c84
fix
charithaintc Dec 12, 2024
491625d
fix
charithaintc Dec 12, 2024
07f9f9f
fix
charithaintc Dec 12, 2024
b842f33
fix
charithaintc Dec 12, 2024
69cbc3b
fix
charithaintc Dec 12, 2024
b7cb16f
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Dec 13, 2024
e7ca3cd
fix
charithaintc Dec 13, 2024
8234edd
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Dec 13, 2024
2f4b748
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Jan 30, 2025
b443c71
sync
charithaintc Jan 30, 2025
6f11f3c
fix comments
charithaintc Jan 31, 2025
36c5b46
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Jan 31, 2025
7d1c7a6
add mem side effects interface
charithaintc Jan 31, 2025
eb7ee36
Merge branch 'xegpu-mem-effects' into xegpu-distribution-charitha
charithaintc Jan 31, 2025
263d72d
add mem side effects interface
charithaintc Jan 31, 2025
1b0bba7
add mem side effects interface
charithaintc Feb 3, 2025
91fa249
Merge branch 'main' into xegpu-mem-effects
charithaintc Feb 3, 2025
38ee43c
Merge branch 'xegpu-mem-effects' into xegpu-distribution-charitha
charithaintc Feb 3, 2025
ae2a3fe
Merge remote-tracking branch 'origin/main' into xegpu-distribution-ch…
charithaintc Feb 3, 2025
2d664e8
fix issues
charithaintc Feb 4, 2025
615f22d
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Feb 4, 2025
983dd4d
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Feb 5, 2025
4afbff9
fix comments
charithaintc Feb 5, 2025
48fc6d5
fix
charithaintc Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace xegpu {

/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
/// Patterns for distributing subgroup XeGPU ops to work items.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);

} // namespace xegpu
} // namespace mlir
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
XeGPUSubgroupDistribute.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
Expand All @@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRIR
MLIRMemRefDialect
MLIRXeGPUDialect
MLIRVectorDialect
MLIRVectorUtils
MLIRArithDialect
MLIRFuncDialect
MLIRPass
MLIRTransforms
)
353 changes: 353 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
//=- XeGPUSubgroupDistribute.cpp - ditribute XeGPU ops to work items *-C++-*-=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/Value.h"

#define DEBUG_TYPE "xegpu-distribute"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
Comment on lines +19 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks unused so could be removed


using namespace mlir;

namespace {
bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }

/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
/// `gpu.warp_execute_on_lane_0` and put it after the warp op. The warp op will
/// still contain the original op that will not be used by the yield op (and
/// should be cleaned up later with dce). The yield op will bypass the
/// create_nd_tdesc's arguments. Tensor descriptor is not distributed because it
/// is a uniform value accorss all work items within the subgroup.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// is a uniform value accorss all work items within the subgroup.
/// is a uniform value across all work items within the subgroup.

///
/// Example:
///
/// ```
/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (!xegpu.tensor_desc<4x8xf32>) {
/// ...
/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
/// vector.yield %td
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
/// vector.yield %arg0, %dead
/// }
/// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
/// -> !xegpu.tensor_desc<4x8xf32>
///
/// ```
struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const override;
};

/// Sink a store_nd op at the end of enclosing `gpu.warp_execute_on_lane_0`. In
/// case arguments for the store are passed through the warp op interface they
/// would be propagated as returned values. Only the source vector for the store
/// is distributed according to sg_map attribute.
///
/// Example:
///
/// ```
/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
/// !xegpu.tensor_desc<4x8xf32>
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
/// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32>
/// }
/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>,
/// !xegpu.tensor_desc<4x8xf32>
///
/// ```
struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const override;
};

/// Clone a load_nd feeding into vector.yield op for the enclosing
/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
/// The warp op will still contain the original op that will not be used by
/// the yield op (and should be cleaned up later with dce). The yield op will
/// bypass the load's arguments. Only the loaded vector is distributed according
/// to sg_map attribute and, tensor descriptor types is not distributed.
///
/// Example:
///
/// ```
/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (vector<4x1xf32>) {
/// ...
/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32> ->
/// vector<4x8xf32>
/// gpu.yield %ld
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32> ->
/// vector<4x8xf32> gpu.yield %arg0, %arg1
/// }
/// %ld = xegpu.load_nd %r#0: !xegpu.tensor_desc<4x8xf32> -> vector<4x1xf32>
///
/// ```
struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const override;
};

/// Returns the distributed vector type for a source vector type according to
/// the sg_map attribute.
FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
xegpu::SGMapAttr sgMap) {
llvm::SmallVector<int64_t, 2> distributedShape;
auto layout = sgMap.getWiLayout();
auto shape = originalT.getShape();
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you use more descriptive variable names?

if (!divisible(APInt(64, o), APInt(64, l)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to go through APInt for this?

return failure();
distributedShape.push_back(o / l);
}
auto newVectorType =
VectorType::get(distributedShape, originalT.getElementType(),
originalT.getScalableDims());
return newVectorType;
}

// Returns the distributed tensor descriptor type for a source tensor descriptor
// type according to the sg_map attribute. Note that tensor descriptor type is
// distributed only for the scattered case. For XeGPU ND operaions
// (create_nd_tdesc, load_nd, store_nd), tensor descriptor is considered uniform
// across all work items within the subgroup and therefore is not distributed.
FailureOr<xegpu::TensorDescType>
getDistributedTensorDescType(xegpu::TensorDescType originalT,
xegpu::SGMapAttr sgMap,
xegpu::MemorySpace memSpace) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take that memSpace is there to propagate it to the newly created type. If so, it's unused atm.
Anyway, does it need to be a separate argument at all or could it be taken directly from originalT?

llvm::SmallVector<int64_t, 2> distributedShape;
auto layout = sgMap.getWiLayout();
auto shape = originalT.getShape();
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
if (!divisible(APInt(64, o), APInt(64, l)))
return failure();
// Tensor descriptor is distributed only for the scattered case.
if (originalT.isScattered())
distributedShape.push_back(o / l);
else
distributedShape.push_back(o);
}

return xegpu::TensorDescType::get(
originalT.getContext(), distributedShape, originalT.getElementType(),
originalT.getEncoding(), originalT.getSGMapAttr());
}
} // namespace

LogicalResult
SubgroupOpStoreNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const {
auto yield = cast<gpu::YieldOp>(
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
return failure();

auto origType = storeOp.getTensorDescType();
xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
if (!sgMap)
return rewriter.notifyMatchFailure(
storeOp, "the source tensor descriptor lacks sg_map attribute");

if (storeOp.getTensorDescType().getShape().size() != 2)
return rewriter.notifyMatchFailure(storeOp, "unsupported shape");

auto distributedTypeOrFailure =
getDistributedVectorType(storeOp.getValueType(), sgMap);
if (failed(distributedTypeOrFailure))
return rewriter.notifyMatchFailure(storeOp,
"Failed to distribute the type");
VectorType newVectorType = distributedTypeOrFailure.value();

auto distributedDescTypeOrFailure = getDistributedTensorDescType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this at all? I think it TensorDesc can't be ever scattered for nd ops

storeOp.getTensorDescType(), sgMap,
storeOp.getTensorDescType().getMemorySpace());
if (failed(distributedDescTypeOrFailure))
return rewriter.notifyMatchFailure(storeOp,
"Failed to distribute the desc type");
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp,
ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
TypeRange{newTDescType, newVectorType}, newRetIndices);

rewriter.setInsertionPointAfter(newWarpOp);
auto newStoreOp =
cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
rewriter.eraseOp(storeOp);
newStoreOp.getTensorDescMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));

return success();
}

LogicalResult
SubgroupOpLoadNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const {
OpOperand *operand = getWarpResult(subgroupOp, [](Operation *op) {
return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
});

if (!operand)
return rewriter.notifyMatchFailure(subgroupOp,
"warp result is not a xegpu::LoadNd op");

auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();

if (loadOp.getPacked())
return rewriter.notifyMatchFailure(
loadOp, "Packed load distribution not supported");

xegpu::TensorDescType origType = loadOp.getTensorDescType();
xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
if (!sgMap)
return rewriter.notifyMatchFailure(
loadOp, "the source tensor descriptor lacks sg_map attribute");

auto origShape = origType.getShape();
if (origShape.size() != 2)
return rewriter.notifyMatchFailure(loadOp, "unsupported shape");

auto distributedTypeOrFailure =
getDistributedVectorType(loadOp.getType(), sgMap);
if (failed(distributedTypeOrFailure))
return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
VectorType newVectorType = distributedTypeOrFailure.value();

auto distributedDescTypeOrFailure =
getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
loadOp.getTensorDescType().getMemorySpace());
if (failed(distributedDescTypeOrFailure))
return rewriter.notifyMatchFailure(loadOp,
"Failed to distribute the desc type");
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();

unsigned operandIdx = operand->getOperandNumber();

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
newRetIndices);

rewriter.setInsertionPointAfter(newWarpOp);

auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
loadOp.getL2HintAttr(), loadOp.getL3HintAttr());

newLoadOp.getTensorDescMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newLoadOp);

return success();
}

LogicalResult
SubgroupOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
PatternRewriter &rewriter) const {
OpOperand *operand = getWarpResult(subgroupOp, [](Operation *op) {
return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
});

if (!operand)
return rewriter.notifyMatchFailure(
subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
assert(descOp && "desc op must be not null");
unsigned operandIdx = operand->getOperandNumber();

// TODO: is memref uniform in the region
rewriter.setInsertionPoint(subgroupOp);
auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
assert(srcTypedVal && "source value must be not null");

auto descOffsets = descOp.getMixedOffsets();
if (descOffsets.size() != 2)
return rewriter.notifyMatchFailure(descOp,
"offsets size is expected to be 2");

xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
if (!sgMap)
return rewriter.notifyMatchFailure(
descOp, "the tensor descriptor lacks sg_map attribute");

auto distributedDescTypeOrFailure = getDistributedTensorDescType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test case for this?

descOp.getType(), sgMap, descOp.getType().getMemorySpace());
if (failed(distributedDescTypeOrFailure))
return rewriter.notifyMatchFailure(descOp,
"Failed to distribute the desc type");
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
auto distributedShape = newTDescType.getShape();
// use the base memref strides
SmallVector<OpFoldResult> overwriteStrides =
getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
SmallVector<OpFoldResult> overwriteSizes =
getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, subgroupOp, descOp.getSource(), descOp.getSourceType(),
newRetIndices);

rewriter.setInsertionPointAfter(newWarpOp);
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
newWarpOp.getLoc(), newTDescType,
dyn_cast<TypedValue<MemRefType>>(newWarpOp.getResult(newRetIndices[0])),
descOffsets);

Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newDescOp);

return success();
}

void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
patterns.add<SubgroupOpTensorDescOp>(patterns.getContext());
patterns.add<SubgroupOpStoreNd>(patterns.getContext());
patterns.add<SubgroupOpLoadNd>(patterns.getContext());
}
Loading