Skip to content

Commit d9b4245

Browse files
abialas1ThomasRaoux
authored andcommitted
[mlir][spirv] Add block read and write from SPV_INTEL_subgroups
Added support to OpSubgroupBlockReadINTEL and OpSubgroupBlockWriteINTEL Differential Revision: https://reviews.llvm.org/D86876
1 parent f7e04b7 commit d9b4245

File tree

5 files changed

+268
-2
lines changed

5 files changed

+268
-2
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3252,6 +3252,8 @@ def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoa
32523252
def SPV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
32533253
def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
32543254
def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
3255+
def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
3256+
def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
32553257

32563258
def SPV_OpcodeAttr :
32573259
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3308,7 +3310,8 @@ def SPV_OpcodeAttr :
33083310
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
33093311
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
33103312
SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
3311-
SPV_OC_OpCooperativeMatrixLengthNV
3313+
SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL,
3314+
SPV_OC_OpSubgroupBlockWriteINTEL
33123315
]>;
33133316

33143317
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
8888
let assemblyFormat = [{
8989
$execution_scope operands attr-dict `:` type($value) `,` type($localid)
9090
}];
91-
9291
}
9392

9493
// -----
@@ -147,4 +146,104 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
147146

148147
// -----
149148

149+
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
150+
let summary = "See extension SPV_INTEL_subgroups";
151+
152+
let description = [{
153+
Reads one or more components of Result data for each invocation in the
154+
subgroup from the specified Ptr as a block operation.
155+
156+
The data is read strided, so the first value read is:
157+
Ptr[ SubgroupLocalInvocationId ]
158+
159+
and the second value read is:
160+
Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
161+
etc.
162+
163+
Result Type may be a scalar or vector type, and its component type must be
164+
equal to the type pointed to by Ptr.
165+
166+
The type of Ptr must be a pointer type, and must point to a scalar type.
167+
168+
<!-- End of AutoGen section -->
169+
170+
```
171+
subgroup-block-read-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockReadINTEL`
172+
storage-class ssa_use `:` spirv-element-type
173+
```mlir
174+
175+
#### Example:
176+
177+
```
178+
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
179+
```
180+
}];
181+
182+
let availability = [
183+
MinVersion<SPV_V_1_0>,
184+
MaxVersion<SPV_V_1_5>,
185+
Extension<[SPV_INTEL_subgroups]>,
186+
Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
187+
];
188+
189+
let arguments = (ins
190+
SPV_AnyPtr:$ptr
191+
);
192+
193+
let results = (outs
194+
SPV_Type:$value
195+
);
196+
}
197+
198+
// -----
199+
200+
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
201+
let summary = "See extension SPV_INTEL_subgroups";
202+
203+
let description = [{
204+
Writes one or more components of Data for each invocation in the subgroup
205+
from the specified Ptr as a block operation.
206+
207+
The data is written strided, so the first value is written to:
208+
Ptr[ SubgroupLocalInvocationId ]
209+
210+
and the second value written is:
211+
Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
212+
etc.
213+
214+
The type of Ptr must be a pointer type, and must point to a scalar type.
215+
216+
The component type of Data must be equal to the type pointed to by Ptr.
217+
218+
<!-- End of AutoGen section -->
219+
220+
```
221+
subgroup-block-write-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockWriteINTEL`
222+
storage-class ssa_use `,` ssa-use `:` spirv-element-type
223+
```mlir
224+
225+
#### Example:
226+
227+
```
228+
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
229+
```
230+
}];
231+
232+
let availability = [
233+
MinVersion<SPV_V_1_0>,
234+
MaxVersion<SPV_V_1_5>,
235+
Extension<[SPV_INTEL_subgroups]>,
236+
Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
237+
];
238+
239+
let arguments = (ins
240+
SPV_AnyPtr:$ptr,
241+
SPV_Type:$value
242+
);
243+
244+
let results = (outs);
245+
}
246+
247+
// -----
248+
150249
#endif // SPIRV_GROUP_OPS

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,19 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
468468
return success();
469469
}
470470

