@@ -2510,6 +2510,48 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
2510
2510
}
2511
2511
}
2512
2512
2513
+ // / True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
2514
+ // / StoreMA
2515
+ bool hasIntersectingAccesses (isl::set AllAccs, MemoryAccess *LoadMA,
2516
+ MemoryAccess *StoreMA, isl::set Domain,
2517
+ SmallVector<MemoryAccess *, 8 > &MemAccs) {
2518
+ bool HasIntersectingAccs = false ;
2519
+ for (MemoryAccess *MA : MemAccs) {
2520
+ if (MA == LoadMA || MA == StoreMA)
2521
+ continue ;
2522
+
2523
+ isl::map AccRel = MA->getAccessRelation ().intersect_domain (Domain);
2524
+ isl::set Accs = AccRel.range ();
2525
+
2526
+ if (AllAccs.has_equal_space (Accs)) {
2527
+ isl::set OverlapAccs = Accs.intersect (AllAccs);
2528
+ bool DoesIntersect = !OverlapAccs.is_empty ();
2529
+ HasIntersectingAccs |= DoesIntersect;
2530
+ }
2531
+ }
2532
+ return HasIntersectingAccs;
2533
+ }
2534
+
2535
+ // / Test if the accesses of @p LoadMA and @p StoreMA can form a reduction
2536
+ bool checkCandidatePairAccesses (MemoryAccess *LoadMA, MemoryAccess *StoreMA,
2537
+ isl::set Domain,
2538
+ SmallVector<MemoryAccess *, 8 > &MemAccs) {
2539
+ isl::map LoadAccs = LoadMA->getAccessRelation ();
2540
+ isl::map StoreAccs = StoreMA->getAccessRelation ();
2541
+
2542
+ // Skip those with obviously unequal base addresses.
2543
+ bool Valid = LoadAccs.has_equal_space (StoreAccs);
2544
+
2545
+ // And check if the remaining for overlap with other memory accesses.
2546
+ if (Valid) {
2547
+ isl::map AllAccsRel = LoadAccs.unite (StoreAccs);
2548
+ AllAccsRel = AllAccsRel.intersect_domain (Domain);
2549
+ isl::set AllAccs = AllAccsRel.range ();
2550
+ Valid = !hasIntersectingAccesses (AllAccs, LoadMA, StoreMA, Domain, MemAccs);
2551
+ }
2552
+ return Valid;
2553
+ }
2554
+
2513
2555
void ScopBuilder::checkForReductions (ScopStmt &Stmt) {
2514
2556
SmallVector<MemoryAccess *, 2 > Loads;
2515
2557
SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4 > Candidates;
@@ -2528,34 +2570,10 @@ void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
2528
2570
2529
2571
// Then check each possible candidate pair.
2530
2572
for (const auto &CandidatePair : Candidates) {
2531
- bool Valid = true ;
2532
- isl::map LoadAccs = CandidatePair.first ->getAccessRelation ();
2533
- isl::map StoreAccs = CandidatePair.second ->getAccessRelation ();
2534
-
2535
- // Skip those with obviously unequal base addresses.
2536
- if (!LoadAccs.has_equal_space (StoreAccs)) {
2537
- continue ;
2538
- }
2539
-
2540
- // And check if the remaining for overlap with other memory accesses.
2541
- isl::map AllAccsRel = LoadAccs.unite (StoreAccs);
2542
- AllAccsRel = AllAccsRel.intersect_domain (Stmt.getDomain ());
2543
- isl::set AllAccs = AllAccsRel.range ();
2544
-
2545
- for (MemoryAccess *MA : Stmt) {
2546
- if (MA == CandidatePair.first || MA == CandidatePair.second )
2547
- continue ;
2548
-
2549
- isl::map AccRel =
2550
- MA->getAccessRelation ().intersect_domain (Stmt.getDomain ());
2551
- isl::set Accs = AccRel.range ();
2552
-
2553
- if (AllAccs.has_equal_space (Accs)) {
2554
- isl::set OverlapAccs = Accs.intersect (AllAccs);
2555
- Valid = Valid && OverlapAccs.is_empty ();
2556
- }
2557
- }
2558
-
2573
+ MemoryAccess *LoadMA = CandidatePair.first ;
2574
+ MemoryAccess *StoreMA = CandidatePair.second ;
2575
+ bool Valid = checkCandidatePairAccesses (LoadMA, StoreMA, Stmt.getDomain (),
2576
+ Stmt.MemAccs );
2559
2577
if (!Valid)
2560
2578
continue ;
2561
2579
@@ -2566,8 +2584,8 @@ void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
2566
2584
2567
2585
// If no overlapping access was found we mark the load and store as
2568
2586
// reduction like.
2569
- CandidatePair. first ->markAsReductionLike (RT);
2570
- CandidatePair. second ->markAsReductionLike (RT);
2587
+ LoadMA ->markAsReductionLike (RT);
2588
+ StoreMA ->markAsReductionLike (RT);
2571
2589
}
2572
2590
}
2573
2591
0 commit comments