Skip to content

Commit 1d0dc9b

Browse files
committed
[MLIR][SPIRV] Add rewrite pattern to convert select+cmp into GLSL clamp.
Adds rewrite patterns to convert select+cmp instructions into clamp instructions whenever possible. Support is added to convert: - FOrdLessThan, FOrdLessThanEqual to GLSLFClampOp. - SLessThan, SLessThanEqual to GLSLSClampOp. - ULessThan, ULessThanEqual to GLSLUClampOp. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D93618
1 parent 031743c commit 1d0dc9b

File tree

8 files changed

+252
-0
lines changed

8 files changed

+252
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- SPIRVGLSLCanonicalization.h - GLSL-specific patterns -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares a function to register SPIR-V GLSL-specific
10+
// canonicalization patterns.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
15+
#define MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_
16+
17+
#include "mlir/IR/MLIRContext.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
20+
//===----------------------------------------------------------------------===//
21+
// GLSL canonicalization patterns
22+
//===----------------------------------------------------------------------===//
23+
24+
namespace mlir {
25+
namespace spirv {
26+
void populateSPIRVGLSLCanonicalizationPatterns(
27+
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context);
28+
} // namespace spirv
29+
} // namespace mlir
30+
31+
#endif // MLIR_DIALECT_SPIRV_IR_SPIRVGLSLCANONICALIZATION_H_

mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
55
add_mlir_dialect_library(MLIRSPIRV
66
SPIRVAttributes.cpp
77
SPIRVCanonicalization.cpp
8+
SPIRVGLSLCanonicalization.cpp
89
SPIRVDialect.cpp
910
SPIRVEnums.cpp
1011
SPIRVOps.cpp

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,33 @@ def ConvertLogicalNotOfLogicalEqual : Pat<
3838
def ConvertLogicalNotOfLogicalNotEqual : Pat<
3939
(SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)),
4040
(SPV_LogicalEqualOp $lhs, $rhs)>;
41+
42+
//===----------------------------------------------------------------------===//
43+
// Re-write spv.Select + spv.<less_than_op> to a suitable variant of
44+
// spv.<glsl_clamp_op>
45+
//===----------------------------------------------------------------------===//
46+
47+
def ValuesAreEqual : Constraint<CPred<"$0 == $1">>;
48+
49+
foreach CmpClampPair = [
50+
[SPV_FOrdLessThanOp, SPV_GLSLFClampOp],
51+
[SPV_FOrdLessThanEqualOp, SPV_GLSLFClampOp],
52+
[SPV_SLessThanOp, SPV_GLSLSClampOp],
53+
[SPV_SLessThanEqualOp, SPV_GLSLSClampOp],
54+
[SPV_ULessThanOp, SPV_GLSLUClampOp],
55+
[SPV_ULessThanEqualOp, SPV_GLSLUClampOp]] in {
56+
def ConvertComparisonIntoClamp#CmpClampPair[0] : Pat<
57+
(SPV_SelectOp
58+
(CmpClampPair[0]
59+
(SPV_SelectOp:$middle0
60+
(CmpClampPair[0] $min, $input),
61+
$input,
62+
$min
63+
),
64+
$max
65+
),
66+
$middle1,
67+
$max),
68+
(CmpClampPair[1] $input, $min, $max),
69+
[(ValuesAreEqual $middle0, $middle1)]>;
70+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- SPIRVGLSLCanonicalization.cpp - SPIR-V GLSL canonicalization patterns =//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines the canonicalization patterns for SPIR-V GLSL-specific ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
14+
15+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
#include "SPIRVCanonicalization.inc"
21+
} // end anonymous namespace
22+
23+
namespace mlir {
24+
namespace spirv {
25+
void populateSPIRVGLSLCanonicalizationPatterns(
26+
OwningRewritePatternList &results, MLIRContext *context) {
27+
results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
28+
ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
29+
ConvertComparisonIntoClampSPV_SLessThanOp,
30+
ConvertComparisonIntoClampSPV_SLessThanEqualOp,
31+
ConvertComparisonIntoClampSPV_ULessThanOp,
32+
ConvertComparisonIntoClampSPV_ULessThanEqualOp>(context);
33+
}
34+
} // namespace spirv
35+
} // namespace mlir
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// RUN: mlir-opt -test-spirv-glsl-canonicalization -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK: func @clamp_fordlessthan(%[[INPUT:.*]]: f32)
4+
func @clamp_fordlessthan(%input: f32) -> f32 {
5+
// CHECK: %[[MIN:.*]] = spv.constant
6+
%min = spv.constant 0.5 : f32
7+
// CHECK: %[[MAX:.*]] = spv.constant
8+
%max = spv.constant 1.0 : f32
9+
10+
// CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
11+
%0 = spv.FOrdLessThan %min, %input : f32
12+
%mid = spv.Select %0, %input, %min : i1, f32
13+
%1 = spv.FOrdLessThan %mid, %max : f32
14+
%2 = spv.Select %1, %mid, %max : i1, f32
15+
16+
// CHECK-NEXT: spv.ReturnValue [[RES]]
17+
spv.ReturnValue %2 : f32
18+
}
19+
20+
// -----
21+
22+
// CHECK: func @clamp_fordlessthanequal(%[[INPUT:.*]]: f32)
23+
func @clamp_fordlessthanequal(%input: f32) -> f32 {
24+
// CHECK: %[[MIN:.*]] = spv.constant
25+
%min = spv.constant 0.5 : f32
26+
// CHECK: %[[MAX:.*]] = spv.constant
27+
%max = spv.constant 1.0 : f32
28+
29+
// CHECK: [[RES:%.*]] = spv.GLSL.FClamp %[[INPUT]], %[[MIN]], %[[MAX]]
30+
%0 = spv.FOrdLessThanEqual %min, %input : f32
31+
%mid = spv.Select %0, %input, %min : i1, f32
32+
%1 = spv.FOrdLessThanEqual %mid, %max : f32
33+
%2 = spv.Select %1, %mid, %max : i1, f32
34+
35+
// CHECK-NEXT: spv.ReturnValue [[RES]]
36+
spv.ReturnValue %2 : f32
37+
}
38+
39+
// -----
40+
41+
// CHECK: func @clamp_slessthan(%[[INPUT:.*]]: si32)
42+
func @clamp_slessthan(%input: si32) -> si32 {
43+
// CHECK: %[[MIN:.*]] = spv.constant
44+
%min = spv.constant 0 : si32
45+
// CHECK: %[[MAX:.*]] = spv.constant
46+
%max = spv.constant 10 : si32
47+
48+
// CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
49+
%0 = spv.SLessThan %min, %input : si32
50+
%mid = spv.Select %0, %input, %min : i1, si32
51+
%1 = spv.SLessThan %mid, %max : si32
52+
%2 = spv.Select %1, %mid, %max : i1, si32
53+
54+
// CHECK-NEXT: spv.ReturnValue [[RES]]
55+
spv.ReturnValue %2 : si32
56+
}
57+
58+
// -----
59+
60+
// CHECK: func @clamp_slessthanequal(%[[INPUT:.*]]: si32)
61+
func @clamp_slessthanequal(%input: si32) -> si32 {
62+
// CHECK: %[[MIN:.*]] = spv.constant
63+
%min = spv.constant 0 : si32
64+
// CHECK: %[[MAX:.*]] = spv.constant
65+
%max = spv.constant 10 : si32
66+
67+
// CHECK: [[RES:%.*]] = spv.GLSL.SClamp %[[INPUT]], %[[MIN]], %[[MAX]]
68+
%0 = spv.SLessThanEqual %min, %input : si32
69+
%mid = spv.Select %0, %input, %min : i1, si32
70+
%1 = spv.SLessThanEqual %mid, %max : si32
71+
%2 = spv.Select %1, %mid, %max : i1, si32
72+
73+
// CHECK-NEXT: spv.ReturnValue [[RES]]
74+
spv.ReturnValue %2 : si32
75+
}
76+
77+
// -----
78+
79+
// CHECK: func @clamp_ulessthan(%[[INPUT:.*]]: i32)
80+
func @clamp_ulessthan(%input: i32) -> i32 {
81+
// CHECK: %[[MIN:.*]] = spv.constant
82+
%min = spv.constant 0 : i32
83+
// CHECK: %[[MAX:.*]] = spv.constant
84+
%max = spv.constant 10 : i32
85+
86+
// CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
87+
%0 = spv.ULessThan %min, %input : i32
88+
%mid = spv.Select %0, %input, %min : i1, i32
89+
%1 = spv.ULessThan %mid, %max : i32
90+
%2 = spv.Select %1, %mid, %max : i1, i32
91+
92+
// CHECK-NEXT: spv.ReturnValue [[RES]]
93+
spv.ReturnValue %2 : i32
94+
}
95+
96+
// -----
97+
98+
// CHECK: func @clamp_ulessthanequal(%[[INPUT:.*]]: i32)
99+
func @clamp_ulessthanequal(%input: i32) -> i32 {
100+
// CHECK: %[[MIN:.*]] = spv.constant
101+
%min = spv.constant 0 : i32
102+
// CHECK: %[[MAX:.*]] = spv.constant
103+
%max = spv.constant 10 : i32
104+
105+
// CHECK: [[RES:%.*]] = spv.GLSL.UClamp %[[INPUT]], %[[MIN]], %[[MAX]]
106+
%0 = spv.ULessThanEqual %min, %input : i32
107+
%mid = spv.Select %0, %input, %min : i1, i32
108+
%1 = spv.ULessThanEqual %mid, %max : i32
109+
%2 = spv.Select %1, %mid, %max : i1, i32
110+
111+
// CHECK-NEXT: spv.ReturnValue [[RES]]
112+
spv.ReturnValue %2 : i32
113+
}

