Skip to content

[GPU] Add MLP test and linalg.fill lowering in 'linalg-to-xegpu' #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/imex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if (NOT DEFINED IMEX_INCLUDES)

# TODO: Change to main https://github.com/intel/mlir-extensions when all the
# required functionality is merged.
gc_fetch_content(imex 496b240093b5e132b60c5ee69878300fe69be300 https://github.com/Menooker/mlir-extensions
gc_fetch_content(imex d5bbd635dee500b8cff138686833bacfac5ade78 https://github.com/Menooker/mlir-extensions
SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0
)

Expand Down
2 changes: 1 addition & 1 deletion include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
"DPAS register block sizes MxNxK">,
];
}
#endif
#endif // GC_USE_IMEX

def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
"func::FuncOp"> {
Expand Down
2 changes: 1 addition & 1 deletion lib/gc/ExecutionEngine/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else()
endif()

set(GC_PASSES GcInterface GcPasses)
if(GC_UNABLE_GPU)
if(GC_ENABLE_IMEX)
list(APPEND GC_PASSES GcGpuPasses)
endif()

Expand Down
34 changes: 29 additions & 5 deletions lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,35 @@ template <typename T> size_t countUntil(T *ptr, T &&elem) {
} // namespace

static cl_device_id getDevice(cl_device_type *devtype) {
cl_platform_id platform; // OpenCL platform
cl_device_id device; // device ID
CL_SAFE_CALL(clGetPlatformIDs(1, &platform, NULL));
CL_SAFE_CALL(clGetDeviceIDs(platform, *devtype, 1, &device, NULL));
return device;
cl_uint numPlatforms;
CL_SAFE_CALL(clGetPlatformIDs(0, nullptr, &numPlatforms)) // get num platforms

std::vector<cl_platform_id> platforms(numPlatforms);
CL_SAFE_CALL(clGetPlatformIDs(numPlatforms, platforms.data(),
nullptr)); // get available platforms

for (cl_uint i = 0; i < numPlatforms; ++i) {
// Get GPU device IDs for each platform
cl_uint numDevices;
cl_int status =
clGetDeviceIDs(platforms[i], *devtype, 0, /*devices.data()=*/nullptr,
&numDevices); // get num devices with 'devtype'
if (status != CL_SUCCESS) {
if (status == CL_DEVICE_NOT_FOUND) {
continue; // No GPU devices found on this platform
}
fprintf(stderr, "CL error %d @ line=%d (%s)\n", status, __LINE__,
"Error getting device IDs");
abort();
}

std::vector<cl_device_id> devices(numDevices);
clGetDeviceIDs(platforms[i], *devtype, numDevices, devices.data(), nullptr);
return devices[0];
}

fprintf(stderr, "No suitable devices found.");
abort();
}

struct GPUCLQUEUE {
Expand Down
132 changes: 121 additions & 11 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
Location loc, ValueRange tiles,
ArrayRef<int64_t> offsets) {
SmallVector<Value> updatedTiles;
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
std::vector<Value> dynOffsets;
for (auto &x : offsets) {
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x);
dynOffsets.push_back(offset);
}
ValueRange newOffsets{dynOffsets};
for (auto tile : tiles) {
auto updatedTile =
rewriter
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile,
/*offsets=*/ValueRange{}, offsets)
.getResult();
auto updatedTile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, tile.getType(), tile,
/*offsets=*/newOffsets,
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
updatedTiles.push_back(updatedTile);
}

Expand Down Expand Up @@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,

SmallVector<Value> tiles;
for (int i = 0; i < loadShape[0]; i += descTile[0]) {
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j})
/*offsets=*/ValueRange{newRowOffs, newColOffs},
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
tiles.push_back(tile);
}
Expand Down Expand Up @@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,

VectorType vecLoadType =
VectorType::get(tileType.getShape(), tileType.getElementType());
UnitAttr vnniAxisAttr = nullptr;
mlir::UnitAttr packedAttr = nullptr;
if (vnniConf) {
vnniAxisAttr = UnitAttr::get(rewriter.getContext());
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}

