Skip to content

Commit 2c7e8ed

Browse files
antiagainsttensorflower-gardener
authored andcommitted
[spirv] Add spv.IAdd, spv.ISub, and spv.IMul folders
The patterns to be folded away can be commonly generated during lowering to SPIR-V. PiperOrigin-RevId: 284604855
1 parent 5a48e40 commit 2c7e8ed

File tree

3 files changed

+193
-5
lines changed

3 files changed

+193
-5
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub", SPV_Integer, []> {
368368

369369
```
370370
}];
371+
372+
let hasFolder = 1;
371373
}
372374

373375
// -----

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
2323

2424
#include "mlir/Analysis/CallInterfaces.h"
25+
#include "mlir/Dialect/CommonFolders.h"
2526
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
2627
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
2728
#include "mlir/IR/Builders.h"
@@ -1753,11 +1754,17 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
17531754

17541755
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
17551756
assert(operands.size() == 2 && "spv.IAdd expects two operands");
1756-
// lhs + 0 = lhs
1757+
// x + 0 = x
17571758
if (matchPattern(operand2(), m_Zero()))
17581759
return operand1();
17591760

1760-
return nullptr;
1761+
// According to the SPIR-V spec:
1762+
//
1763+
// The resulting value will equal the low-order N bits of the correct result
1764+
// R, where N is the component width and R is computed with enough precision
1765+
// to avoid overflow and underflow.
1766+
return constFoldBinaryOp<IntegerAttr>(operands,
1767+
[](APInt a, APInt b) { return a + b; });
17611768
}
17621769

17631770
//===----------------------------------------------------------------------===//
@@ -1766,14 +1773,38 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
17661773

17671774
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
17681775
assert(operands.size() == 2 && "spv.IMul expects two operands");
1769-
// lhs * 0 == 0
1776+
// x * 0 == 0
17701777
if (matchPattern(operand2(), m_Zero()))
17711778
return operand2();
1772-
// lhs * 1 = lhs
1779+
// x * 1 = x
17731780
if (matchPattern(operand2(), m_One()))
17741781
return operand1();
17751782

1776-
return nullptr;
1783+
// According to the SPIR-V spec:
1784+
//
1785+
// The resulting value will equal the low-order N bits of the correct result
1786+
// R, where N is the component width and R is computed with enough precision
1787+
// to avoid overflow and underflow.
1788+
return constFoldBinaryOp<IntegerAttr>(operands,
1789+
[](APInt a, APInt b) { return a * b; });
1790+
}
1791+
1792+
//===----------------------------------------------------------------------===//
1793+
// spv.ISub
1794+
//===----------------------------------------------------------------------===//
1795+
1796+
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
1797+
// x - x = 0
1798+
if (operand1() == operand2())
1799+
return Builder(getContext()).getIntegerAttr(getType(), 0);
1800+
1801+
// According to the SPIR-V spec:
1802+
//
1803+
// The resulting value will equal the low-order N bits of the correct result
1804+
// R, where N is the component width and R is computed with enough precision
1805+
// to avoid overflow and underflow.
1806+
return constFoldBinaryOp<IntegerAttr>(operands,
1807+
[](APInt a, APInt b) { return a - b; });
17771808
}
17781809

17791810
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/canonicalize.mlir

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,54 @@ func @iadd_zero(%arg0: i32) -> (i32, i32) {
187187
return %0, %1: i32, i32
188188
}
189189

190+
// CHECK-LABEL: @const_fold_scalar_iadd_normal
191+
func @const_fold_scalar_iadd_normal() -> (i32, i32, i32) {
192+
%c5 = spv.constant 5 : i32
193+
%cn8 = spv.constant -8 : i32
194+
195+
// CHECK: spv.constant 10
196+
// CHECK: spv.constant -16
197+
// CHECK: spv.constant -3
198+
%0 = spv.IAdd %c5, %c5 : i32
199+
%1 = spv.IAdd %cn8, %cn8 : i32
200+
%2 = spv.IAdd %c5, %cn8 : i32
201+
return %0, %1, %2: i32, i32, i32
202+
}
203+
204+
// CHECK-LABEL: @const_fold_scalar_iadd_flow
205+
func @const_fold_scalar_iadd_flow() -> (i32, i32, i32, i32) {
206+
%c1 = spv.constant 1 : i32
207+
%c2 = spv.constant 2 : i32
208+
%c3 = spv.constant 4294967295 : i32 // 2^32 - 1: 0xffff ffff
209+
%c4 = spv.constant -2147483648 : i32 // -2^31 : 0x8000 0000
210+
%c5 = spv.constant -1 : i32 // : 0xffff ffff
211+
%c6 = spv.constant -2 : i32 // : 0xffff fffe
212+
213+
// 0x0000 0001 + 0xffff ffff = 0x1 0000 0000 -> 0x0000 0000
214+
// CHECK: spv.constant 0
215+
%0 = spv.IAdd %c1, %c3 : i32
216+
// 0x0000 0002 + 0xffff ffff = 0x1 0000 0001 -> 0x0000 0001
217+
// CHECK: spv.constant 1
218+
%1 = spv.IAdd %c2, %c3 : i32
219+
// 0x8000 0000 + 0xffff ffff = 0x1 7fff ffff -> 0x7fff ffff
220+
// CHECK: spv.constant 2147483647
221+
%2 = spv.IAdd %c4, %c5 : i32
222+
// 0x8000 0000 + 0xffff fffe = 0x1 7fff fffe -> 0x7fff fffe
223+
// CHECK: spv.constant 2147483646
224+
%3 = spv.IAdd %c4, %c6 : i32
225+
return %0, %1, %2, %3: i32, i32, i32, i32
226+
}
227+
228+
// CHECK-LABEL: @const_fold_vector_iadd
229+
func @const_fold_vector_iadd() -> vector<3xi32> {
230+
%vc1 = spv.constant dense<[42, -55, 127]> : vector<3xi32>
231+
%vc2 = spv.constant dense<[-3, -15, 28]> : vector<3xi32>
232+
233+
// CHECK: spv.constant dense<[39, -70, 155]>
234+
%0 = spv.IAdd %vc1, %vc2 : vector<3xi32>
235+
return %0: vector<3xi32>
236+
}
237+
190238
// -----
191239

192240
//===----------------------------------------------------------------------===//
@@ -205,6 +253,113 @@ func @imul_zero_one(%arg0: i32) -> (i32, i32) {
205253
return %0, %1: i32, i32
206254
}
207255

256+
// CHECK-LABEL: @const_fold_scalar_imul_normal
257+
func @const_fold_scalar_imul_normal() -> (i32, i32, i32) {
258+
%c5 = spv.constant 5 : i32
259+
%cn8 = spv.constant -8 : i32
260+
%c7 = spv.constant 7 : i32
261+
262+
// CHECK: spv.constant 35
263+
// CHECK: spv.constant -40
264+
// CHECK: spv.constant -56
265+
%0 = spv.IMul %c7, %c5 : i32
266+
%1 = spv.IMul %c5, %cn8 : i32
267+
%2 = spv.IMul %cn8, %c7 : i32
268+
return %0, %1, %2: i32, i32, i32
269+
}
270+
271+
// CHECK-LABEL: @const_fold_scalar_imul_flow
272+
func @const_fold_scalar_imul_flow() -> (i32, i32, i32) {
273+
%c1 = spv.constant 2 : i32
274+
%c2 = spv.constant 4 : i32
275+
%c3 = spv.constant 4294967295 : i32 // 2^32 - 1 : 0xffff ffff
276+
%c4 = spv.constant -2147483649 : i32 // -2^31 - 1: 0x7fff ffff
277+
278+
// (0xffff ffff << 1) = 0x1 ffff fffe -> 0xffff fffe
279+
// CHECK: %[[CST2:.*]] = spv.constant -2
280+
%0 = spv.IMul %c1, %c3 : i32
281+
// (0x7fff ffff << 1) = 0x0 ffff fffe -> 0xffff fffe
282+
%1 = spv.IMul %c1, %c4 : i32
283+
// (0x7fff ffff << 2) = 0x1 ffff fffc -> 0xffff fffc
284+
// CHECK: %[[CST4:.*]] = spv.constant -4
285+
%2 = spv.IMul %c4, %c2 : i32
286+
// CHECK: return %[[CST2]], %[[CST2]], %[[CST4]]
287+
return %0, %1, %2: i32, i32, i32
288+
}
289+
290+
291+
// CHECK-LABEL: @const_fold_vector_imul
292+
func @const_fold_vector_imul() -> vector<3xi32> {
293+
%vc1 = spv.constant dense<[42, -55, 127]> : vector<3xi32>
294+
%vc2 = spv.constant dense<[-3, -15, 28]> : vector<3xi32>
295+
296+
// CHECK: spv.constant dense<[-126, 825, 3556]>
297+
%0 = spv.IMul %vc1, %vc2 : vector<3xi32>
298+
return %0: vector<3xi32>
299+
}
300+
301+
// -----
302+
303+
//===----------------------------------------------------------------------===//
304+
// spv.ISub
305+
//===----------------------------------------------------------------------===//
306+
307+
// CHECK-LABEL: @isub_x_x
308+
func @isub_x_x(%arg0: i32) -> i32 {
309+
// CHECK: spv.constant 0
310+
%0 = spv.ISub %arg0, %arg0: i32
311+
return %0: i32
312+
}
313+
314+
// CHECK-LABEL: @const_fold_scalar_isub_normal
315+
func @const_fold_scalar_isub_normal() -> (i32, i32, i32) {
316+
%c5 = spv.constant 5 : i32
317+
%cn8 = spv.constant -8 : i32
318+
%c7 = spv.constant 7 : i32
319+
320+
// CHECK: spv.constant 2
321+
// CHECK: spv.constant 13
322+
// CHECK: spv.constant -15
323+
%0 = spv.ISub %c7, %c5 : i32
324+
%1 = spv.ISub %c5, %cn8 : i32
325+
%2 = spv.ISub %cn8, %c7 : i32
326+
return %0, %1, %2: i32, i32, i32
327+
}
328+
329+
// CHECK-LABEL: @const_fold_scalar_isub_flow
330+
func @const_fold_scalar_isub_flow() -> (i32, i32, i32, i32) {
331+
%c1 = spv.constant 0 : i32
332+
%c2 = spv.constant 1 : i32
333+
%c3 = spv.constant 4294967295 : i32 // 2^32 - 1 : 0xffff ffff
334+
%c4 = spv.constant -2147483649 : i32 // -2^31 - 1: 0x7fff ffff
335+
%c5 = spv.constant -1 : i32 // : 0xffff ffff
336+
%c6 = spv.constant -2 : i32 // : 0xffff fffe
337+
338+
// 0x0000 0000 - 0xffff ffff -> 0x0000 0000 + 0x0000 0001 = 0x0000 0001
339+
// CHECK: spv.constant 1
340+
%0 = spv.ISub %c1, %c3 : i32
341+
// 0x0000 0001 - 0xffff ffff -> 0x0000 0001 + 0x0000 0001 = 0x0000 0002
342+
// CHECK: spv.constant 2
343+
%1 = spv.ISub %c2, %c3 : i32
344+
// 0xffff ffff - 0x7fff ffff -> 0xffff ffff + 0x8000 0001 = 0x1 8000 0000
345+
// CHECK: spv.constant -2147483648
346+
%2 = spv.ISub %c5, %c4 : i32
347+
// 0xffff fffe - 0x7fff ffff -> 0xffff fffe + 0x8000 0001 = 0x1 7fff ffff
348+
// CHECK: spv.constant 2147483647
349+
%3 = spv.ISub %c6, %c4 : i32
350+
return %0, %1, %2, %3: i32, i32, i32, i32
351+
}
352+
353+
// CHECK-LABEL: @const_fold_vector_isub
354+
func @const_fold_vector_isub() -> vector<3xi32> {
355+
%vc1 = spv.constant dense<[42, -55, 127]> : vector<3xi32>
356+
%vc2 = spv.constant dense<[-3, -15, 28]> : vector<3xi32>
357+
358+
// CHECK: spv.constant dense<[45, -40, 99]>
359+
%0 = spv.ISub %vc1, %vc2 : vector<3xi32>
360+
return %0: vector<3xi32>
361+
}
362+
208363
// -----
209364

210365
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)