Skip to content

Commit 7f6d6ef

Browse files
committed
Adding Vector to AMDGPU conversion lowering
1 parent 967ab7e commit 7f6d6ef

File tree

7 files changed

+269
-0
lines changed

7 files changed

+269
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
7474
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
7575
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
76+
#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
7677
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
7778
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
7879
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,16 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
13331333
let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
13341334
}
13351335

1336+
//===----------------------------------------------------------------------===//
1337+
// VectorToAMDGPU
1338+
//===----------------------------------------------------------------------===//
1339+
1340+
def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
1341+
let summary = "Lower the operations from the vector dialect into the AMDGPU "
1342+
"dialect";
1343+
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
1344+
}
1345+
13361346
//===----------------------------------------------------------------------===//
13371347
// ArmSMEToSCF
13381348
//===----------------------------------------------------------------------===//
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- VectorToAMDGPU.h - Vector to AMDGPU dialect conversion ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
10+
#define MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
11+
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
namespace mlir {
15+
class RewritePatternSet;
16+
class Pass;
17+
18+
#define GEN_PASS_DECL_CONVERTVECTORTOAMDGPUPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
void populateVectorToAMDGPUConversionPatterns(RewritePatternSet &patterns);
22+
} // namespace mlir
23+
24+
#endif // MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ add_subdirectory(TosaToSCF)
6666
add_subdirectory(TosaToTensor)
6767
add_subdirectory(UBToLLVM)
6868
add_subdirectory(UBToSPIRV)
69+
add_subdirectory(VectorToAMDGPU)
6970
add_subdirectory(VectorToArmSME)
7071
add_subdirectory(VectorToGPU)
7172
add_subdirectory(VectorToLLVM)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRVectorToAMDGPU
2+
VectorToAMDGPU.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMDGPU
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRAMDGPUDialect
15+
MLIRVectorDialect
16+
MLIRPass
17+
MLIRTransforms
18+
)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//===- VectorToAMDGPU.cpp - Vector to AMDGPU dialect conversion ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
10+
11+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
13+
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/IR/TypeUtilities.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Support/LogicalResult.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
22+
#include "mlir/Conversion/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
27+
/// This pattern supports lowering of:
28+
/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
29+
/// `vector.broadcast` if all of the following hold:
30+
/// - The transfer op is masked.
31+
/// - The memref is in buffer address space.
32+
/// - Stride of most minor memref dimension must be 1.
33+
/// - Out-of-bounds masking is not required.
34+
/// - If the memref's element type is a vector type then it coincides with the
35+
/// result type.
36+
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
37+
/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
38+
/// pass.
39+
static LogicalResult
40+
transferPreconditions(PatternRewriter &rewriter,
41+
VectorTransferOpInterface xferOp,
42+
SmallVector<unsigned> &broadcastedDims,
43+
VectorType &unbroadcastedVectorType) {
44+
if (!xferOp.getMask())
45+
return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
46+
47+
// Permutations are handled by VectorToSCF or
48+
// populateVectorTransferPermutationMapLoweringPatterns.
49+
// We let the 0-d corner case pass-through as it is supported.
50+
if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
51+
&broadcastedDims))
52+
return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
53+
54+
auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
55+
if (!memRefType)
56+
return rewriter.notifyMatchFailure(xferOp, "not a memref source");
57+
58+
Attribute addrSpace = memRefType.getMemorySpace();
59+
if (!addrSpace ||
60+
llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
61+
amdgpu::AddressSpace::FatRawBuffer)
62+
return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
63+
64+
// Non-unit strides are handled by VectorToSCF.
65+
if (!memRefType.isLastDimUnitStride())
66+
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
67+
68+
// If there is broadcasting involved then we first load the unbroadcasted
69+
// vector, and then broadcast it with `vector.broadcast`.
70+
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
71+
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
72+
for (unsigned i : broadcastedDims)
73+
unbroadcastedVectorShape[i] = 1;
74+
unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
75+
unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
76+
77+
// `vector.load` supports vector types as memref's elements only when the
78+
// resulting vector type is the same as the element type.
79+
auto memrefElTy = memRefType.getElementType();
80+
if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
81+
return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
82+
83+
// Otherwise, element types of the memref and the vector must match.
84+
if (!isa<VectorType>(memrefElTy) &&
85+
memrefElTy != xferOp.getVectorType().getElementType())
86+
return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
87+
88+
// Out-of-bounds dims are handled by MaterializeTransferMask.
89+
if (xferOp.hasOutOfBoundsDim())
90+
return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
91+
92+
if (xferOp.getVectorType().getRank() != 1)
93+
// vector.maskedload operates on 1-D vectors.
94+
return rewriter.notifyMatchFailure(
95+
xferOp, "vector type is not rank 1, can't create masked load, needs "
96+
"VectorToSCF");
97+
98+
return success();
99+
}
100+
101+
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
102+
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
103+
104+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
105+
PatternRewriter &rewriter) const override {
106+
107+
SmallVector<unsigned> broadcastedDims;
108+
VectorType unbroadcastedVectorType;
109+
if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
110+
unbroadcastedVectorType))) {
111+
return failure();
112+
}
113+
114+
Value fill = rewriter.create<vector::SplatOp>(
115+
readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
116+
Value load = rewriter.create<vector::LoadOp>(
117+
readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
118+
readOp.getIndices());
119+
Value res = rewriter.create<arith::SelectOp>(
120+
readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
121+
122+
// Insert a broadcasting op if required.
123+
if (!broadcastedDims.empty()) {
124+
res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
125+
readOp.getVectorType(), res);
126+
}
127+
128+
rewriter.replaceOp(readOp, res);
129+
130+
return success();
131+
}
132+
};
133+
134+
void mlir::populateVectorToAMDGPUConversionPatterns(
135+
RewritePatternSet &patterns) {
136+
patterns.add<TransferReadLowering>(patterns.getContext());
137+
}
138+
139+
struct ConvertVectorToAMDGPUPass
140+
: public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
141+
void runOnOperation() override {
142+
RewritePatternSet patterns(&getContext());
143+
populateVectorToAMDGPUConversionPatterns(patterns);
144+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
145+
return signalPassFailure();
146+
}
147+
};
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
4+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
5+
// CHECK-SAME: %[[ARG1:.*]]: index
6+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
7+
func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
8+
%cf0 = arith.constant 0.0 : f32
9+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
10+
return %res : vector<4xf32>
11+
}
12+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
13+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
14+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
15+
// CHECK: return %[[SELECT]] : vector<4xf32>
16+
17+
// -----
18+
19+
// CHECK-LABEL: func @transfer_to_maskedload_regular(
20+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
21+
// CHECK-SAME: %[[ARG1:.*]]: index
22+
// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
23+
func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
24+
%cf0 = arith.constant 0.0 : f32
25+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
26+
return %res : vector<4xf32>
27+
}
28+
// CHECK: %[[CST:.*]] = arith.constant 0.0
29+
// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
30+
// CHECK: return %[[RES]] : vector<4xf32>
31+
32+
// -----
33+
34+
// CHECK-LABEL: func @transfer_broadcasting(
35+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
36+
// CHECK-SAME: %[[ARG1:.*]]: index
37+
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
38+
#broadcast_1d = affine_map<(d0, d1) -> (0)>
39+
func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
40+
%cf0 = arith.constant 0.0 : f32
41+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
42+
{in_bounds = [true], permutation_map = #broadcast_1d}
43+
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
44+
return %res : vector<4xf32>
45+
}
46+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
47+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
48+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
49+
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
50+
// CHECK: return %[[BROADCAST]] : vector<4xf32>
51+
52+
// -----
53+
54+
// CHECK-LABEL: func @transfer_scalar(
55+
// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
56+
// CHECK-SAME: %[[ARG1:.*]]: index
57+
// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
58+
func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
59+
%cf0 = arith.constant 0.0 : f32
60+
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
61+
{in_bounds = [true]}
62+
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
63+
return %res : vector<1xf32>
64+
}
65+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
66+
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
67+
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
68+
// CHECK: return %[[SELECT]] : vector<1xf32>

0 commit comments

Comments
 (0)