Skip to content

Commit 3847a6a

Browse files
ergawyantiagainst
authored andcommitted
[MLIR][SPIRV] Support two memory access attributes in OpCopyMemory.
This commit augments spv.CopyMemory's implementation to support 2 memory access operands. Hence, more closely following the spec. The following changes are introduces: - Customize logic for spv.CopyMemory serialization and deserialization. - Add 2 additional attributes for source memory access operand. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D83241
1 parent ce22527 commit 3847a6a

File tree

6 files changed

+331
-21
lines changed

6 files changed

+331
-21
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
198198
```
199199
copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use
200200
storage-class ssa-use
201-
(`[` memory-access `]`)?
201+
(`[` memory-access `]` (`, [` memory-access `]`)?)?
202202
` : ` spirv-element-type
203203
```
204204

@@ -215,12 +215,16 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
215215
SPV_AnyPtr:$target,
216216
SPV_AnyPtr:$source,
217217
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
218-
OptionalAttr<I32Attr>:$alignment
218+
OptionalAttr<I32Attr>:$alignment,
219+
OptionalAttr<SPV_MemoryAccessAttr>:$source_memory_access,
220+
OptionalAttr<I32Attr>:$source_alignment
219221
);
220222

221223
let results = (outs);
222224

223225
let verifier = [{ return verifyCopyMemory(*this); }];
226+
227+
let autogenSerialization = 0;
224228
}
225229

226230
// -----

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 167 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
using namespace mlir;
2929

3030
// TODO: generate these strings using ODS.
31+
static constexpr const char kMemoryAccessAttrName[] = "memory_access";
32+
static constexpr const char kSourceMemoryAccessAttrName[] =
33+
"source_memory_access";
3134
static constexpr const char kAlignmentAttrName[] = "alignment";
35+
static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
3236
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
3337
static constexpr const char kCallee[] = "callee";
3438
static constexpr const char kClusterSize[] = "cluster_size";
@@ -157,6 +161,12 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
157161
return success();
158162
}
159163

164+
/// Parses optional memory access attributes attached to a memory access
165+
/// operand/pointer. Specifically, parses the following syntax:
166+
/// (`[` memory-access `]`)?
167+
/// where:
168+
/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
169+
/// integer-literal | `"NonTemporal"`
160170
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
161171
OperationState &state) {
162172
// Parse an optional list of attributes staring with '['
@@ -166,7 +176,8 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
166176
}
167177

168178
spirv::MemoryAccess memoryAccessAttr;
169-
if (parseEnumStrAttr(memoryAccessAttr, parser, state)) {
179+
if (parseEnumStrAttr(memoryAccessAttr, parser, state,
180+
kMemoryAccessAttrName)) {
170181
return failure();
171182
}
172183

@@ -183,19 +194,90 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
183194
return parser.parseRSquare();
184195
}
185196

