Skip to content

Commit cf2b0dc

Browse files
author
git apple-llvm automerger
committed
Merge commit '995c3984efe3' from llvm.org/main into next
2 parents 9a6900e + 995c398 commit cf2b0dc

File tree

12 files changed

+323
-56
lines changed

12 files changed

+323
-56
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- MathToSPIRV.h - Math to SPIR-V Patterns ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides patterns to convert Math dialect to SPIR-V dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
14+
#define MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
15+
16+
#include "mlir/Transforms/DialectConversion.h"
17+
18+
namespace mlir {
19+
class SPIRVTypeConverter;
20+
21+
/// Appends to a pattern list additional patterns for translating Math ops
22+
/// to SPIR-V ops.
23+
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
24+
RewritePatternSet &patterns);
25+
26+
} // namespace mlir
27+
28+
#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRV_H
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- MathToSPIRVPass.h - Math to SPIR-V Passes ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides passes to convert Math dialect to SPIR-V dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRVPASS_H
14+
#define MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRVPASS_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
18+
namespace mlir {
19+
20+
/// Creates a pass to convert Math ops to SPIR-V ops.
21+
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToSPIRVPass();
22+
23+
} // namespace mlir
24+
25+
#endif // MLIR_CONVERSION_MATHTOSPIRV_MATHTOSPIRVPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
2525
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2626
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
27+
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
2728
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2829
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
2930
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,16 @@ def ConvertMathToLLVM : FunctionPass<"convert-math-to-llvm"> {
268268
let dependentDialects = ["LLVM::LLVMDialect"];
269269
}
270270

271+
//===----------------------------------------------------------------------===//
272+
// MathToSPIRV
273+
//===----------------------------------------------------------------------===//
274+
275+
def ConvertMathToSPIRV : Pass<"convert-math-to-spirv", "ModuleOp"> {
276+
let summary = "Convert Math dialect to SPIR-V dialect";
277+
let constructor = "mlir::createConvertMathToSPIRVPass()";
278+
let dependentDialects = ["spirv::SPIRVDialect"];
279+
}
280+
271281
//===----------------------------------------------------------------------===//
272282
// MemRefToLLVM
273283
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_subdirectory(LinalgToStandard)
1414
add_subdirectory(LLVMCommon)
1515
add_subdirectory(MathToLibm)
1616
add_subdirectory(MathToLLVM)
17+
add_subdirectory(MathToSPIRV)
1718
add_subdirectory(MemRefToLLVM)
1819
add_subdirectory(OpenACCToLLVM)
1920
add_subdirectory(OpenACCToSCF)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
add_mlir_conversion_library(MLIRMathToSPIRV
2+
MathToSPIRV.cpp
3+
MathToSPIRVPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
7+
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
8+
9+
DEPENDS
10+
MLIRConversionPassIncGen
11+
12+
LINK_LIBS PUBLIC
13+
MLIRIR
14+
MLIRMath
15+
MLIRPass
16+
MLIRSPIRV
17+
MLIRSPIRVConversion
18+
MLIRSupport
19+
MLIRTransformUtils
20+
)
21+
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert Math dialect to SPIR-V dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Math/IR/Math.h"
14+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16+
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17+
#include "llvm/Support/Debug.h"
18+
19+
#define DEBUG_TYPE "math-to-spirv-pattern"
20+
21+
using namespace mlir;
22+
23+
//===----------------------------------------------------------------------===//
24+
// Operation conversion
25+
//===----------------------------------------------------------------------===//
26+
27+
// Note that DRR cannot be used for the patterns in this file: we may need to
28+
// convert type along the way, which requires ConversionPattern. DRR generates
29+
// normal RewritePattern.
30+
31+
namespace {
32+
33+
/// Converts unary and binary standard operations to SPIR-V operations.
34+
template <typename StdOp, typename SPIRVOp>
35+
class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
36+
public:
37+
using OpConversionPattern<StdOp>::OpConversionPattern;
38+
39+
LogicalResult
40+
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
41+
ConversionPatternRewriter &rewriter) const override {
42+
assert(operands.size() <= 2);
43+
auto dstType = this->getTypeConverter()->convertType(operation.getType());
44+
if (!dstType)
45+
return failure();
46+
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
47+
dstType != operation.getType()) {
48+
return operation.emitError(
49+
"bitwidth emulation is not implemented yet on unsigned op");
50+
}
51+
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
52+
return success();
53+
}
54+
};
55+
56+
/// Converts math.log1p to SPIR-V ops.
57+
///
58+
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
59+
/// these operations.
60+
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
61+
public:
62+
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
63+
64+
LogicalResult
65+
matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
66+
ConversionPatternRewriter &rewriter) const override {
67+
assert(operands.size() == 1);
68+
Location loc = operation.getLoc();
69+
auto type =
70+
this->getTypeConverter()->convertType(operation.operand().getType());
71+
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
72+
auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
73+
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
74+
return success();
75+
}
76+
};
77+
78+
} // namespace
79+
80+
//===----------------------------------------------------------------------===//
81+
// Pattern population
82+
//===----------------------------------------------------------------------===//
83+
84+
namespace mlir {
85+
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
86+
RewritePatternSet &patterns) {
87+
patterns.add<Log1pOpPattern,
88+
UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
89+
UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
90+
UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
91+
UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
92+
UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
93+
UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
94+
UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
95+
UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
96+
typeConverter, patterns.getContext());
97+
}
98+
99+
} // namespace mlir
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- MathToSPIRVPass.cpp - Math to SPIR-V Passes ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert standard dialect to SPIR-V dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
14+
#include "../PassDetail.h"
15+
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
16+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17+
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
/// A pass converting MLIR Math operations into the SPIR-V dialect.
23+
class ConvertMathToSPIRVPass
24+
: public ConvertMathToSPIRVBase<ConvertMathToSPIRVPass> {
25+
void runOnOperation() override;
26+
};
27+
} // namespace
28+
29+
void ConvertMathToSPIRVPass::runOnOperation() {
30+
MLIRContext *context = &getContext();
31+
ModuleOp module = getOperation();
32+
33+
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
34+
std::unique_ptr<ConversionTarget> target =
35+
SPIRVConversionTarget::get(targetAttr);
36+
37+
SPIRVTypeConverter typeConverter(targetAttr);
38+
39+
RewritePatternSet patterns(context);
40+
populateMathToSPIRVPatterns(typeConverter, patterns);
41+
42+
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
43+
return signalPassFailure();
44+
}
45+
46+
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToSPIRVPass() {
47+
return std::make_unique<ConvertMathToSPIRVPass>();
48+
}

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#include "mlir/Dialect/Math/IR/Math.h"
1413
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1514
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1615
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -317,28 +316,6 @@ class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
317316
}
318317
};
319318

