Skip to content

Commit 2abad34

Browse files
committed
[mlir][rocdl] Adding vector to ROCDL dialect lowering
* Created the vector to ROCDL lowering pass * The lowering pass lowers vector transferOps to rocdl mubufOps * Added unit test and functional test
1 parent bff0987 commit 2abad34

File tree

10 files changed

+401
-0
lines changed

10 files changed

+401
-0
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,4 +301,14 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
301301
let constructor = "mlir::createConvertVectorToLLVMPass()";
302302
}
303303

304+
//===----------------------------------------------------------------------===//
305+
// VectorToROCDL
306+
//===----------------------------------------------------------------------===//
307+
308+
def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
309+
let summary = "Lower the operations from the vector dialect into the ROCDL "
310+
"dialect";
311+
let constructor = "mlir::createConvertVectorToROCDLPass()";
312+
}
313+
304314
#endif // MLIR_CONVERSION_PASSES
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- VectorToROCDL.h - Convert Vector to ROCDL dialect ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
9+
#define MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
10+
11+
#include <memory>
12+
13+
namespace mlir {
14+
class LLVMTypeConverter;
15+
class OwningRewritePatternList;
16+
class ModuleOp;
17+
template <typename OpT>
18+
class OperationPass;
19+
20+
/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
21+
void populateVectorToROCDLConversionPatterns(
22+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
23+
24+
/// Create a pass to convert vector operations to the ROCDL dialect.
25+
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToROCDLPass();
26+
27+
} // namespace mlir
28+
#endif // MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_

mlir/include/mlir/InitAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
3131
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
3232
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
33+
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
3334
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
3435
#include "mlir/Dialect/Affine/Passes.h"
3536
#include "mlir/Dialect/GPU/Passes.h"

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ add_subdirectory(ShapeToStandard)
1414
add_subdirectory(SPIRVToLLVM)
1515
add_subdirectory(StandardToLLVM)
1616
add_subdirectory(StandardToSPIRV)
17+
add_subdirectory(VectorToROCDL)
1718
add_subdirectory(VectorToLLVM)
1819
add_subdirectory(VectorToSCF)

mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRGPUtoROCDLTransforms
1515
MLIRROCDLIR
1616
MLIRPass
1717
MLIRStandardToLLVM
18+
MLIRVectorToROCDL
1819
)

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
1717
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
18+
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
1819
#include "mlir/Dialect/GPU/GPUDialect.h"
1920
#include "mlir/Dialect/GPU/Passes.h"
2021
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -55,6 +56,7 @@ class LowerGpuOpsToROCDLOpsPass
5556
patterns.clear();
5657

