Skip to content

Commit 5c16eeb

Browse files
committed
[mlir][spirv] Define spv.IAddCarry
Based on `spv.ISubBorrow` from D127909. Also resolved some clang-tidy warnings. Reviewed By: antiagainst, ThomasRaoux Differential Revision: https://reviews.llvm.org/D131281
1 parent 8d2901d commit 5c16eeb

File tree

4 files changed

+168
-21
lines changed

4 files changed

+168
-21
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,55 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd",
310310

311311
// -----
312312

313+
def SPV_IAddCarryOp : SPV_BinaryOp<"IAddCarry",
314+
SPV_AnyStruct, SPV_Integer,
315+
[Commutative, NoSideEffect]> {
316+
let summary = [{
317+
Integer addition of Operand 1 and Operand 2, including the carry.
318+
}];
319+
320+
let description = [{
321+
Result Type must be from OpTypeStruct. The struct must have two
322+
members, and the two members must be the same type. The member type
323+
must be a scalar or vector of integer type, whose Signedness operand is
324+
0.
325+
326+
Operand 1 and Operand 2 must have the same type as the members of Result
327+
Type. These are consumed as unsigned integers.
328+
329+
Results are computed per component.
330+
331+
Member 0 of the result gets the low-order bits (full component width) of
332+
the addition.
333+
334+
Member 1 of the result gets the high-order (carry) bit of the result of
335+
the addition. That is, it gets the value 1 if the addition overflowed
336+
the component width, and 0 otherwise.
337+
338+
<!-- End of AutoGen section -->
339+
340+
#### Example:
341+
342+
```mlir
343+
%2 = spv.IAddCarry %0, %1 : !spv.struct<(i32, i32)>
344+
%2 = spv.IAddCarry %0, %1 : !spv.struct<(vector<2xi32>, vector<2xi32>)>
345+
```
346+
}];
347+
348+
let arguments = (ins
349+
SPV_ScalarOrVectorOf<SPV_Integer>:$operand1,
350+
SPV_ScalarOrVectorOf<SPV_Integer>:$operand2
351+
);
352+
353+
let results = (outs
354+
SPV_AnyStruct:$result
355+
);
356+
357+
let hasVerifier = 1;
358+
}
359+
360+
// -----
361+
313362
def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul",
314363
SPV_Integer,
315364
[Commutative, UsableInSpecConstantOp]> {

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4094,6 +4094,7 @@ def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
40944094
def SPV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
40954095
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
40964096
def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
4097+
def SPV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
40974098
def SPV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
40984099
def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
40994100
def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>;
@@ -4219,16 +4220,16 @@ def SPV_OpcodeAttr :
42194220
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
42204221
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
42214222
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar,
4222-
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpISubBorrow,
4223-
SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered,
4224-
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
4225-
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
4226-
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
4227-
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
4228-
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
4229-
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
4230-
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
4231-
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
4223+
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIAddCarry,
4224+
SPV_OC_OpISubBorrow, SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered,
4225+
SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
4226+
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
4227+
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
4228+
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
4229+
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
4230+
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
4231+
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
4232+
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
42324233
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
42334234
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
42344235
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
2020
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2121
#include "mlir/IR/Builders.h"
22-
#include "mlir/IR/BuiltinOps.h"
2322
#include "mlir/IR/BuiltinTypes.h"
2423
#include "mlir/IR/FunctionImplementation.h"
2524
#include "mlir/IR/OpDefinition.h"
@@ -2840,6 +2839,55 @@ void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
28402839
printGroupNonUniformArithmeticOp(*this, p);
28412840
}
28422841

2842+
//===----------------------------------------------------------------------===//
2843+
// spv.IAddCarryOp
2844+
//===----------------------------------------------------------------------===//
2845+
2846+
LogicalResult spirv::IAddCarryOp::verify() {
2847+
auto resultType = getType().cast<spirv::StructType>();
2848+
if (resultType.getNumElements() != 2)
2849+
return emitOpError("expected result struct type containing two members");
2850+
2851+
if (!llvm::is_splat(llvm::makeArrayRef(
2852+
{operand1().getType(), operand2().getType(),
2853+
resultType.getElementType(0), resultType.getElementType(1)})))
2854+
return emitOpError(
2855+
"expected all operand types and struct member types are the same");
2856+
2857+
return success();
2858+
}
2859+
2860+
ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
2861+
OperationState &result) {
2862+
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
2863+
if (parser.parseOptionalAttrDict(result.attributes) ||
2864+
parser.parseOperandList(operands) || parser.parseColon())
2865+
return failure();
2866+
2867+
Type resultType;
2868+
SMLoc loc = parser.getCurrentLocation();
2869+
if (parser.parseType(resultType))
2870+
return failure();
2871+
2872+
auto structType = resultType.dyn_cast<spirv::StructType>();
2873+
if (!structType || structType.getNumElements() != 2)
2874+
return parser.emitError(loc, "expected spv.struct type with two members");
2875+
2876+
SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2877+
if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
2878+
return failure();
2879+
2880+
result.addTypes(resultType);
2881+
return success();
2882+
}
2883+
2884+
void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
2885+
printer << ' ';
2886+
printer.printOptionalAttrDict((*this)->getAttrs());
2887+
printer.printOperands((*this)->getOperands());
2888+
printer << " : " << getType();
2889+
}
2890+
28432891
//===----------------------------------------------------------------------===//
28442892
// spv.ISubBorrowOp
28452893
//===----------------------------------------------------------------------===//
@@ -2849,22 +2897,19 @@ LogicalResult spirv::ISubBorrowOp::verify() {
28492897
if (resultType.getNumElements() != 2)
28502898
return emitOpError("expected result struct type containing two members");
28512899

2852-
SmallVector<Type, 4> types;
2853-
types.push_back(operand1().getType());
2854-
types.push_back(operand2().getType());
2855-
types.push_back(resultType.getElementType(0));
2856-
types.push_back(resultType.getElementType(1));
2857-
if (!llvm::is_splat(types))
2900+
if (!llvm::is_splat(llvm::makeArrayRef(
2901+
{operand1().getType(), operand2().getType(),
2902+
resultType.getElementType(0), resultType.getElementType(1)})))
28582903
return emitOpError(
28592904
"expected all operand types and struct member types are the same");
28602905

28612906
return success();
28622907
}
28632908

