Skip to content

Commit fda5192

Browse files
georgemitenkovantiagainst
authored andcommitted
[MLIR][SPIRVToLLVM] Add skeleton for SPIR-V to LLVM dialect conversion
These commits set up the skeleton for SPIR-V to LLVM dialect conversion. I created SPIR-V to LLVM pass, registered it in Passes.td, InitAllPasses.h. Added a pattern for `spv.BitwiseAndOp` and tests for it. Integer, float and vector types are converted through LLVMTypeConverter. Differential Revision: https://reviews.llvm.org/D81100
1 parent a6d6b0a commit fda5192

File tree

10 files changed

+209
-1
lines changed

10 files changed

+209
-1
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
206206
let constructor = "mlir::createConvertShapeToStandardPass()";
207207
}
208208

209+
//===----------------------------------------------------------------------===//
210+
// SPIRVToLLVM
211+
//===----------------------------------------------------------------------===//
212+
213+
def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
214+
let summary = "Convert SPIR-V dialect to LLVM dialect";
215+
let constructor = "mlir::createConvertSPIRVToLLVMPass()";
216+
}
217+
209218
//===----------------------------------------------------------------------===//
210219
// StandardToLLVM
211220
//===----------------------------------------------------------------------===//
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- ConvertSPIRVToLLVM.h - Convert SPIR-V to LLVM dialect ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides patterns to convert SPIR-V dialect to LLVM dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
14+
#define MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
15+
16+
#include "mlir/Transforms/DialectConversion.h"
17+
18+
namespace mlir {
19+
class LLVMTypeConverter;
20+
class MLIRContext;
21+
class ModuleOp;
22+
23+
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
24+
void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
25+
LLVMTypeConverter &typeConverter,
26+
OwningRewritePatternList &patterns);
27+
28+
} // namespace mlir
29+
30+
#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- ConvertSPIRVToLLVMPass.h - SPIR-V dialect to LLVM pass ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides a pass to lower from SPIR-V dialect to LLVM dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVMPASS_H
14+
#define MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVMPASS_H
15+
16+
#include <memory>
17+
18+
namespace mlir {
19+
class ModuleOp;
20+
template <typename T>
21+
class OperationPass;
22+
23+
/// Creates a pass to convert SPIR-V operations to the LLVMIR dialect.
24+
std::unique_ptr<OperationPass<ModuleOp>> createConvertSPIRVToLLVMPass();
25+
26+
} // namespace mlir
27+
28+
#endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVMPASS_H_

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
namespace mlir {
1414
class LLVMTypeConverter;
1515
class ModuleOp;
16-
template <typename T> class OperationPass;
16+
template <typename T>
17+
class OperationPass;
1718

1819
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
1920
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics

mlir/include/mlir/InitAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
2626
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
2727
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
28+
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
2829
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
2930
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
3031
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_subdirectory(LinalgToStandard)
1111
add_subdirectory(SCFToGPU)
1212
add_subdirectory(SCFToStandard)
1313
add_subdirectory(ShapeToStandard)
14+
add_subdirectory(SPIRVToLLVM)
1415
add_subdirectory(StandardToLLVM)
1516
add_subdirectory(StandardToSPIRV)
1617
add_subdirectory(VectorToLLVM)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRSPIRVToLLVM
2+
ConvertSPIRVToLLVM.cpp
3+
ConvertSPIRVToLLVMPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SPIRVToLLVM
7+
8+
DEPENDS
9+
MLIRConversionPassIncGen
10+
intrinsics_gen
11+
12+
LINK_LIBS PUBLIC
13+
MLIRSPIRV
14+
MLIRLLVMIR
15+
MLIRStandardToLLVM
16+
MLIRIR
17+
MLIRTransforms
18+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
14+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17+
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18+
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
19+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
20+
#include "mlir/IR/Module.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Support/LogicalResult.h"
23+
#include "mlir/Transforms/DialectConversion.h"
24+
25+
using namespace mlir;
26+
27+
namespace {
28+
29+
class BitwiseAndOpConversion : public ConvertToLLVMPattern {
30+
public:
31+
explicit BitwiseAndOpConversion(MLIRContext *context,
32+
LLVMTypeConverter &typeConverter)
33+
: ConvertToLLVMPattern(spirv::BitwiseAndOp::getOperationName(), context,
34+
typeConverter) {}
35+
36+
LogicalResult
37+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
38+
ConversionPatternRewriter &rewriter) const override {
39+
auto bitwiseAndOp = cast<spirv::BitwiseAndOp>(op);
40+
auto dstType = typeConverter.convertType(bitwiseAndOp.getType());
41+
if (!dstType)
42+
return failure();
43+
rewriter.replaceOpWithNewOp<LLVM::AndOp>(bitwiseAndOp, dstType, operands);
44+
return success();
45+
}
46+
};
47+
} // namespace
48+
49+
void mlir::populateSPIRVToLLVMConversionPatterns(
50+
MLIRContext *context, LLVMTypeConverter &typeConverter,
51+
OwningRewritePatternList &patterns) {
52+
patterns.insert<BitwiseAndOpConversion>(context, typeConverter);
53+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- ConvertSPIRVToLLVMPass.cpp - Convert SPIR-V ops to LLVM ops --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert MLIR SPIR-V ops into LLVM ops
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
14+
#include "../PassDetail.h"
15+
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
16+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
17+
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
/// A pass converting MLIR SPIR-V operations into LLVM dialect.
25+
class ConvertSPIRVToLLVMPass
26+
: public ConvertSPIRVToLLVMBase<ConvertSPIRVToLLVMPass> {
27+
void runOnOperation() override;
28+
};
29+
} // namespace
30+
31+
void ConvertSPIRVToLLVMPass::runOnOperation() {
32+
MLIRContext *context = &getContext();
33+
ModuleOp module = getOperation();
34+
LLVMTypeConverter converter(&getContext());
35+
36+
OwningRewritePatternList patterns;
37+
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
38+
39+
// Currently pulls in Std to LLVM conversion patterns
40+
// that help with testing. This allows to convert
41+
// function arguments to LLVM.
42+
populateStdToLLVMConversionPatterns(converter, patterns);
43+
44+
ConversionTarget target(getContext());
45+
target.addIllegalDialect<spirv::SPIRVDialect>();
46+
target.addLegalDialect<LLVM::LLVMDialect>();
47+
48+
if (failed(applyPartialConversion(module, target, patterns, &converter)))
49+
signalPassFailure();
50+
}
51+
52+
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSPIRVToLLVMPass() {
53+
return std::make_unique<ConvertSPIRVToLLVMPass>();
54+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
2+
3+
func @bitwise_and_scalar(%arg0: i32, %arg1: i32) {
4+
// CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32
5+
%0 = spv.BitwiseAnd %arg0, %arg1 : i32
6+
return
7+
}
8+
9+
func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
10+
// CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
11+
%0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64>
12+
return
13+
}

0 commit comments

Comments
 (0)