Skip to content

Commit 34265da

Browse files
Denis Khalikovtensorflower-gardener
authored andcommitted
[spirv] Add CompositeConstruct operation.
Closes tensorflow/mlir#308 COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#308 from denis0x0D:sandbox/composite_construct 9ef7180f77f9374bcd05afc4f9e6c1d2d72d02b7 PiperOrigin-RevId: 284613617
1 parent 2c7e8ed commit 34265da

File tree

5 files changed

+193
-15
lines changed

5 files changed

+193
-15
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
10751075
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
10761076
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
10771077
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
1078+
def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
10781079
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
10791080
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
10801081
def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>;
@@ -1171,20 +1172,21 @@ def SPV_OpcodeAttr :
11711172
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
11721173
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
11731174
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
1174-
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert,
1175-
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
1176-
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
1177-
SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
1178-
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
1179-
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
1180-
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
1181-
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
1182-
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
1183-
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
1184-
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
1185-
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
1186-
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
1187-
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
1175+
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
1176+
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
1177+
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
1178+
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
1179+
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
1180+
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
1181+
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
1182+
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
1183+
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
1184+
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
1185+
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
1186+
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
1187+
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
1188+
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
1189+
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
11881190
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
11891191
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
11901192
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,58 @@
2525

2626
include "mlir/Dialect/SPIRV/SPIRVBase.td"
2727

28+
// -----
29+
30+
def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> {
31+
let summary = [{
32+
Construct a new composite object from a set of constituent objects that
33+
will fully form it.
34+
}];
35+
36+
let description = [{
37+
Result Type must be a composite type, whose top-level
38+
members/elements/components/columns have the same type as the types of
39+
the operands, with one exception. The exception is that for constructing
40+
a vector, the operands may also be vectors with the same component type
41+
as the Result Type component type. When constructing a vector, the total
42+
number of components in all the operands must equal the number of
43+
components in Result Type.
44+
45+
Constituents will become members of a structure, or elements of an
46+
array, or components of a vector, or columns of a matrix. There must be
47+
exactly one Constituent for each top-level
48+
member/element/component/column of the result, with one exception. The
49+
exception is that for constructing a vector, a contiguous subset of the
50+
scalars consumed can be represented by a vector operand instead. The
51+
Constituents must appear in the order needed by the definition of the
52+
type of the result. When constructing a vector, there must be at least
53+
two Constituent operands.
54+
55+
### Custom assembly form
56+
57+
``` {.ebnf}
58+
composite-construct-op ::= ssa-id `=` `spv.CompositeConstruct`
59+
(ssa-use (`,` ssa-use)* )? `:` composite-type
60+
```
61+
62+
For example:
63+
64+
```
65+
%0 = spv.CompositeConstruct %1, %2, %3 : vector<3xf32>
66+
```
67+
}];
68+
69+
let arguments = (ins
70+
Variadic<SPV_Type>:$constituents
71+
);
72+
73+
let results = (outs
74+
SPV_Composite:$result
75+
);
76+
}
77+
78+
// -----
79+
2880
def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
2981
let summary = "Extract a part of a composite object.";
3082

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,73 @@ static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
10691069
return success();
10701070
}
10711071