mlir/test/lib/Dialect/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
add_mlir_library(MLIRSPIRVTestPasses
33
TestAvailability.cpp
44
TestEntryPointAbi.cpp
5+
TestGLSLCanonicalization.cpp
56
TestModuleCombiner.cpp
67

78
EXCLUDE_FROM_LIBMLIR
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//===- TestGLSLCanonicalization.cpp - Pass to test GLSL-specific pattterns ===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h"
10+
#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
using namespace mlir;
15+
16+
namespace {
17+
class TestGLSLCanonicalizationPass
18+
: public PassWrapper<TestGLSLCanonicalizationPass,
19+
OperationPass<mlir::ModuleOp>> {
20+
public:
21+
TestGLSLCanonicalizationPass() = default;
22+
TestGLSLCanonicalizationPass(const TestGLSLCanonicalizationPass &) {}
23+
void runOnOperation() override;
24+
};
25+
} // namespace
26+
27+
void TestGLSLCanonicalizationPass::runOnOperation() {
28+
OwningRewritePatternList patterns;
29+
spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext());
30+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
31+
}
32+
33+
namespace mlir {
34+
void registerTestSpirvGLSLCanonicalizationPass() {
35+
PassRegistration<TestGLSLCanonicalizationPass> registration(
36+
"test-spirv-glsl-canonicalization",
37+
"Tests SPIR-V canonicalization patterns for GLSL extension.");
38+
}
39+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ void registerTestPrintDefUsePass();
4747
void registerTestPrintNestingPass();
4848
void registerTestReducer();
4949
void registerTestSpirvEntryPointABIPass();
50+
void registerTestSpirvGLSLCanonicalizationPass();
5051
void registerTestSpirvModuleCombinerPass();
5152
void registerTestTraitsPass();
5253
void registerTosaTestQuantUtilAPIPass();
@@ -115,6 +116,7 @@ void registerTestPasses() {
115116
registerTestPrintNestingPass();
116117
registerTestReducer();
117118
registerTestSpirvEntryPointABIPass();
119+
registerTestSpirvGLSLCanonicalizationPass();
118120
registerTestSpirvModuleCombinerPass();
119121
registerTestTraitsPass();
120122
registerVectorizerTestPass();

0 commit comments

Comments
 (0)