Skip to content

Commit d9bbc85

Browse files
authored
[GPU] Register initial GPU pipeline that uses IMEX (#329)
Signed-off-by: dchigarev <[email protected]>
1 parent bd49e3a commit d9bbc85

File tree

9 files changed

+275
-99
lines changed

9 files changed

+275
-99
lines changed

include/gc/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ std::unique_ptr<Pass> createMergeAllocPass();
115115
void populateFrontendPasses(mlir::OpPassManager &);
116116
void populateCPUPipeline(mlir::OpPassManager &);
117117

118+
#ifdef GC_USE_IMEX
119+
void populateGPUPipeline(mlir::OpPassManager &);
120+
#endif
121+
118122
#define GEN_PASS_DECL
119123
#include "gc/Transforms/Passes.h.inc"
120124

lib/gc/Transforms/GPU/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
gc_add_mlir_library(GcGpuPasses
22
LinalgToXeGPU.cpp
3+
Pipeline.cpp
34

45
DEPENDS
56
GraphCompilerPassIncGen
@@ -18,3 +19,7 @@ gc_add_mlir_library(GcGpuPasses
1819
GcUtilsIR
1920
)
2021

22+
include(imex)
23+
get_property(IMEX_INCLUDES GLOBAL PROPERTY IMEX_INCLUDES)
24+
target_include_directories(GcGpuPasses PRIVATE ${IMEX_INCLUDES})
25+

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,17 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
14051405
auto outputType = cast<ShapedType>(output.getType());
14061406
auto outputShape = outputType.getShape();
14071407

1408+
if (outputShape.size() != 2) {
1409+
return rewriter.notifyMatchFailure(
1410+
linalgOp, "Memory fill operation expects 2D output");
1411+
}
1412+
1413+
// Otherwise 'xegpu-to-vc' pass will fail to convert it to VC
1414+
if (outputShape[0] * outputShape[1] < 16) {
1415+
return rewriter.notifyMatchFailure(
1416+
linalgOp, "Memory fill operation is to small to be converted to xegpu");
1417+
}
1418+
14081419
// Extract SIMD sized sub-tiles
14091420
int maxSizeSIMD = 256;
14101421
int64_t subTileCols = outputShape[1];

lib/gc/Transforms/GPU/Pipeline.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===- Pipeline.cpp - Graph Compiler GPU pipeline ---------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/Passes.h"
10+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
11+
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
12+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
15+
#include "mlir/Dialect/Linalg/Passes.h"
16+
#include "mlir/Dialect/Math/Transforms/Passes.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
19+
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/IR/DialectRegistry.h"
22+
#include "mlir/InitAllPasses.h"
23+
#include "mlir/Pass/PassManager.h"
24+
#include "mlir/Support/LogicalResult.h"
25+
#include "mlir/Transforms/Passes.h"
26+
#include <iostream>
27+
28+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
29+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
30+
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
31+
32+
#include <imex/Conversion/Passes.h>
33+
#include <imex/Transforms/Passes.h>
34+
35+
#include <string>
36+
37+
#include "gc/Transforms/Passes.h"
38+
39+
namespace mlir::gc {
40+
41+
void populateGPUPipeline(mlir::OpPassManager &pm) {
42+
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
43+
44+
pm.addPass(bufferization::createEmptyTensorEliminationPass());
45+
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
46+
47+
bufferization::OneShotBufferizationOptions options;
48+
options.bufferizeFunctionBoundaries = true;
49+
options.setFunctionBoundaryTypeConversion(
50+
bufferization::LayoutMapOption::IdentityLayoutMap);
51+
pm.addPass(bufferization::createOneShotBufferizePass(options));
52+
53+
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
54+
pm.addNestedPass<func::FuncOp>(
55+
bufferization::createFinalizingBufferizePass());
56+
pm.addPass(createCanonicalizerPass());
57+
pm.addPass(createCSEPass());
58+
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
59+
pm.addPass(memref::createExpandReallocPass());
60+
pm.addPass(createCanonicalizerPass());
61+
pm.addPass(bufferization::createOwnershipBasedBufferDeallocationPass());
62+
pm.addPass(createCanonicalizerPass());
63+
pm.addPass(bufferization::createBufferDeallocationSimplificationPass());
64+
pm.addPass(bufferization::createLowerDeallocationsPass());
65+
pm.addPass(createCSEPass());
66+
pm.addPass(createCanonicalizerPass());
67+
pm.addPass(createBufferizationToMemRefPass());
68+
69+
pm.addNestedPass<func::FuncOp>(createForallToParallelLoopPass());
70+
pm.addNestedPass<func::FuncOp>(createLinalgToXeGPU(
71+
{/*kTile=*/16, /*stages=*/1, /*dpasTiles=*/{8, 16, 16}}));
72+
73+
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
74+
pm.addPass(xegpu::createXeGPUFoldAliasOps());
75+
pm.addPass(memref::createFoldMemRefAliasOpsPass());
76+
pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass());
77+
pm.addNestedPass<func::FuncOp>(createParallelLoopToGpuPass());
78+
79+
pm.addNestedPass<func::FuncOp>(imex::createInsertGPUAllocsPass("opencl"));
80+
pm.addPass(createGpuKernelOutliningPass());
81+
pm.addPass(createCanonicalizerPass());
82+
pm.addPass(imex::createSetSPIRVCapabilitiesPass());
83+
pm.addNestedPass<gpu::GPUModuleOp>(
84+
imex::createSetSPIRVAbiAttributePass("opencl"));
85+
pm.addPass(createLowerAffinePass());
86+
pm.addPass(imex::createVectorLinearizePass());
87+
pm.addNestedPass<gpu::GPUModuleOp>(imex::createConvertXeGPUToVCPass());
88+
pm.addPass(createReconcileUnrealizedCastsPass());
89+
pm.addPass(imex::createBF16ToGPUPass());
90+
pm.addNestedPass<gpu::GPUModuleOp>(createConvertFuncToSPIRVPass());
91+
pm.addNestedPass<gpu::GPUModuleOp>(createConvertVectorToSPIRVPass());
92+
pm.addPass(imex::createConvertGPUXToSPIRVPass());
93+
pm.addNestedPass<spirv::ModuleOp>(spirv::createSPIRVLowerABIAttributesPass());
94+
pm.addNestedPass<spirv::ModuleOp>(spirv::createSPIRVUpdateVCEPass());
95+
pm.addNestedPass<func::FuncOp>(LLVM::createRequestCWrappersPass());
96+
pm.addPass(imex::createSerializeSPIRVPass());
97+
pm.addPass(createConvertVectorToSCFPass());
98+
pm.addPass(imex::createConvertGPUToGPUXPass());
99+
pm.addPass(createConvertSCFToCFPass());
100+
pm.addPass(createConvertControlFlowToLLVMPass());
101+
pm.addPass(createConvertVectorToLLVMPass());
102+
pm.addPass(createConvertIndexToLLVMPass());
103+
pm.addPass(createArithToLLVMConversionPass());
104+
pm.addPass(createConvertFuncToLLVMPass());
105+
pm.addPass(createConvertMathToLLVMPass());
106+
pm.addPass(imex::createConvertGPUXToLLVMPass());
107+
pm.addPass(createConvertIndexToLLVMPass());
108+
pm.addPass(memref::createExpandStridedMetadataPass());
109+
pm.addPass(createLowerAffinePass());
110+
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
111+
pm.addPass(createReconcileUnrealizedCastsPass());
112+
}
113+
114+
void registerGPUPipeline() {
115+
PassPipelineRegistration<>("gc-gpu-pipeline",
116+
"The GPU pipeline for Graph Compiler with IMEX",
117+
populateGPUPipeline);
118+
}
119+
120+
} // namespace mlir::gc

src/gc-opt/gc-opt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636

3737
namespace mlir::gc {
3838
void registerCPUPipeline();
39+
#ifdef GC_USE_IMEX
40+
void registerGPUPipeline();
41+
#endif
3942
} // namespace mlir::gc
4043

4144
int main(int argc, char *argv[]) {
@@ -47,6 +50,7 @@ int main(int argc, char *argv[]) {
4750
imex::registerConvertGPUXToSPIRV();
4851
imex::registerConvertXeGPUToVC();
4952
imex::registerConvertXeTileToXeGPU();
53+
mlir::gc::registerGPUPipeline();
5054
#endif
5155
mlir::registerAllPasses();
5256
mlir::gc::registerCPUPipeline();

0 commit comments

Comments
 (0)