28
28
using namespace mlir ;
29
29
30
30
// TODO: generate these strings using ODS.
31
+ static constexpr const char kMemoryAccessAttrName [] = " memory_access" ;
32
+ static constexpr const char kSourceMemoryAccessAttrName [] =
33
+ " source_memory_access" ;
31
34
static constexpr const char kAlignmentAttrName [] = " alignment" ;
35
+ static constexpr const char kSourceAlignmentAttrName [] = " source_alignment" ;
32
36
static constexpr const char kBranchWeightAttrName [] = " branch_weights" ;
33
37
static constexpr const char kCallee [] = " callee" ;
34
38
static constexpr const char kClusterSize [] = " cluster_size" ;
@@ -157,6 +161,12 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
157
161
return success ();
158
162
}
159
163
164
+ // / Parses optional memory access attributes attached to a memory access
165
+ // / operand/pointer. Specifically, parses the following syntax:
166
+ // / (`[` memory-access `]`)?
167
+ // / where:
168
+ // / memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
169
+ // / integer-literal | `"NonTemporal"`
160
170
static ParseResult parseMemoryAccessAttributes (OpAsmParser &parser,
161
171
OperationState &state) {
162
172
// Parse an optional list of attributes staring with '['
@@ -166,7 +176,8 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
166
176
}
167
177
168
178
spirv::MemoryAccess memoryAccessAttr;
169
- if (parseEnumStrAttr (memoryAccessAttr, parser, state)) {
179
+ if (parseEnumStrAttr (memoryAccessAttr, parser, state,
180
+ kMemoryAccessAttrName )) {
170
181
return failure ();
171
182
}
172
183
@@ -183,19 +194,90 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
183
194
return parser.parseRSquare ();
184
195
}
185
196
197
+ // TODO Make sure to merge this and the previous function into one template
198
+ // parameterized by memroy access attribute name and alignment. Doing so now
199
+ // results in VS2017 in producing an internal error (at the call site) that's
200
+ // not detailed enough to understand what is happenning.
201
+ static ParseResult parseSourceMemoryAccessAttributes (OpAsmParser &parser,
202
+ OperationState &state) {
203
+ // Parse an optional list of attributes staring with '['
204
+ if (parser.parseOptionalLSquare ()) {
205
+ // Nothing to do
206
+ return success ();
207
+ }
208
+
209
+ spirv::MemoryAccess memoryAccessAttr;
210
+ if (parseEnumStrAttr (memoryAccessAttr, parser, state,
211
+ kSourceMemoryAccessAttrName )) {
212
+ return failure ();
213
+ }
214
+
215
+ if (spirv::bitEnumContains (memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
216
+ // Parse integer attribute for alignment.
217
+ Attribute alignmentAttr;
218
+ Type i32Type = parser.getBuilder ().getIntegerType (32 );
219
+ if (parser.parseComma () ||
220
+ parser.parseAttribute (alignmentAttr, i32Type, kSourceAlignmentAttrName ,
221
+ state.attributes )) {
222
+ return failure ();
223
+ }
224
+ }
225
+ return parser.parseRSquare ();
226
+ }
227
+
186
228
template <typename MemoryOpTy>
187
- static void
188
- printMemoryAccessAttribute (MemoryOpTy memoryOp, OpAsmPrinter &printer,
189
- SmallVectorImpl<StringRef> &elidedAttrs) {
229
+ static void printMemoryAccessAttribute (
230
+ MemoryOpTy memoryOp, OpAsmPrinter &printer,
231
+ SmallVectorImpl<StringRef> &elidedAttrs,
232
+ Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
233
+ Optional<llvm::APInt> alignmentAttrValue = None) {
190
234
// Print optional memory access attribute.
191
- if (auto memAccess = memoryOp.memory_access ()) {
192
- elidedAttrs.push_back (spirv::attributeName<spirv::MemoryAccess>());
235
+ if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
236
+ : memoryOp.memory_access ())) {
237
+ elidedAttrs.push_back (kMemoryAccessAttrName );
238
+
193
239
printer << " [\" " << stringifyMemoryAccess (*memAccess) << " \" " ;
194
240
195
- // Print integer alignment attribute.
196
- if (auto alignment = memoryOp.alignment ()) {
197
- elidedAttrs.push_back (kAlignmentAttrName );
198
- printer << " , " << alignment;
241
+ if (spirv::bitEnumContains (*memAccess, spirv::MemoryAccess::Aligned)) {
242
+ // Print integer alignment attribute.
243
+ if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
244
+ : memoryOp.alignment ())) {
245
+ elidedAttrs.push_back (kAlignmentAttrName );
246
+ printer << " , " << alignment;
247
+ }
248
+ }
249
+ printer << " ]" ;
250
+ }
251
+ elidedAttrs.push_back (spirv::attributeName<spirv::StorageClass>());
252
+ }
253
+
254
+ // TODO Make sure to merge this and the previous function into one template
255
+ // parameterized by memroy access attribute name and alignment. Doing so now
256
+ // results in VS2017 in producing an internal error (at the call site) that's
257
+ // not detailed enough to understand what is happenning.
258
+ template <typename MemoryOpTy>
259
+ static void printSourceMemoryAccessAttribute (
260
+ MemoryOpTy memoryOp, OpAsmPrinter &printer,
261
+ SmallVectorImpl<StringRef> &elidedAttrs,
262
+ Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
263
+ Optional<llvm::APInt> alignmentAttrValue = None) {
264
+
265
+ printer << " , " ;
266
+
267
+ // Print optional memory access attribute.
268
+ if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
269
+ : memoryOp.memory_access ())) {
270
+ elidedAttrs.push_back (kSourceMemoryAccessAttrName );
271
+
272
+ printer << " [\" " << stringifyMemoryAccess (*memAccess) << " \" " ;
273
+
274
+ if (spirv::bitEnumContains (*memAccess, spirv::MemoryAccess::Aligned)) {
275
+ // Print integer alignment attribute.
276
+ if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
277
+ : memoryOp.alignment ())) {
278
+ elidedAttrs.push_back (kSourceAlignmentAttrName );
279
+ printer << " , " << alignment;
280
+ }
199
281
}
200
282
printer << " ]" ;
201
283
}
@@ -249,7 +331,7 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
249
331
// memory-access attribute is Aligned, then the alignment attribute must be
250
332
// present.
251
333
auto *op = memoryOp.getOperation ();
252
- auto memAccessAttr = op->getAttr (spirv::attributeName<spirv::MemoryAccess>() );
334
+ auto memAccessAttr = op->getAttr (kMemoryAccessAttrName );
253
335
if (!memAccessAttr) {
254
336
// Alignment attribute shouldn't be present if memory access attribute is
255
337
// not present.
@@ -283,6 +365,50 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
283
365
return success ();
284
366
}
285
367
368
+ // TODO Make sure to merge this and the previous function into one template
369
+ // parameterized by memroy access attribute name and alignment. Doing so now
370
+ // results in VS2017 in producing an internal error (at the call site) that's
371
+ // not detailed enough to understand what is happenning.
372
+ template <typename MemoryOpTy>
373
+ static LogicalResult verifySourceMemoryAccessAttribute (MemoryOpTy memoryOp) {
374
+ // ODS checks for attributes values. Just need to verify that if the
375
+ // memory-access attribute is Aligned, then the alignment attribute must be
376
+ // present.
377
+ auto *op = memoryOp.getOperation ();
378
+ auto memAccessAttr = op->getAttr (kSourceMemoryAccessAttrName );
379
+ if (!memAccessAttr) {
380
+ // Alignment attribute shouldn't be present if memory access attribute is
381
+ // not present.
382
+ if (op->getAttr (kSourceAlignmentAttrName )) {
383
+ return memoryOp.emitOpError (
384
+ " invalid alignment specification without aligned memory access "
385
+ " specification" );
386
+ }
387
+ return success ();
388
+ }
389
+
390
+ auto memAccessVal = memAccessAttr.template cast <IntegerAttr>();
391
+ auto memAccess = spirv::symbolizeMemoryAccess (memAccessVal.getInt ());
392
+
393
+ if (!memAccess) {
394
+ return memoryOp.emitOpError (" invalid memory access specifier: " )
395
+ << memAccessVal;
396
+ }
397
+
398
+ if (spirv::bitEnumContains (*memAccess, spirv::MemoryAccess::Aligned)) {
399
+ if (!op->getAttr (kSourceAlignmentAttrName )) {
400
+ return memoryOp.emitOpError (" missing alignment value" );
401
+ }
402
+ } else {
403
+ if (op->getAttr (kSourceAlignmentAttrName )) {
404
+ return memoryOp.emitOpError (
405
+ " invalid alignment specification with non-aligned memory access "
406
+ " specification" );
407
+ }
408
+ }
409
+ return success ();
410
+ }
411
+
286
412
template <typename BarrierOp>
287
413
static LogicalResult verifyMemorySemantics (BarrierOp op) {
288
414
// According to the SPIR-V specification:
@@ -2832,6 +2958,9 @@ static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
2832
2958
2833
2959
SmallVector<StringRef, 4 > elidedAttrs;
2834
2960
printMemoryAccessAttribute (copyMemory, printer, elidedAttrs);
2961
+ printSourceMemoryAccessAttribute (copyMemory, printer, elidedAttrs,
2962
+ copyMemory.source_memory_access (),
2963
+ copyMemory.source_alignment ());
2835
2964
2836
2965
printer.printOptionalAttrDict (op->getAttrs (), elidedAttrs);
2837
2966
@@ -2854,12 +2983,23 @@ static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
2854
2983
parser.parseOperand (targetPtrInfo) || parser.parseComma () ||
2855
2984
parseEnumStrAttr (sourceStorageClass, parser) ||
2856
2985
parser.parseOperand (sourcePtrInfo) ||
2857
- parseMemoryAccessAttributes (parser, state) ||
2858
- parser.parseOptionalAttrDict (state.attributes ) || parser.parseColon () ||
2859
- parser.parseType (elementType)) {
2986
+ parseMemoryAccessAttributes (parser, state)) {
2860
2987
return failure ();
2861
2988
}
2862
2989
2990
+ if (!parser.parseOptionalComma ()) {
2991
+ // Parse 2nd memory access attributes.
2992
+ if (parseSourceMemoryAccessAttributes (parser, state)) {
2993
+ return failure ();
2994
+ }
2995
+ }
2996
+
2997
+ if (parser.parseColon () || parser.parseType (elementType))
2998
+ return failure ();
2999
+
3000
+ if (parser.parseOptionalAttrDict (state.attributes ))
3001
+ return failure ();
3002
+
2863
3003
auto targetPtrType = spirv::PointerType::get (elementType, targetStorageClass);
2864
3004
auto sourcePtrType = spirv::PointerType::get (elementType, sourceStorageClass);
2865
3005
@@ -2883,7 +3023,19 @@ static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
2883
3023
" both operands must be pointers to the same type" );
2884
3024
}
2885
3025
2886
- return verifyMemoryAccessAttribute (copyMemory);
3026
+ if (failed (verifyMemoryAccessAttribute (copyMemory))) {
3027
+ return failure ();
3028
+ }
3029
+
3030
+ // TODO - According to the spec:
3031
+ //
3032
+ // If two masks are present, the first applies to Target and cannot include
3033
+ // MakePointerVisible, and the second applies to Source and cannot include
3034
+ // MakePointerAvailable.
3035
+ //
3036
+ // Add such verification here.
3037
+
3038
+ return verifySourceMemoryAccessAttribute (copyMemory);
2887
3039
}
2888
3040
2889
3041
// ===----------------------------------------------------------------------===//
0 commit comments