Skip to content

Commit 2fb2a64

Browse files
committed
Add a new pass for testing signature conversion patterns in isolation
1 parent e682f15 commit 2fb2a64

File tree

7 files changed

+77
-2
lines changed

7 files changed

+77
-2
lines changed

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ struct ConvertToSPIRVPass final
5454
config.strictMode = GreedyRewriteStrictness::ExistingOps;
5555
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
5656
return signalPassFailure();
57-
return;
5857
}
5958

6059
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);

mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -test-spirv-func-signature-conversion -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: @simple_scalar
44
// CHECK-SAME: (%[[ARG0:.+]]: i32)

mlir/test/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(ConvertToSPIRV)
12
add_subdirectory(FuncToLLVM)
23
add_subdirectory(MathToVCIX)
34
add_subdirectory(OneToNTypeConversion)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Exclude tests from libMLIR.so
2+
add_mlir_library(MLIRTestConvertToSPIRV
3+
TestSPIRVFuncSignatureConversion.cpp
4+
5+
EXCLUDE_FROM_LIBMLIR
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArithDialect
9+
MLIRFuncDialect
10+
MLIRSPIRVConversion
11+
MLIRSPIRVDialect
12+
MLIRVectorDialect
13+
MLIRPass
14+
MLIRTransforms
15+
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12+
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Pass/PassManager.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
18+
namespace mlir {
19+
namespace {
20+
21+
struct TestSPIRVFuncSignatureConversion final
22+
: PassWrapper<TestSPIRVFuncSignatureConversion, OperationPass<ModuleOp>> {
23+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion)
24+
25+
StringRef getArgument() const final {
26+
return "test-spirv-func-signature-conversion";
27+
}
28+
29+
StringRef getDescription() const final {
30+
return "Test patterns that convert vector inputs and results in function "
31+
"signatures";
32+
}
33+
34+
void getDependentDialects(DialectRegistry &registry) const override {
35+
registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
36+
vector::VectorDialect>();
37+
}
38+
39+
void runOnOperation() override {
40+
RewritePatternSet patterns(&getContext());
41+
populateFuncOpVectorRewritePatterns(patterns);
42+
populateReturnOpVectorRewritePatterns(patterns);
43+
GreedyRewriteConfig config;
44+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
45+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
46+
config);
47+
}
48+
};
49+
50+
} // namespace
51+
52+
namespace test {
53+
void registerTestSPIRVFuncSignatureConversion() {
54+
PassRegistration<TestSPIRVFuncSignatureConversion>();
55+
}
56+
} // namespace test
57+
} // namespace mlir

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ if(MLIR_INCLUDE_TESTS)
3636
MLIRSPIRVTestPasses
3737
MLIRTensorTestPasses
3838
MLIRTestAnalysis
39+
MLIRTestConvertToSPIRV
3940
MLIRTestDialect
4041
MLIRTestDynDialect
4142
MLIRTestIR

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ void registerTestSCFWhileOpBuilderPass();
141141
void registerTestSCFWrapInZeroTripCheckPasses();
142142
void registerTestShapeMappingPass();
143143
void registerTestSliceAnalysisPass();
144+
void registerTestSPIRVFuncSignatureConversion();
144145
void registerTestTensorCopyInsertionPass();
145146
void registerTestTensorTransforms();
146147
void registerTestTopologicalSortAnalysisPass();
@@ -273,6 +274,7 @@ void registerTestPasses() {
273274
mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
274275
mlir::test::registerTestShapeMappingPass();
275276
mlir::test::registerTestSliceAnalysisPass();
277+
mlir::test::registerTestSPIRVFuncSignatureConversion();
276278
mlir::test::registerTestTensorCopyInsertionPass();
277279
mlir::test::registerTestTensorTransforms();
278280
mlir::test::registerTestTopologicalSortAnalysisPass();

0 commit comments

Comments
 (0)