471+
template <typename BlockReadWriteOpTy>
472+
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
473+
Value ptr, Value val) {
474+
auto valType = val.getType();
475+
if (auto valVecTy = valType.dyn_cast<VectorType>())
476+
valType = valVecTy.getElementType();
477+
478+
if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
479+
return op.emitOpError("mismatch in result type and pointer type");
480+
}
481+
return success();
482+
}
483+
471484
static ParseResult parseVariableDecorations(OpAsmParser &parser,
472485
OperationState &state) {
473486
auto builtInName = llvm::convertToSnakeFromCamelCase(
@@ -2025,6 +2038,93 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
20252038
return success();
20262039
}
20272040

2041+
//===----------------------------------------------------------------------===//
2042+
// spv.SubgroupBlockReadINTEL
2043+
//===----------------------------------------------------------------------===//
2044+
2045+
static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
2046+
OperationState &state) {
2047+
// Parse the storage class specification
2048+
spirv::StorageClass storageClass;
2049+
OpAsmParser::OperandType ptrInfo;
2050+
Type elementType;
2051+
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2052+
parser.parseColon() || parser.parseType(elementType)) {
2053+
return failure();
2054+
}
2055+
2056+
auto ptrType = spirv::PointerType::get(elementType, storageClass);
2057+
if (auto valVecTy = elementType.dyn_cast<VectorType>())
2058+
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2059+
2060+
if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2061+
return failure();
2062+
}
2063+
2064+
state.addTypes(elementType);
2065+
return success();
2066+
}
2067+
2068+
static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
2069+
OpAsmPrinter &printer) {
2070+
SmallVector<StringRef, 4> elidedAttrs;
2071+
printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
2072+
<< blockReadOp.ptr();
2073+
printer << " : " << blockReadOp.getType();
2074+
}
2075+
2076+
static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
2077+
if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
2078+
blockReadOp.value())))
2079+
return failure();
2080+
2081+
return success();
2082+
}
2083+
2084+
//===----------------------------------------------------------------------===//
2085+
// spv.SubgroupBlockWriteINTEL
2086+
//===----------------------------------------------------------------------===//
2087+
2088+
static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
2089+
OperationState &state) {
2090+
// Parse the storage class specification
2091+
spirv::StorageClass storageClass;
2092+
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2093+
auto loc = parser.getCurrentLocation();
2094+
Type elementType;
2095+
if (parseEnumStrAttr(storageClass, parser) ||
2096+
parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2097+
parser.parseType(elementType)) {
2098+
return failure();
2099+
}
2100+
2101+
auto ptrType = spirv::PointerType::get(elementType, storageClass);
2102+
if (auto valVecTy = elementType.dyn_cast<VectorType>())
2103+
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2104+
2105+
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2106+
state.operands)) {
2107+
return failure();
2108+
}
2109+
return success();
2110+
}
2111+
2112+
static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
2113+
OpAsmPrinter &printer) {
2114+
SmallVector<StringRef, 4> elidedAttrs;
2115+
printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
2116+
<< blockWriteOp.ptr() << ", " << blockWriteOp.value();
2117+
printer << " : " << blockWriteOp.value().getType();
2118+
}
2119+
2120+
static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
2121+
if (failed(verifyBlockReadWritePtrAndValTypes(
2122+
blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
2123+
return failure();
2124+
2125+
return success();
2126+
}
2127+
20282128
//===----------------------------------------------------------------------===//
20292129
// spv.GroupNonUniformElectOp
20302130
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,28 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
1919
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
2020
spv.ReturnValue %0: f32
2121
}
22+
// CHECK-LABEL: @subgroup_block_read_intel
23+
spv.func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 "None" {
24+
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
25+
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
26+
spv.ReturnValue %0: i32
27+
}
28+
// CHECK-LABEL: @subgroup_block_read_intel_vector
29+
spv.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
30+
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
31+
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
32+
spv.ReturnValue %0: vector<3xi32>
33+
}
34+
// CHECK-LABEL: @subgroup_block_write_intel
35+
spv.func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
36+
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
37+
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
38+
spv.Return
39+
}
40+
// CHECK-LABEL: @subgroup_block_write_intel_vector
41+
spv.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
42+
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
43+
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
44+
spv.Return
45+
}
2246
}

mlir/test/Dialect/SPIRV/group-ops.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,43 @@ func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> )
6161
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
6262
return %0: f32
6363
}
64+
65+
// -----
66+
67+
//===----------------------------------------------------------------------===//
68+
// spv.SubgroupBlockReadINTEL
69+
//===----------------------------------------------------------------------===//
70+
71+
func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 {
72+
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
73+
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
74+
return %0: i32
75+
}
76+
77+
// -----
78+
79+
func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> {
80+
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
81+
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
82+
return %0: vector<3xi32>
83+
}
84+
85+
// -----
86+
87+
//===----------------------------------------------------------------------===//
88+
// spv.SubgroupBlockWriteINTEL
89+
//===----------------------------------------------------------------------===//
90+
91+
func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () {
92+
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
93+
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
94+
return
95+
}
96+
97+
// -----
98+
99+
func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () {
100+
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
101+
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
102+
return
103+
}

0 commit comments

Comments
 (0)