Skip to content

Reland "[mlir][spirv] Add a generic convert-to-spirv pass" #96359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// ToSPIRV
//===----------------------------------------------------------------------===//

def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let summary = "Convert to SPIR-V";
let description = [{
This is a generic pass to convert to SPIR-V.
}];
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// AffineToStandard
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSCF)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(ConvertToLLVM)
add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
)

add_mlir_conversion_library(MLIRConvertToSPIRVPass
ConvertToSPIRVPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithToSPIRV
MLIRArithTransforms
MLIRFuncToSPIRV
MLIRIndexToSPIRV
MLIRIR
MLIRPass
MLIRRewrite
MLIRSCFToSPIRV
MLIRSPIRVConversion
MLIRSPIRVDialect
MLIRSPIRVTransforms
MLIRSupport
MLIRTransforms
MLIRTransformUtils
MLIRUBToSPIRV
MLIRVectorToSPIRV
MLIRVectorTransforms
)
71 changes: 71 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>

#define DEBUG_TYPE "convert-to-spirv"

namespace mlir {
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {

void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
SPIRVTypeConverter typeConverter(targetAttr);

RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;

// Populate patterns.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);

std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);

if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
218 changes: 218 additions & 0 deletions mlir/test/Conversion/ConvertToSPIRV/arith.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s

//===----------------------------------------------------------------------===//
// arithmetic ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @int32_scalar
func.func @int32_scalar(%lhs: i32, %rhs: i32) {
// CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
%0 = arith.addi %lhs, %rhs: i32
// CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
%1 = arith.subi %lhs, %rhs: i32
// CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
%2 = arith.muli %lhs, %rhs: i32
// CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
%3 = arith.divsi %lhs, %rhs: i32
// CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
%4 = arith.divui %lhs, %rhs: i32
// CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
%5 = arith.remui %lhs, %rhs: i32
return
}

// CHECK-LABEL: @int32_scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
// CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
// CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
// CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
// CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
%0 = arith.remsi %lhs, %rhs: i32
return
}

// -----

//===----------------------------------------------------------------------===//
// arith bit ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @bitwise_scalar
func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.BitwiseAnd
%0 = arith.andi %arg0, %arg1 : i32
// CHECK: spirv.BitwiseOr
%1 = arith.ori %arg0, %arg1 : i32
// CHECK: spirv.BitwiseXor
%2 = arith.xori %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @bitwise_vector
func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// CHECK: spirv.BitwiseAnd
%0 = arith.andi %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.BitwiseOr
%1 = arith.ori %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.BitwiseXor
%2 = arith.xori %arg0, %arg1 : vector<4xi32>
return
}

// CHECK-LABEL: @logical_scalar
func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
// CHECK: spirv.LogicalAnd
%0 = arith.andi %arg0, %arg1 : i1
// CHECK: spirv.LogicalOr
%1 = arith.ori %arg0, %arg1 : i1
// CHECK: spirv.LogicalNotEqual
%2 = arith.xori %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @logical_vector
func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spirv.LogicalAnd
%0 = arith.andi %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalOr
%1 = arith.ori %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalNotEqual
%2 = arith.xori %arg0, %arg1 : vector<4xi1>
return
}

// CHECK-LABEL: @shift_scalar
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.ShiftLeftLogical
%0 = arith.shli %arg0, %arg1 : i32
// CHECK: spirv.ShiftRightArithmetic
%1 = arith.shrsi %arg0, %arg1 : i32
// CHECK: spirv.ShiftRightLogical
%2 = arith.shrui %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @shift_vector
func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// CHECK: spirv.ShiftLeftLogical
%0 = arith.shli %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.ShiftRightArithmetic
%1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.ShiftRightLogical
%2 = arith.shrui %arg0, %arg1 : vector<4xi32>
return
}

// -----

//===----------------------------------------------------------------------===//
// arith.cmpf
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cmpf
func.func @cmpf(%arg0 : f32, %arg1 : f32) {
// CHECK: spirv.FOrdEqual
%1 = arith.cmpf oeq, %arg0, %arg1 : f32
return
}

// CHECK-LABEL: @vec1cmpf
func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
// CHECK: spirv.FOrdGreaterThan
%0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
// CHECK: spirv.FUnordLessThan
%1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
return
}

// -----

//===----------------------------------------------------------------------===//
// arith.cmpi
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cmpi
func.func @cmpi(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.IEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @indexcmpi
func.func @indexcmpi(%arg0 : index, %arg1 : index) {
// CHECK: spirv.IEqual
%0 = arith.cmpi eq, %arg0, %arg1 : index
return
}

// CHECK-LABEL: @vec1cmpi
func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
// CHECK: spirv.ULessThan
%0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
// CHECK: spirv.SGreaterThan
%1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
return
}

// CHECK-LABEL: @boolcmpi_equality
func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i1
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @boolcmpi_unsigned
func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : i1
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @vec1boolcmpi_equality
func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
return
}

// CHECK-LABEL: @vec1boolcmpi_unsigned
func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
return
}

// CHECK-LABEL: @vecboolcmpi_equality
func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
return
}

// CHECK-LABEL: @vecboolcmpi_unsigned
func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
return
}
Loading
Loading