320-
/// Converts math.log1p to SPIR-V ops.
321-
///
322-
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
323-
/// these operations.
324-
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
325-
public:
326-
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
327-
328-
LogicalResult
329-
matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
330-
ConversionPatternRewriter &rewriter) const override {
331-
assert(operands.size() == 1);
332-
Location loc = operation.getLoc();
333-
auto type =
334-
this->getTypeConverter()->convertType(operation.operand().getType());
335-
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
336-
auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
337-
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
338-
return success();
339-
}
340-
};
341-
342319
/// Converts std.remi_signed to SPIR-V ops.
343320
///
344321
/// This cannot be merged into the template unary/binary pattern due to
@@ -1336,17 +1313,6 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
13361313
MLIRContext *context = patterns.getContext();
13371314

13381315
patterns.add<
1339-
// Math dialect operations.
1340-
// TODO: Move to separate pass.
1341-
UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
1342-
UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
1343-
UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
1344-
UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
1345-
UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
1346-
UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
1347-
UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
1348-
UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
1349-
13501316
// Unary and binary patterns
13511317
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
13521318
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
@@ -1369,7 +1335,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
13691335
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
13701336
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
13711337
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
1372-
Log1pOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
1338+
SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
13731339

13741340
// Comparison patterns
13751341
BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
2+
3+
// CHECK-LABEL: @float32_unary_scalar
4+
func @float32_unary_scalar(%arg0: f32) {
5+
// CHECK: spv.GLSL.Cos %{{.*}}: f32
6+
%0 = math.cos %arg0 : f32
7+
// CHECK: spv.GLSL.Exp %{{.*}}: f32
8+
%1 = math.exp %arg0 : f32
9+
// CHECK: spv.GLSL.Log %{{.*}}: f32
10+
%2 = math.log %arg0 : f32
11+
// CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
12+
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
13+
// CHECK: spv.GLSL.Log %[[ADDONE]]
14+
%3 = math.log1p %arg0 : f32
15+
// CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
16+
%4 = math.rsqrt %arg0 : f32
17+
// CHECK: spv.GLSL.Sqrt %{{.*}}: f32
18+
%5 = math.sqrt %arg0 : f32
19+
// CHECK: spv.GLSL.Tanh %{{.*}}: f32
20+
%6 = math.tanh %arg0 : f32
21+
// CHECK: spv.GLSL.Sin %{{.*}}: f32
22+
%7 = math.sin %arg0 : f32
23+
return
24+
}
25+
26+
// CHECK-LABEL: @float32_unary_vector
27+
func @float32_unary_vector(%arg0: vector<3xf32>) {
28+
// CHECK: spv.GLSL.Cos %{{.*}}: vector<3xf32>
29+
%0 = math.cos %arg0 : vector<3xf32>
30+
// CHECK: spv.GLSL.Exp %{{.*}}: vector<3xf32>
31+
%1 = math.exp %arg0 : vector<3xf32>
32+
// CHECK: spv.GLSL.Log %{{.*}}: vector<3xf32>
33+
%2 = math.log %arg0 : vector<3xf32>
34+
// CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
35+
// CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
36+
// CHECK: spv.GLSL.Log %[[ADDONE]]
37+
%3 = math.log1p %arg0 : vector<3xf32>
38+
// CHECK: spv.GLSL.InverseSqrt %{{.*}}: vector<3xf32>
39+
%4 = math.rsqrt %arg0 : vector<3xf32>
40+
// CHECK: spv.GLSL.Sqrt %{{.*}}: vector<3xf32>
41+
%5 = math.sqrt %arg0 : vector<3xf32>
42+
// CHECK: spv.GLSL.Tanh %{{.*}}: vector<3xf32>
43+
%6 = math.tanh %arg0 : vector<3xf32>
44+
// CHECK: spv.GLSL.Sin %{{.*}}: vector<3xf32>
45+
%7 = math.sin %arg0 : vector<3xf32>
46+
return
47+
}
48+
49+
// CHECK-LABEL: @float32_binary_scalar
50+
func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
51+
// CHECK: spv.GLSL.Pow %{{.*}}: f32
52+
%0 = math.powf %lhs, %rhs : f32
53+
return
54+
}
55+
56+
// CHECK-LABEL: @float32_binary_vector
57+
func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
58+
// CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32>
59+
%0 = math.powf %lhs, %rhs : vector<4xf32>
60+
return
61+
}

0 commit comments

Comments
 (0)