12
12
13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
14
#include " mlir/Dialect/MemRef/IR/MemRef.h"
15
+ #include " mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
15
16
#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16
17
#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17
18
#include " mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18
19
#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20
+ #include " mlir/IR/BuiltinAttributes.h"
19
21
#include " mlir/IR/BuiltinTypes.h"
22
+ #include " mlir/IR/MLIRContext.h"
23
+ #include " mlir/IR/Visitors.h"
24
+ #include " mlir/Support/LogicalResult.h"
20
25
#include " llvm/Support/Debug.h"
26
+ #include < cassert>
21
27
#include < optional>
22
28
23
29
#define DEBUG_TYPE " memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
439
445
// LoadOp
440
446
// ===----------------------------------------------------------------------===//
441
447
448
+ using AlignmentRequirements =
449
+ FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
450
+
451
+ // / Given an accessed SPIR-V pointer, calculates its alignment requirements, if
452
+ // / any.
453
+ static AlignmentRequirements calculateRequiredAlignment (Value accessedPtr) {
454
+ auto ptrType = cast<spirv::PointerType>(accessedPtr.getType ());
455
+ if (ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer)
456
+ return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
457
+
458
+ // PhysicalStorageBuffers require the `Aligned` attribute.
459
+ auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType ());
460
+ if (!pointeeType)
461
+ return failure ();
462
+
463
+ // For scalar types, the alignment is determined by their size.
464
+ std::optional<int64_t > sizeInBytes = pointeeType.getSizeInBytes ();
465
+ if (!sizeInBytes.has_value ())
466
+ return failure ();
467
+
468
+ MLIRContext *ctx = accessedPtr.getContext ();
469
+ auto memAccessAttr =
470
+ spirv::MemoryAccessAttr::get (ctx, spirv::MemoryAccess::Aligned);
471
+ auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), *sizeInBytes);
472
+ return std::pair{memAccessAttr, alignment};
473
+ }
474
+
475
+ // / Given an accessed SPIR-V pointer and the original memref load/store
476
+ // / `memAccess` op, calculates the alignment requirements, if any. Takes into
477
+ // / account the alignment attributes applied to the load/store op.
478
+ static AlignmentRequirements
479
+ calculateRequiredAlignment (Value accessedPtr, Operation *memrefAccessOp) {
480
+ assert (memrefAccessOp);
481
+ assert ((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
482
+ " Bad op type" );
483
+
484
+ auto memrefMemAccess = memrefAccessOp->getAttrOfType <spirv::MemoryAccessAttr>(
485
+ spirv::attributeName<spirv::MemoryAccess>());
486
+ auto memrefAlignment =
487
+ memrefAccessOp->getAttrOfType <IntegerAttr>(" alignment" );
488
+ if (memrefMemAccess && memrefAlignment)
489
+ return std::pair{memrefMemAccess, memrefAlignment};
490
+
491
+ return calculateRequiredAlignment (accessedPtr);
492
+ }
493
+
442
494
LogicalResult
443
495
IntLoadOpPattern::matchAndRewrite (memref::LoadOp loadOp, OpAdaptor adaptor,
444
496
ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
486
538
// If the rewritten load op has the same bit width, use the loading value
487
539
// directly.
488
540
if (srcBits == dstBits) {
489
- Value loadVal = rewriter.create <spirv::LoadOp>(loc, accessChain);
541
+ AlignmentRequirements alignmentRequirements =
542
+ calculateRequiredAlignment (accessChain, loadOp);
543
+ if (failed (alignmentRequirements))
544
+ return rewriter.notifyMatchFailure (
545
+ loadOp, " failed to determine alignment requirements" );
546
+
547
+ auto [memoryAccess, alignment] = *alignmentRequirements;
548
+ Value loadVal = rewriter.create <spirv::LoadOp>(loc, accessChain,
549
+ memoryAccess, alignment);
490
550
if (isBool)
491
551
loadVal = castIntNToBool (loc, loadVal, rewriter);
492
552
rewriter.replaceOp (loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
508
568
assert (accessChainOp.getIndices ().size () == 2 );
509
569
Value adjustedPtr = adjustAccessChainForBitwidth (typeConverter, accessChainOp,
510
570
srcBits, dstBits, rewriter);
511
- Value spvLoadOp = rewriter.create <spirv::LoadOp>(
512
- loc, dstType, adjustedPtr,
513
- loadOp->getAttrOfType <spirv::MemoryAccessAttr>(
514
- spirv::attributeName<spirv::MemoryAccess>()),
515
- loadOp->getAttrOfType <IntegerAttr>(" alignment" ));
571
+ AlignmentRequirements alignmentRequirements =
572
+ calculateRequiredAlignment (adjustedPtr, loadOp);
573
+ if (failed (alignmentRequirements))
574
+ return rewriter.notifyMatchFailure (
575
+ loadOp, " failed to determine alignment requirements" );
576
+
577
+ auto [memoryAccess, alignment] = *alignmentRequirements;
578
+ Value spvLoadOp = rewriter.create <spirv::LoadOp>(loc, dstType, adjustedPtr,
579
+ memoryAccess, alignment);
516
580
517
581
// Shift the bits to the rightmost.
518
582
// ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
552
616
auto memrefType = cast<MemRefType>(loadOp.getMemref ().getType ());
553
617
if (memrefType.getElementType ().isSignlessInteger ())
554
618
return failure ();
555
- auto loadPtr = spirv::getElementPtr (
619
+ Value loadPtr = spirv::getElementPtr (
556
620
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref (),
557
621
adaptor.getIndices (), loadOp.getLoc (), rewriter);
558
622
559
623
if (!loadPtr)
560
624
return failure ();
561
625
562
- rewriter.replaceOpWithNewOp <spirv::LoadOp>(loadOp, loadPtr);
626
+ AlignmentRequirements requiredAlignment = calculateRequiredAlignment (loadPtr);
627
+ if (failed (requiredAlignment))
628
+ return rewriter.notifyMatchFailure (
629
+ loadOp, " failed to determine alignment requirements" );
630
+
631
+ auto [memAccessAttr, alignment] = *requiredAlignment;
632
+ rewriter.replaceOpWithNewOp <spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
633
+ alignment);
563
634
return success ();
564
635
}
565
636
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
618
689
assert (dstBits % srcBits == 0 );
619
690
620
691
if (srcBits == dstBits) {
692
+ AlignmentRequirements requiredAlignment =
693
+ calculateRequiredAlignment (accessChain);
694
+ if (failed (requiredAlignment))
695
+ return rewriter.notifyMatchFailure (
696
+ storeOp, " failed to determine alignment requirements" );
697
+
698
+ auto [memAccessAttr, alignment] = *requiredAlignment;
621
699
Value storeVal = adaptor.getValue ();
622
700
if (isBool)
623
701
storeVal = castBoolToIntN (loc, storeVal, dstType, rewriter);
624
- rewriter.replaceOpWithNewOp <spirv::StoreOp>(storeOp, accessChain, storeVal);
702
+ rewriter.replaceOpWithNewOp <spirv::StoreOp>(storeOp, accessChain, storeVal,
703
+ memAccessAttr, alignment);
625
704
return success ();
626
705
}
627
706
@@ -768,8 +847,15 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
768
847
if (!storePtr)
769
848
return rewriter.notifyMatchFailure (storeOp, " type conversion failed" );
770
849
771
- rewriter.replaceOpWithNewOp <spirv::StoreOp>(storeOp, storePtr,
772
- adaptor.getValue ());
850
+ AlignmentRequirements requiredAlignment =
851
+ calculateRequiredAlignment (storePtr, storeOp);
852
+ if (failed (requiredAlignment))
853
+ return rewriter.notifyMatchFailure (
854
+ storeOp, " failed to determine alignment requirements" );
855
+
856
+ auto [memAccessAttr, alignment] = *requiredAlignment;
857
+ rewriter.replaceOpWithNewOp <spirv::StoreOp>(
858
+ storeOp, storePtr, adaptor.getValue (), memAccessAttr, alignment);
773
859
return success ();
774
860
}
775
861
0 commit comments