1072+
//===----------------------------------------------------------------------===//
1073+
// spv.CompositeConstruct
1074+
//===----------------------------------------------------------------------===//
1075+
1076+
static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
1077+
OperationState &state) {
1078+
SmallVector<OpAsmParser::OperandType, 4> operands;
1079+
Type type;
1080+
auto loc = parser.getCurrentLocation();
1081+
1082+
if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1083+
return failure();
1084+
}
1085+
auto cType = type.dyn_cast<spirv::CompositeType>();
1086+
if (!cType) {
1087+
return parser.emitError(
1088+
loc, "result type must be a composite type, but provided ")
1089+
<< type;
1090+
}
1091+
1092+
if (operands.size() != cType.getNumElements()) {
1093+
return parser.emitError(loc, "has incorrect number of operands: expected ")
1094+
<< cType.getNumElements() << ", but provided " << operands.size();
1095+
}
1096+
// TODO: Add support for constructing a vector type from the vector operands.
1097+
// According to the spec: "for constructing a vector, the operands may
1098+
// also be vectors with the same component type as the Result Type component
1099+
// type".
1100+
SmallVector<Type, 4> elementTypes;
1101+
elementTypes.reserve(cType.getNumElements());
1102+
for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) {
1103+
elementTypes.push_back(cType.getElementType(index));
1104+
}
1105+
state.addTypes(type);
1106+
return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1107+
}
1108+
1109+
static void print(spirv::CompositeConstructOp compositeConstructOp,
1110+
OpAsmPrinter &printer) {
1111+
printer << spirv::CompositeConstructOp::getOperationName() << " ";
1112+
printer.printOperands(compositeConstructOp.constituents());
1113+
printer << " : " << compositeConstructOp.getResult()->getType();
1114+
}
1115+
1116+
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
1117+
auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
1118+
1119+
SmallVector<Value *, 4> constituents(compositeConstructOp.constituents());
1120+
if (constituents.size() != cType.getNumElements()) {
1121+
return compositeConstructOp.emitError(
1122+
"has incorrect number of operands: expected ")
1123+
<< cType.getNumElements() << ", but provided "
1124+
<< constituents.size();
1125+
}
1126+
1127+
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1128+
if (constituents[index]->getType() != cType.getElementType(index)) {
1129+
return compositeConstructOp.emitError(
1130+
"operand type mismatch: expected operand type ")
1131+
<< cType.getElementType(index) << ", but provided "
1132+
<< constituents[index]->getType();
1133+
}
1134+
}
1135+
1136+
return success();
1137+
}
1138+
10721139
//===----------------------------------------------------------------------===//
10731140
// spv.CompositeExtractOp
10741141
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22

33
spv.module "Logical" "GLSL450" {
44
func @composite_insert(%arg0 : !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> {
5-
// CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>>
5+
// CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>>
66
%0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>
77
spv.ReturnValue %0: !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>
88
}
9+
func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
10+
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
11+
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
12+
spv.ReturnValue %0: vector<3xf32>
13+
}
914
}

mlir/test/Dialect/SPIRV/composite-ops.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,57 @@
11
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
22

3+
//===----------------------------------------------------------------------===//
4+
// spv.CompositeConstruct
5+
//===----------------------------------------------------------------------===//
6+
7+
func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
8+
// CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
9+
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
10+
return %0: vector<3xf32>
11+
}
12+
13+
// -----
14+
15+
func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<f32>) -> !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>> {
16+
// CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<vector<3xf32>, !spv.array<4 x f32>, !spv.struct<f32>>
17+
%0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>>
18+
return %0: !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>>
19+
}
20+
21+
// -----
22+
23+
func @composite_construct_empty_struct() -> !spv.struct<> {
24+
// CHECK: spv.CompositeConstruct : !spv.struct<>
25+
%0 = spv.CompositeConstruct : !spv.struct<>
26+
return %0: !spv.struct<>
27+
}
28+
29+
// -----
30+
31+
func @composite_construct_invalid_num_of_elements(%arg0: f32) -> f32 {
32+
// expected-error @+1 {{result type must be a composite type, but provided 'f32'}}
33+
%0 = spv.CompositeConstruct %arg0 : f32
34+
return %0: f32
35+
}
36+
37+
// -----
38+
39+
func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
40+
// expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}}
41+
%0 = spv.CompositeConstruct %arg0, %arg2 : vector<3xf32>
42+
return %0: vector<3xf32>
43+
}
44+
45+
// -----
46+
47+
func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> {
48+
// expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}}
49+
%0 = "spv.CompositeConstruct" (%arg0, %arg1, %arg2) : (f32, f32, f32) -> vector<3xi32>
50+
return %0: vector<3xi32>
51+
}
52+
53+
// -----
54+
355
//===----------------------------------------------------------------------===//
456
// spv.CompositeExtractOp
557
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)