197+
// TODO Make sure to merge this and the previous function into one template
198+
// parameterized by memroy access attribute name and alignment. Doing so now
199+
// results in VS2017 in producing an internal error (at the call site) that's
200+
// not detailed enough to understand what is happenning.
201+
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
202+
OperationState &state) {
203+
// Parse an optional list of attributes staring with '['
204+
if (parser.parseOptionalLSquare()) {
205+
// Nothing to do
206+
return success();
207+
}
208+
209+
spirv::MemoryAccess memoryAccessAttr;
210+
if (parseEnumStrAttr(memoryAccessAttr, parser, state,
211+
kSourceMemoryAccessAttrName)) {
212+
return failure();
213+
}
214+
215+
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
216+
// Parse integer attribute for alignment.
217+
Attribute alignmentAttr;
218+
Type i32Type = parser.getBuilder().getIntegerType(32);
219+
if (parser.parseComma() ||
220+
parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
221+
state.attributes)) {
222+
return failure();
223+
}
224+
}
225+
return parser.parseRSquare();
226+
}
227+
186228
template <typename MemoryOpTy>
187-
static void
188-
printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
189-
SmallVectorImpl<StringRef> &elidedAttrs) {
229+
static void printMemoryAccessAttribute(
230+
MemoryOpTy memoryOp, OpAsmPrinter &printer,
231+
SmallVectorImpl<StringRef> &elidedAttrs,
232+
Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
233+
Optional<llvm::APInt> alignmentAttrValue = None) {
190234
// Print optional memory access attribute.
191-
if (auto memAccess = memoryOp.memory_access()) {
192-
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
235+
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
236+
: memoryOp.memory_access())) {
237+
elidedAttrs.push_back(kMemoryAccessAttrName);
238+
193239
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
194240

195-
// Print integer alignment attribute.
196-
if (auto alignment = memoryOp.alignment()) {
197-
elidedAttrs.push_back(kAlignmentAttrName);
198-
printer << ", " << alignment;
241+
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
242+
// Print integer alignment attribute.
243+
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
244+
: memoryOp.alignment())) {
245+
elidedAttrs.push_back(kAlignmentAttrName);
246+
printer << ", " << alignment;
247+
}
248+
}
249+
printer << "]";
250+
}
251+
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
252+
}
253+
254+
// TODO Make sure to merge this and the previous function into one template
255+
// parameterized by memroy access attribute name and alignment. Doing so now
256+
// results in VS2017 in producing an internal error (at the call site) that's
257+
// not detailed enough to understand what is happenning.
258+
template <typename MemoryOpTy>
259+
static void printSourceMemoryAccessAttribute(
260+
MemoryOpTy memoryOp, OpAsmPrinter &printer,
261+
SmallVectorImpl<StringRef> &elidedAttrs,
262+
Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
263+
Optional<llvm::APInt> alignmentAttrValue = None) {
264+
265+
printer << ", ";
266+
267+
// Print optional memory access attribute.
268+
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
269+
: memoryOp.memory_access())) {
270+
elidedAttrs.push_back(kSourceMemoryAccessAttrName);
271+
272+
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
273+
274+
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
275+
// Print integer alignment attribute.
276+
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
277+
: memoryOp.alignment())) {
278+
elidedAttrs.push_back(kSourceAlignmentAttrName);
279+
printer << ", " << alignment;
280+
}
199281
}
200282
printer << "]";
201283
}
@@ -249,7 +331,7 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
249331
// memory-access attribute is Aligned, then the alignment attribute must be
250332
// present.
251333
auto *op = memoryOp.getOperation();
252-
auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
334+
auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
253335
if (!memAccessAttr) {
254336
// Alignment attribute shouldn't be present if memory access attribute is
255337
// not present.
@@ -283,6 +365,50 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
283365
return success();
284366
}
285367

368+
// TODO Make sure to merge this and the previous function into one template
369+
// parameterized by memroy access attribute name and alignment. Doing so now
370+
// results in VS2017 in producing an internal error (at the call site) that's
371+
// not detailed enough to understand what is happenning.
372+
template <typename MemoryOpTy>
373+
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
374+
// ODS checks for attributes values. Just need to verify that if the
375+
// memory-access attribute is Aligned, then the alignment attribute must be
376+
// present.
377+
auto *op = memoryOp.getOperation();
378+
auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
379+
if (!memAccessAttr) {
380+
// Alignment attribute shouldn't be present if memory access attribute is
381+
// not present.
382+
if (op->getAttr(kSourceAlignmentAttrName)) {
383+
return memoryOp.emitOpError(
384+
"invalid alignment specification without aligned memory access "
385+
"specification");
386+
}
387+
return success();
388+
}
389+
390+
auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
391+
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
392+
393+
if (!memAccess) {
394+
return memoryOp.emitOpError("invalid memory access specifier: ")
395+
<< memAccessVal;
396+
}
397+
398+
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
399+
if (!op->getAttr(kSourceAlignmentAttrName)) {
400+
return memoryOp.emitOpError("missing alignment value");
401+
}
402+
} else {
403+
if (op->getAttr(kSourceAlignmentAttrName)) {
404+
return memoryOp.emitOpError(
405+
"invalid alignment specification with non-aligned memory access "
406+
"specification");
407+
}
408+
}
409+
return success();
410+
}
411+
286412
template <typename BarrierOp>
287413
static LogicalResult verifyMemorySemantics(BarrierOp op) {
288414
// According to the SPIR-V specification:
@@ -2832,6 +2958,9 @@ static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
28322958

28332959
SmallVector<StringRef, 4> elidedAttrs;
28342960
printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
2961+
printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
2962+
copyMemory.source_memory_access(),
2963+
copyMemory.source_alignment());
28352964

28362965
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
28372966