IntegerAttr transpose_bit = nullptr;
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {

auto loadOp = rewriter.create<xegpu::LoadNdOp>(
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr,
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
loadVec.push_back(loadOp);
Expand Down Expand Up @@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// Load A sub-tiles.
SmallVector<Value> loadVecA =
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

// Load B sub-tiles.
Expand Down Expand Up @@ -1371,6 +1388,88 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern<LinalgOpTy> {
LinalgToXeGPUOptions options;
};

// Create XeGPU kernel out of memory fill operation.
LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) {
Location loc = linalgOp.getLoc();
auto ctx = linalgOp.getContext();

auto scalar = linalgOp.getDpsInputs()[0];
auto output = linalgOp.getDpsInits()[0];
auto outputType = cast<ShapedType>(output.getType());
auto outputShape = outputType.getShape();

// Extract SIMD sized sub-tiles
int maxSizeSIMD = 256;
int64_t subTileCols = outputShape[1];
int64_t subTileRows = std::min(outputShape[0], maxSizeSIMD / subTileCols);

// Output descriptors for later stores.
SmallVector<Value> outputTiles = createDescriptorTiles(
rewriter, loc, output, outputShape, {0, 0}, {subTileRows, subTileCols});

SmallVector<Value> results;
for (size_t i = 0; i < outputTiles.size(); i++) {
// Operands are sub-tiles at the same location.
auto bcastType = VectorType::get({subTileRows, subTileCols},
outputType.getElementType());
auto res = rewriter.create<vector::BroadcastOp>(loc, bcastType, scalar);
if (!res)
return failure();

results.push_back(res.getResult());
}

// Store results.
auto writeCacheHint =
xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::WRITE_BACK);
for (size_t i = 0; i < outputTiles.size(); i++) {
rewriter.create<xegpu::StoreNdOp>(loc, results[i], outputTiles[i],
/*l1_hint=*/writeCacheHint,
/*l2_hint=*/writeCacheHint,
/*l3_hint=*/writeCacheHint);
}

rewriter.eraseOp(linalgOp);

return success();
}

// Convert a named fill operation to an XeGPU kernel.
template <typename LinalgOpTy>
struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;

ConvertMemoryFillToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options)
: OpRewritePattern<LinalgOpTy>(ctx), options(options) {}

LogicalResult matchAndRewrite(LinalgOpTy linalgOp,
PatternRewriter &rewriter) const override {
if (!linalgOp.hasPureBufferSemantics()) {
return rewriter.notifyMatchFailure(
linalgOp, "Linalg eltwise to GPU expects memref type");
}
if (linalgOp.hasDynamicShape()) {
return rewriter.notifyMatchFailure(
linalgOp, "Expect static shape when mapping to GPU");
}
auto isInputValid =
success(linalgOp.isScalar(linalgOp.getDpsInputOperand(0)));
if (failed(isInputValid))
return isInputValid;

auto isOutputValid =
isValidMemrefOperand(linalgOp, linalgOp.getDpsInits()[0], rewriter);
if (failed(isOutputValid))
return isOutputValid;

return createMemoryFillKernel(linalgOp, rewriter);
}

private:
LinalgToXeGPUOptions options;
};

// TODO: Finalize BRGEMM support and register the pattern.
void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns,
LinalgToXeGPUOptions options) {
Expand All @@ -1395,6 +1494,12 @@ void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns,
options);
}

void populateLinalgMemoryFillToXeGPUPatterns(RewritePatternSet &patterns,
LinalgToXeGPUOptions options) {
patterns.add<ConvertMemoryFillToXeGPU<linalg::FillOp>>(patterns.getContext(),
options);
}

struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
using LinalgToXeGPUBase::LinalgToXeGPUBase;

Expand All @@ -1406,6 +1511,11 @@ struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
populateLinalgGemmToXeGPUPatterns(gemmPatterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(gemmPatterns));

// Convert memory fill ops.
RewritePatternSet fillPatterns(&getContext());
populateLinalgMemoryFillToXeGPUPatterns(fillPatterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(fillPatterns));

// Convert other remaining ops.
RewritePatternSet patterns(&getContext());
populateLinalgEltwiseToXeGPUPatterns(patterns, options);
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) {
void populateLLVMPasses(mlir::OpPassManager &pm) {
pm.addPass(memref::createExpandOpsPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createLowerAffinePass());
populateLoweringToLLVMPasses(pm);
}

Expand Down
8 changes: 4 additions & 4 deletions test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Create output initial value load tiles.
// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]]

// Load initial accumulator values.
Expand All @@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Create input load tiles.
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]]

// Create DPAS computation loop over tiled reduction dimension.
Expand Down Expand Up @@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16>
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK-COUNT-3: vector.extract_strided_slice
Expand Down
Loading