5758
populateVectorToLLVMConversionPatterns(converter, patterns);
59+
populateVectorToROCDLConversionPatterns(converter, patterns);
5860
populateStdToLLVMConversionPatterns(converter, patterns);
5961
populateGpuToROCDLConversionPatterns(converter, patterns);
6062
LLVMConversionTarget target(getContext());
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRVectorToROCDL
2+
VectorToROCDL.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToROCDL
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
intrinsics_gen
10+
11+
LINK_COMPONENTS
12+
Core
13+
14+
LINK_LIBS PUBLIC
15+
MLIRROCDLIR
16+
MLIRStandardToLLVM
17+
MLIRVector
18+
MLIRTransforms
19+
)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
//===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to generate ROCDLIR operations for higher-level
10+
// Vector operations.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
15+
16+
#include "../PassDetail.h"
17+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19+
#include "mlir/Dialect/GPU/GPUDialect.h"
20+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
22+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
23+
#include "mlir/Dialect/Vector/VectorOps.h"
24+
#include "mlir/Pass/Pass.h"
25+
#include "mlir/Transforms/DialectConversion.h"
26+
27+
using namespace mlir;
28+
using namespace mlir::vector;
29+
30+
static TransferReadOpOperandAdaptor
31+
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
32+
return OperandAdaptor<TransferReadOp>(operands);
33+
}
34+
35+
static TransferWriteOpOperandAdaptor
36+
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
37+
return OperandAdaptor<TransferWriteOp>(operands);
38+
}
39+
40+
static LogicalResult replaceTransferOpWithMubuf(
41+
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
42+
LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
43+
LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
44+
Value &offsetSizeInBytes, Value &glc, Value &slc) {
45+
rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
46+
xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
47+
return success();
48+
}
49+
50+
static LogicalResult replaceTransferOpWithMubuf(
51+
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
52+
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
53+
LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
54+
Value &offsetSizeInBytes, Value &glc, Value &slc) {
55+
auto adaptor = TransferWriteOpOperandAdaptor(operands);
56+
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
57+
dwordConfig, vindex,
58+
offsetSizeInBytes, glc, slc);
59+
return success();
60+
}
61+
62+
namespace {
63+
/// Conversion pattern that converts a 1-D vector transfer read/write.
64+
/// Note that this conversion pass only converts vector x2 or x4 f32
65+
/// types. For unsupported cases, they will fall back to the vector to
66+
/// llvm conversion pattern.
67+
template <typename ConcreteOp>
68+
class VectorTransferConversion : public ConvertToLLVMPattern {
69+
public:
70+
explicit VectorTransferConversion(MLIRContext *context,
71+
LLVMTypeConverter &typeConv)
72+
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
73+
typeConv) {}
74+
75+
LogicalResult
76+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
77+
ConversionPatternRewriter &rewriter) const override {
78+
auto xferOp = cast<ConcreteOp>(op);
79+
auto adaptor = getTransferOpAdapter(xferOp, operands);
80+
81+
if (xferOp.getVectorType().getRank() > 1 ||
82+
llvm::size(xferOp.indices()) == 0)
83+
return failure();
84+
85+
if (!AffineMap::isMinorIdentity(xferOp.permutation_map()))
86+
return failure();
87+
88+
// Have it handled in vector->llvm conversion pass.
89+
if (!xferOp.isMaskedDim(0))
90+
return failure();
91+
92+
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
93+
LLVM::LLVMType vecTy =
94+
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
95+
unsigned vecWidth = vecTy.getVectorNumElements();
96+
Location loc = op->getLoc();
97+
98+
// The backend result vector scalarization have trouble scalarize
99+
// <1 x ty> result, exclude the x1 width from the lowering.
100+
if (vecWidth != 2 && vecWidth != 4)
101+
return failure();
102+
103+
// Obtain dataPtr and elementType from the memref.
104+
MemRefType memRefType = xferOp.getMemRefType();
105+
// MUBUF instruction operate only on addresspace 0(unified) or 1(global)
106+
// In case of 3(LDS): fall back to vector->llvm pass
107+
// In case of 5(VGPR): wrong
108+
if ((memRefType.getMemorySpace() != 0) &&
109+
(memRefType.getMemorySpace() != 1))
110+
return failure();
111+
112+
// Note that the dataPtr starts at the offset address specified by
113+
// indices, so no need to calculat offset size in bytes again in
114+
// the MUBUF instruction.
115+
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
116+
adaptor.indices(), rewriter, getModule());
117+
118+
// 1. Create and fill a <4 x i32> dwordConfig with:
119+
// 1st two elements holding the address of dataPtr.
120+
// 3rd element: -1.
121+
// 4th element: 0x27000.
122+
SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
123+
Type i32Ty = rewriter.getIntegerType(32);
124+
VectorType i32Vecx4 = VectorType::get(4, i32Ty);
125+
Value constConfig = rewriter.create<LLVM::ConstantOp>(
126+
loc, toLLVMTy(i32Vecx4),
127+
DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
128+
129+
// Treat first two element of <4 x i32> as i64, and save the dataPtr
130+
// to it.
131+
Type i64Ty = rewriter.getIntegerType(64);
132+
Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
133+
loc,
134+
LLVM::LLVMType::getVectorTy(
135+
toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
136+
constConfig);
137+
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
138+
loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
139+
Value zero = createIndexConstant(rewriter, loc, 0);
140+
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
141+
loc,
142+
LLVM::LLVMType::getVectorTy(
143+
toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
144+
i64x2Ty, dataPtrAsI64, zero);
145+
dwordConfig =
146+
rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
147+
148+
// 2. Rewrite op as a buffer read or write.
149+
Value int1False = rewriter.create<LLVM::ConstantOp>(
150+
loc, toLLVMTy(rewriter.getIntegerType(1)),
151+
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
152+
Value int32Zero = rewriter.create<LLVM::ConstantOp>(
153+
loc, toLLVMTy(i32Ty),
154+
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
155+
return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc,
156+
xferOp, vecTy, dwordConfig, int32Zero,
157+
int32Zero, int1False, int1False);
158+
}
159+
};
160+
} // end anonymous namespace
161+
162+
void mlir::populateVectorToROCDLConversionPatterns(
163+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
164+
MLIRContext *ctx = converter.getDialect()->getContext();
165+
patterns.insert<VectorTransferConversion<TransferReadOp>,
166+
VectorTransferConversion<TransferWriteOp>>(ctx, converter);
167+
}
168+
169+
namespace {
170+
struct LowerVectorToROCDLPass
171+
: public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
172+
void runOnOperation() override;
173+
};
174+
} // namespace
175+
176+
void LowerVectorToROCDLPass::runOnOperation() {
177+
LLVMTypeConverter converter(&getContext());
178+
OwningRewritePatternList patterns;
179+
180+
populateVectorToROCDLConversionPatterns(converter, patterns);
181+
populateStdToLLVMConversionPatterns(converter, patterns);
182+
183+
LLVMConversionTarget target(getContext());
184+
target.addLegalDialect<ROCDL::ROCDLDialect>();
185+
186+
if (failed(applyPartialConversion(getOperation(), target, patterns,
187+
&converter))) {
188+
signalPassFailure();
189+
}
190+
}
191+
192+
std::unique_ptr<OperationPass<ModuleOp>>
193+
mlir::createConvertVectorToROCDLPass() {
194+
return std::make_unique<LowerVectorToROCDLPass>();
195+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: mlir-opt %s -convert-vector-to-rocdl | FileCheck %s
2+
3+
gpu.module @test_read{
4+
func @transfer_readx2(%A : memref<?xf32>, %base: index) -> vector<2xf32> {
5+
%f0 = constant 0.0: f32
6+
%f = vector.transfer_read %A[%base], %f0
7+
{permutation_map = affine_map<(d0) -> (d0)>} :
8+
memref<?xf32>, vector<2xf32>
9+
return %f: vector<2xf32>
10+
}
11+
// CHECK-LABEL: @transfer_readx2
12+
// CHECK: rocdl.buffer.load {{.*}} !llvm<"<2 x float>">
13+
14+
func @transfer_readx4(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
15+
%f0 = constant 0.0: f32
16+
%f = vector.transfer_read %A[%base], %f0
17+
{permutation_map = affine_map<(d0) -> (d0)>} :
18+
memref<?xf32>, vector<4xf32>
19+
return %f: vector<4xf32>
20+
}
21+
// CHECK-LABEL: @transfer_readx4
22+
// CHECK: rocdl.buffer.load {{.*}} !llvm<"<4 x float>">
23+
24+
func @transfer_read_dwordConfig(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
25+
%f0 = constant 0.0: f32
26+
%f = vector.transfer_read %A[%base], %f0
27+
{permutation_map = affine_map<(d0) -> (d0)>} :
28+
memref<?xf32>, vector<4xf32>
29+
return %f: vector<4xf32>
30+
}
31+
// CHECK-LABEL: @transfer_read_dwordConfig
32+
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
33+
// CHECK: [0, 0, -1, 159744]
34+
// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
35+
// CHECK: llvm.insertelement %[[i64]]
36+
}
37+
38+
gpu.module @test_write{
39+
func @transfer_writex2(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
40+
vector.transfer_write %B, %A[%base]
41+
{permutation_map = affine_map<(d0) -> (d0)>} :
42+
vector<2xf32>, memref<?xf32>
43+
return
44+
}
45+
// CHECK-LABEL: @transfer_writex2
46+
// CHECK: rocdl.buffer.store {{.*}} !llvm<"<2 x float>">
47+
48+
func @transfer_writex4(%A : memref<?xf32>, %B : vector<4xf32>, %base: index) {
49+
vector.transfer_write %B, %A[%base]
50+
{permutation_map = affine_map<(d0) -> (d0)>} :
51+
vector<4xf32>, memref<?xf32>
52+
return
53+
}
54+
// CHECK-LABEL: @transfer_writex4
55+
// CHECK: rocdl.buffer.store {{.*}} !llvm<"<4 x float>">
56+
57+
func @transfer_write_dwordConfig(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
58+
vector.transfer_write %B, %A[%base]
59+
{permutation_map = affine_map<(d0) -> (d0)>} :
60+
vector<2xf32>, memref<?xf32>
61+
return
62+
}
63+
// CHECK-LABEL: @transfer_write_dwordConfig
64+
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
65+
// CHECK: [0, 0, -1, 159744]
66+
// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
67+
// CHECK: llvm.insertelement %[[i64]]
68+
}

0 commit comments

Comments
 (0)