@@ -445,15 +445,30 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
445
445
// LoadOp
446
446
// ===----------------------------------------------------------------------===//
447
447
448
- using AlignmentRequirements =
449
- FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
448
+ struct MemoryRequirements {
449
+ spirv::MemoryAccessAttr memoryAccess;
450
+ IntegerAttr alignment;
451
+ };
450
452
451
453
// / Given an accessed SPIR-V pointer, calculates its alignment requirements, if
452
454
// / any.
453
- static AlignmentRequirements calculateRequiredAlignment (Value accessedPtr) {
455
+ static FailureOr<MemoryRequirements>
456
+ calculateMemoryRequirements (Value accessedPtr, bool isNontemporal) {
457
+ MLIRContext *ctx = accessedPtr.getContext ();
458
+
459
+ auto memoryAccess = spirv::MemoryAccess::None;
460
+ if (isNontemporal) {
461
+ memoryAccess = spirv::MemoryAccess::Nontemporal;
462
+ }
463
+
454
464
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType ());
455
- if (ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer)
456
- return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
465
+ if (ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer) {
466
+ if (memoryAccess == spirv::MemoryAccess::None) {
467
+ return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
468
+ }
469
+ return MemoryRequirements{spirv::MemoryAccessAttr::get (ctx, memoryAccess),
470
+ IntegerAttr{}};
471
+ }
457
472
458
473
// PhysicalStorageBuffers require the `Aligned` attribute.
459
474
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType ());
@@ -465,30 +480,32 @@ static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
465
480
if (!sizeInBytes.has_value ())
466
481
return failure ();
467
482
468
- MLIRContext *ctx = accessedPtr.getContext ();
469
- auto memAccessAttr =
470
- spirv::MemoryAccessAttr::get (ctx, spirv::MemoryAccess::Aligned);
483
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
484
+ auto memAccessAttr = spirv::MemoryAccessAttr::get (ctx, memoryAccess);
471
485
auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), *sizeInBytes);
472
- return std::pair {memAccessAttr, alignment};
486
+ return MemoryRequirements {memAccessAttr, alignment};
473
487
}
474
488
475
489
// / Given an accessed SPIR-V pointer and the original memref load/store
476
490
// / `memAccess` op, calculates the alignment requirements, if any. Takes into
477
491
// / account the alignment attributes applied to the load/store op.
478
- static AlignmentRequirements
479
- calculateRequiredAlignment (Value accessedPtr, Operation *memrefAccessOp) {
480
- assert (memrefAccessOp);
481
- assert ((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
482
- " Bad op type" );
483
-
492
+ template <class LoadOrStoreOp >
493
+ static FailureOr<MemoryRequirements>
494
+ calculateMemoryRequirements (Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
495
+ static_assert (
496
+ llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
497
+ " Must be called on either memref::LoadOp or memref::StoreOp" );
498
+
499
+ Operation *memrefAccessOp = loadOrStoreOp.getOperation ();
484
500
auto memrefMemAccess = memrefAccessOp->getAttrOfType <spirv::MemoryAccessAttr>(
485
501
spirv::attributeName<spirv::MemoryAccess>());
486
502
auto memrefAlignment =
487
503
memrefAccessOp->getAttrOfType <IntegerAttr>(" alignment" );
488
504
if (memrefMemAccess && memrefAlignment)
489
- return std::pair {memrefMemAccess, memrefAlignment};
505
+ return MemoryRequirements {memrefMemAccess, memrefAlignment};
490
506
491
- return calculateRequiredAlignment (accessedPtr);
507
+ return calculateMemoryRequirements (accessedPtr,
508
+ loadOrStoreOp.getNontemporal ());
492
509
}
493
510
494
511
LogicalResult
@@ -538,13 +555,12 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
538
555
// If the rewritten load op has the same bit width, use the loading value
539
556
// directly.
540
557
if (srcBits == dstBits) {
541
- AlignmentRequirements alignmentRequirements =
542
- calculateRequiredAlignment (accessChain, loadOp);
543
- if (failed (alignmentRequirements))
558
+ auto memoryRequirements = calculateMemoryRequirements (accessChain, loadOp);
559
+ if (failed (memoryRequirements))
544
560
return rewriter.notifyMatchFailure (
545
- loadOp, " failed to determine alignment requirements" );
561
+ loadOp, " failed to determine memory requirements" );
546
562
547
- auto [memoryAccess, alignment] = *alignmentRequirements ;
563
+ auto [memoryAccess, alignment] = *memoryRequirements ;
548
564
Value loadVal = rewriter.create <spirv::LoadOp>(loc, accessChain,
549
565
memoryAccess, alignment);
550
566
if (isBool)
@@ -568,13 +584,12 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
568
584
assert (accessChainOp.getIndices ().size () == 2 );
569
585
Value adjustedPtr = adjustAccessChainForBitwidth (typeConverter, accessChainOp,
570
586
srcBits, dstBits, rewriter);
571
- AlignmentRequirements alignmentRequirements =
572
- calculateRequiredAlignment (adjustedPtr, loadOp);
573
- if (failed (alignmentRequirements))
587
+ auto memoryRequirements = calculateMemoryRequirements (adjustedPtr, loadOp);
588
+ if (failed (memoryRequirements))
574
589
return rewriter.notifyMatchFailure (
575
- loadOp, " failed to determine alignment requirements" );
590
+ loadOp, " failed to determine memory requirements" );
576
591
577
- auto [memoryAccess, alignment] = *alignmentRequirements ;
592
+ auto [memoryAccess, alignment] = *memoryRequirements ;
578
593
Value spvLoadOp = rewriter.create <spirv::LoadOp>(loc, dstType, adjustedPtr,
579
594
memoryAccess, alignment);
580
595
@@ -623,13 +638,13 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
623
638
if (!loadPtr)
624
639
return failure ();
625
640
626
- AlignmentRequirements requiredAlignment = calculateRequiredAlignment (loadPtr);
627
- if (failed (requiredAlignment ))
641
+ auto memoryRequirements = calculateMemoryRequirements (loadPtr, loadOp );
642
+ if (failed (memoryRequirements ))
628
643
return rewriter.notifyMatchFailure (
629
- loadOp, " failed to determine alignment requirements" );
644
+ loadOp, " failed to determine memory requirements" );
630
645
631
- auto [memAccessAttr , alignment] = *requiredAlignment ;
632
- rewriter.replaceOpWithNewOp <spirv::LoadOp>(loadOp, loadPtr, memAccessAttr ,
646
+ auto [memoryAccess , alignment] = *memoryRequirements ;
647
+ rewriter.replaceOpWithNewOp <spirv::LoadOp>(loadOp, loadPtr, memoryAccess ,
633
648
alignment);
634
649
return success ();
635
650
}
@@ -689,18 +704,17 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
689
704
assert (dstBits % srcBits == 0 );
690
705
691
706
if (srcBits == dstBits) {
692
- AlignmentRequirements requiredAlignment =
693
- calculateRequiredAlignment (accessChain);
694
- if (failed (requiredAlignment))
707
+ auto memoryRequirements = calculateMemoryRequirements (accessChain, storeOp);
708
+ if (failed (memoryRequirements))
695
709
return rewriter.notifyMatchFailure (
696
- storeOp, " failed to determine alignment requirements" );
710
+ storeOp, " failed to determine memory requirements" );
697
711
698
- auto [memAccessAttr , alignment] = *requiredAlignment ;
712
+ auto [memoryAccess , alignment] = *memoryRequirements ;
699
713
Value storeVal = adaptor.getValue ();
700
714
if (isBool)
701
715
storeVal = castBoolToIntN (loc, storeVal, dstType, rewriter);
702
716
rewriter.replaceOpWithNewOp <spirv::StoreOp>(storeOp, accessChain, storeVal,
703
- memAccessAttr , alignment);
717
+ memoryAccess , alignment);
704
718
return success ();
705
719
}
706
720
@@ -847,15 +861,14 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
847
861
if (!storePtr)
848
862
return rewriter.notifyMatchFailure (storeOp, " type conversion failed" );
849
863
850
- AlignmentRequirements requiredAlignment =
851
- calculateRequiredAlignment (storePtr, storeOp);
852
- if (failed (requiredAlignment))
864
+ auto memoryRequirements = calculateMemoryRequirements (storePtr, storeOp);
865
+ if (failed (memoryRequirements))
853
866
return rewriter.notifyMatchFailure (
854
- storeOp, " failed to determine alignment requirements" );
867
+ storeOp, " failed to determine memory requirements" );
855
868
856
- auto [memAccessAttr , alignment] = *requiredAlignment ;
869
+ auto [memoryAccess , alignment] = *memoryRequirements ;
857
870
rewriter.replaceOpWithNewOp <spirv::StoreOp>(
858
- storeOp, storePtr, adaptor.getValue (), memAccessAttr , alignment);
871
+ storeOp, storePtr, adaptor.getValue (), memoryAccess , alignment);
859
872
return success ();
860
873
}
861
874
0 commit comments