@@ -2854,12 +2983,23 @@ static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
28542983
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
28552984
parseEnumStrAttr(sourceStorageClass, parser) ||
28562985
parser.parseOperand(sourcePtrInfo) ||
2857-
parseMemoryAccessAttributes(parser, state) ||
2858-
parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2859-
parser.parseType(elementType)) {
2986+
parseMemoryAccessAttributes(parser, state)) {
28602987
return failure();
28612988
}
28622989

2990+
if (!parser.parseOptionalComma()) {
2991+
// Parse 2nd memory access attributes.
2992+
if (parseSourceMemoryAccessAttributes(parser, state)) {
2993+
return failure();
2994+
}
2995+
}
2996+
2997+
if (parser.parseColon() || parser.parseType(elementType))
2998+
return failure();
2999+
3000+
if (parser.parseOptionalAttrDict(state.attributes))
3001+
return failure();
3002+
28633003
auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
28643004
auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
28653005

@@ -2883,7 +3023,19 @@ static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
28833023
"both operands must be pointers to the same type");
28843024
}
28853025

2886-
return verifyMemoryAccessAttribute(copyMemory);
3026+
if (failed(verifyMemoryAccessAttribute(copyMemory))) {
3027+
return failure();
3028+
}
3029+
3030+
// TODO - According to the spec:
3031+
//
3032+
// If two masks are present, the first applies to Target and cannot include
3033+
// MakePointerVisible, and the second applies to Source and cannot include
3034+
// MakePointerAvailable.
3035+
//
3036+
// Add such verification here.
3037+
3038+
return verifySourceMemoryAccessAttribute(copyMemory);
28873039
}
28883040

28893041
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,6 +2511,76 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
25112511
return success();
25122512
}
25132513

2514+
template <>
2515+
LogicalResult
2516+
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
2517+
SmallVector<Type, 1> resultTypes;
2518+
size_t wordIndex = 0;
2519+
SmallVector<Value, 4> operands;
2520+
SmallVector<NamedAttribute, 4> attributes;
2521+
2522+
if (wordIndex < words.size()) {
2523+
auto arg = getValue(words[wordIndex]);
2524+
2525+
if (!arg) {
2526+
return emitError(unknownLoc, "unknown result <id> : ")
2527+
<< words[wordIndex];
2528+
}
2529+
2530+
operands.push_back(arg);
2531+
wordIndex++;
2532+
}
2533+
2534+
if (wordIndex < words.size()) {
2535+
auto arg = getValue(words[wordIndex]);
2536+
2537+
if (!arg) {
2538+
return emitError(unknownLoc, "unknown result <id> : ")
2539+
<< words[wordIndex];
2540+
}
2541+
2542+
operands.push_back(arg);
2543+
wordIndex++;
2544+
}
2545+
2546+
bool isAlignedAttr = false;
2547+
2548+
if (wordIndex < words.size()) {
2549+
auto attrValue = words[wordIndex++];
2550+
attributes.push_back(opBuilder.getNamedAttr(
2551+
"memory_access", opBuilder.getI32IntegerAttr(attrValue)));
2552+
isAlignedAttr = (attrValue == 2);
2553+
}
2554+
2555+
if (isAlignedAttr && wordIndex < words.size()) {
2556+
attributes.push_back(opBuilder.getNamedAttr(
2557+
"alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
2558+
}
2559+
2560+
if (wordIndex < words.size()) {
2561+
attributes.push_back(opBuilder.getNamedAttr(
2562+
"source_memory_access",
2563+
opBuilder.getI32IntegerAttr(words[wordIndex++])));
2564+
}
2565+
2566+
if (wordIndex < words.size()) {
2567+
attributes.push_back(opBuilder.getNamedAttr(
2568+
"source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
2569+
}
2570+
2571+
if (wordIndex != words.size()) {
2572+
return emitError(unknownLoc,
2573+
"found more operands than expected when deserializing "
2574+
"spirv::CopyMemoryOp, only ")
2575+
<< wordIndex << " of " << words.size() << " processed";
2576+
}
2577+
2578+
Location loc = createFileLineColLoc(opBuilder);
2579+
opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
2580+
2581+
return success();
2582+
}
2583+
25142584
// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
25152585
// various Deserializer::processOp<...>() specializations.
25162586
#define GET_DESERIALIZATION_FNS

0 commit comments

Comments
 (0)