28642909
ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
2865-
OperationState &state) {
2910+
OperationState &result) {
28662911
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
2867-
if (parser.parseOptionalAttrDict(state.attributes) ||
2912+
if (parser.parseOptionalAttrDict(result.attributes) ||
28682913
parser.parseOperandList(operands) || parser.parseColon())
28692914
return failure();
28702915

@@ -2878,10 +2923,10 @@ ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
28782923
return parser.emitError(loc, "expected spv.struct type with two members");
28792924

28802925
SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2881-
if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
2926+
if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
28822927
return failure();
28832928

2884-
state.addTypes(resultType);
2929+
result.addTypes(resultType);
28852930
return success();
28862931
}
28872932

mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,58 @@ func.func @isub_scalar(%arg: i32) -> i32 {
150150

151151
// -----
152152

153+
//===----------------------------------------------------------------------===//
154+
// spv.IAddCarry
155+
//===----------------------------------------------------------------------===//
156+
157+
// CHECK-LABEL: @iadd_carry_scalar
158+
func.func @iadd_carry_scalar(%arg: i32) -> !spv.struct<(i32, i32)> {
159+
// CHECK: spv.IAddCarry %{{.+}}, %{{.+}} : !spv.struct<(i32, i32)>
160+
%0 = spv.IAddCarry %arg, %arg : !spv.struct<(i32, i32)>
161+
return %0 : !spv.struct<(i32, i32)>
162+
}
163+
164+
// CHECK-LABEL: @iadd_carry_vector
165+
func.func @iadd_carry_vector(%arg: vector<3xi32>) -> !spv.struct<(vector<3xi32>, vector<3xi32>)> {
166+
// CHECK: spv.IAddCarry %{{.+}}, %{{.+}} : !spv.struct<(vector<3xi32>, vector<3xi32>)>
167+
%0 = spv.IAddCarry %arg, %arg : !spv.struct<(vector<3xi32>, vector<3xi32>)>
168+
return %0 : !spv.struct<(vector<3xi32>, vector<3xi32>)>
169+
}
170+
171+
// -----
172+
173+
func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32, i32, i32)> {
174+
// expected-error @+1 {{expected spv.struct type with two members}}
175+
%0 = spv.IAddCarry %arg, %arg : !spv.struct<(i32, i32, i32)>
176+
return %0 : !spv.struct<(i32, i32, i32)>
177+
}
178+
179+
// -----
180+
181+
func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32)> {
182+
// expected-error @+1 {{expected result struct type containing two members}}
183+
%0 = "spv.IAddCarry"(%arg, %arg): (i32, i32) -> !spv.struct<(i32)>
184+
return %0 : !spv.struct<(i32)>
185+
}
186+
187+
// -----
188+
189+
func.func @iadd_carry(%arg: i32) -> !spv.struct<(i32, i64)> {
190+
// expected-error @+1 {{expected all operand types and struct member types are the same}}
191+
%0 = "spv.IAddCarry"(%arg, %arg): (i32, i32) -> !spv.struct<(i32, i64)>
192+
return %0 : !spv.struct<(i32, i64)>
193+
}
194+
195+
// -----
196+
197+
func.func @iadd_carry(%arg: i64) -> !spv.struct<(i32, i32)> {
198+
// expected-error @+1 {{expected all operand types and struct member types are the same}}
199+
%0 = "spv.IAddCarry"(%arg, %arg): (i64, i64) -> !spv.struct<(i32, i32)>
200+
return %0 : !spv.struct<(i32, i32)>
201+
}
202+
203+
// -----
204+
153205
//===----------------------------------------------------------------------===//
154206
// spv.ISubBorrow
155207
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)