Skip to content

Commit 02d34d8

Browse files
authored
[mlir][vector][xegpu] Vector to XeGPU conversion pass (#107419)
Add pass for Vector to XeGPU dialect conversion and initial conversion patterns for vector.transfer_read|write operations.
1 parent 3191587 commit 02d34d8

File tree

8 files changed

+677
-0
lines changed

8 files changed

+677
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
7979
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
8080
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
81+
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
8182

8283
namespace mlir {
8384

mlir/include/mlir/Conversion/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,4 +1421,18 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
14211421
let dependentDialects = ["spirv::SPIRVDialect"];
14221422
}
14231423

1424+
//===----------------------------------------------------------------------===//
1425+
// VectorToXeGPU
1426+
//===----------------------------------------------------------------------===//
1427+
1428+
def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
1429+
let summary = "Lower the operations from the vector dialect into the XeGPU "
1430+
"dialect";
1431+
let constructor = "mlir::createConvertVectorToXeGPUPass()";
1432+
let dependentDialects = [
1433+
"memref::MemRefDialect", "arith::ArithDialect",
1434+
"vector::VectorDialect", "xegpu::XeGPUDialect"
1435+
];
1436+
}
1437+
14241438
#endif // MLIR_CONVERSION_PASSES
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- VectorToXeGPU.h - Convert vector to XeGPU dialect --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
10+
#define MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H
11+
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
namespace mlir {
15+
class Pass;
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_CONVERTVECTORTOXEGPU
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Collect a set of patterns to convert from the vector to XeGPU ops.
22+
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns);
23+
24+
/// Create a pass to convert ops from vector to XeGPU.
25+
std::unique_ptr<Pass> createConvertVectorToXeGPUPass();
26+
27+
} // namespace mlir
28+
29+
#endif // MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,4 @@ add_subdirectory(VectorToGPU)
6969
add_subdirectory(VectorToLLVM)
7070
add_subdirectory(VectorToSCF)
7171
add_subdirectory(VectorToSPIRV)
72+
add_subdirectory(VectorToXeGPU)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_conversion_library(MLIRVectorToXeGPU
2+
VectorToXeGPU.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToXeGPU
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRArithDialect
12+
MLIRMemRefDialect
13+
MLIRTransforms
14+
MLIRVectorDialect
15+
MLIRXeGPUDialect
16+
)
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements lowering of vector operations to XeGPU dialect ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "mlir/Transforms/Passes.h"
22+
#include "llvm/ADT/TypeSwitch.h"
23+
24+
#include <algorithm>
25+
#include <optional>
26+
27+
namespace mlir {
28+
#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29+
#include "mlir/Conversion/Passes.h.inc"
30+
} // namespace mlir
31+
32+
using namespace mlir;
33+
34+
namespace {
35+
36+
static bool isZeroConstant(Value val) {
37+
auto constant = val.getDefiningOp<arith::ConstantOp>();
38+
if (!constant)
39+
return false;
40+
41+
return TypeSwitch<Attribute, bool>(constant.getValue())
42+
.Case<FloatAttr>(
43+
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
44+
.Case<IntegerAttr>(
45+
[](auto intAttr) { return intAttr.getValue().isZero(); })
46+
.Default([](auto) { return false; });
47+
}
48+
49+
static LogicalResult transferPreconditions(PatternRewriter &rewriter,
50+
VectorTransferOpInterface xferOp) {
51+
if (xferOp.getMask())
52+
return rewriter.notifyMatchFailure(xferOp,
53+
"Masked transfer is not supported");
54+
55+
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
56+
if (!srcTy)
57+
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
58+
VectorType vecTy = xferOp.getVectorType();
59+
unsigned vecRank = vecTy.getRank();
60+
if (!(vecRank == 1 || vecRank == 2))
61+
return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
62+
63+
SmallVector<int64_t> strides;
64+
int64_t offset;
65+
if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
66+
strides.back() != 1)
67+
return rewriter.notifyMatchFailure(
68+
xferOp, "Buffer must be contiguous in the innermost dimension");
69+
70+
AffineMap map = xferOp.getPermutationMap();
71+
if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
72+
return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
73+
unsigned numInputDims = map.getNumInputs();
74+
for (AffineExpr expr : map.getResults().take_back(vecRank)) {
75+
auto dim = dyn_cast<AffineDimExpr>(expr);
76+
if (dim.getPosition() < (numInputDims - vecRank))
77+
return rewriter.notifyMatchFailure(
78+
xferOp, "Only the innermost dimensions can be accessed");
79+
}
80+
81+
return success();
82+
}
83+
84+
static xegpu::CreateNdDescOp
85+
createNdDescriptor(PatternRewriter &rewriter, Location loc,
86+
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
87+
Operation::operand_range offsets) {
88+
MemRefType srcTy = src.getType();
89+
auto [strides, offset] = getStridesAndOffset(srcTy);
90+
91+
xegpu::CreateNdDescOp ndDesc;
92+
if (srcTy.hasStaticShape()) {
93+
ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
94+
getAsOpFoldResult(offsets));
95+
} else {
96+
// In case of any dynamic shapes, source's shape and strides have to be
97+
// explicitly provided.
98+
SmallVector<Value> sourceDims;
99+
unsigned srcRank = srcTy.getRank();
100+
for (unsigned i = 0; i < srcRank; ++i)
101+
sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
102+
103+
SmallVector<int64_t> constOffsets;
104+
SmallVector<Value> dynOffsets;
105+
for (Value offset : offsets) {
106+
std::optional<int64_t> staticVal = getConstantIntValue(offset);
107+
if (!staticVal)
108+
dynOffsets.push_back(offset);
109+
constOffsets.push_back(staticVal ? *staticVal : ShapedType::kDynamic);
110+
}
111+
112+
SmallVector<Value> dynShapes;
113+
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
114+
if (shape == ShapedType::kDynamic)
115+
dynShapes.push_back(sourceDims[idx]);
116+
}
117+
118+
// Compute strides in reverse order.
119+
SmallVector<Value> dynStrides;
120+
Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
121+
// Last stride is guaranteed to be static and unit.
122+
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
123+
accStride =
124+
rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
125+
if (strides[i] == ShapedType::kDynamic)
126+
dynStrides.push_back(accStride);
127+
}
128+
std::reverse(dynStrides.begin(), dynStrides.end());
129+
130+
ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
131+
loc, descType, src, dynOffsets, dynShapes, dynStrides,
132+
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
133+
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
134+
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
135+
}
136+
137+
return ndDesc;
138+
}
139+
140+
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
141+
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
142+
143+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
144+
PatternRewriter &rewriter) const override {
145+
Location loc = readOp.getLoc();
146+
147+
if (failed(transferPreconditions(rewriter, readOp)))
148+
return failure();
149+
150+
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
151+
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
152+
return rewriter.notifyMatchFailure(
153+
readOp, "Unsupported non-zero padded out-of-bounds read");
154+
155+
AffineMap readMap = readOp.getPermutationMap();
156+
bool isTransposeLoad = !readMap.isMinorIdentity();
157+
158+
VectorType vecTy = readOp.getVectorType();
159+
Type elementType = vecTy.getElementType();
160+
unsigned minTransposeBitWidth = 32;
161+
if (isTransposeLoad &&
162+
elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
163+
return rewriter.notifyMatchFailure(
164+
readOp, "Unsupported data type for tranposition");
165+
166+
// If load is transposed, get the base shape for the tensor descriptor.
167+
SmallVector<int64_t> descShape{vecTy.getShape()};
168+
if (isTransposeLoad)
169+
std::reverse(descShape.begin(), descShape.end());
170+
auto descType = xegpu::TensorDescType::get(
171+
descShape, elementType, /*scattered=*/false, /*array_length=*/1,
172+
xegpu::MemoryScope::Global,
173+
/*boundary_check=*/isOutOfBounds);
174+
175+
xegpu::CreateNdDescOp ndDesc =
176+
createNdDescriptor(rewriter, loc, descType,
177+
dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
178+
readOp.getIndices());
179+
180+
DenseI64ArrayAttr transposeAttr =
181+
!isTransposeLoad ? nullptr
182+
: DenseI64ArrayAttr::get(rewriter.getContext(),
183+
ArrayRef<int64_t>{1, 0});
184+
// By default, no specific caching policy is assigned.
185+
xegpu::CachePolicyAttr hint = nullptr;
186+
auto loadOp = rewriter.create<xegpu::LoadNdOp>(
187+
loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
188+
/*l1_hint=*/hint,
189+
/*l2_hint=*/hint, /*l3_hint=*/hint);
190+
rewriter.replaceOp(readOp, loadOp);
191+
192+
return success();
193+
}
194+
};
195+
196+
struct TransferWriteLowering
197+
: public OpRewritePattern<vector::TransferWriteOp> {
198+
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
199+
200+
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
201+
PatternRewriter &rewriter) const override {
202+
Location loc = writeOp.getLoc();
203+
204+
if (failed(transferPreconditions(rewriter, writeOp)))
205+
return failure();
206+
207+
if (writeOp.hasOutOfBoundsDim())
208+
return rewriter.notifyMatchFailure(writeOp,
209+
"Unsupported out-of-bounds write");
210+
AffineMap map = writeOp.getPermutationMap();
211+
if (!map.isMinorIdentity())
212+
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
213+
214+
VectorType vecTy = writeOp.getVectorType();
215+
auto descType = xegpu::TensorDescType::get(
216+
vecTy.getShape(), vecTy.getElementType(),
217+
/*scattered=*/false, /*array_length=*/1, xegpu::MemoryScope::Global,
218+
/*boundary_check=*/false);
219+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
220+
rewriter, loc, descType,
221+
dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
222+
writeOp.getIndices());
223+
224+
// By default, no specific caching policy is assigned.
225+
xegpu::CachePolicyAttr hint = nullptr;
226+
auto storeOp =
227+
rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
228+
/*l1_hint=*/hint,
229+
/*l2_hint=*/hint, /*l3_hint=*/hint);
230+
rewriter.replaceOp(writeOp, storeOp);
231+
232+
return success();
233+
}
234+
};
235+
236+
struct ConvertVectorToXeGPUPass
237+
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
238+
void runOnOperation() override {
239+
RewritePatternSet patterns(&getContext());
240+
populateVectorToXeGPUConversionPatterns(patterns);
241+
if (failed(
242+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
243+
return signalPassFailure();
244+
}
245+
};
246+
247+
} // namespace
248+
249+
void mlir::populateVectorToXeGPUConversionPatterns(
250+
RewritePatternSet &patterns) {
251+
patterns.add<TransferReadLowering, TransferWriteLowering>(
252+
patterns.getContext());
253+
}
254+
255+
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
256+
return std::make_unique<ConvertVectorToXeGPUPass>();
257+
}

0 commit comments

Comments
 (0)