@@ -30,6 +30,65 @@ using namespace special_ticks;
30
30
// / and default memory space.
31
31
static bool isMemRefTypeOk (MemRefType type) { return type.hasStaticShape (); }
32
32
33
+ static inline int64_t getSizeInBytes (MemRefType &memType) {
34
+ // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
35
+ // least have a large enough size for i1
36
+ int64_t size = memType.getElementTypeBitWidth () / 8 ;
37
+ size = (size > 0 ) ? size : 1 ;
38
+ for (auto v : memType.getShape ()) {
39
+ size *= v;
40
+ }
41
+ return size;
42
+ }
43
+
44
+ static bool needsHoistOutOfParallelLoop (Operation *op) {
45
+ Operation *parent =
46
+ op->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
47
+ if (isa_and_nonnull<scf::ForallOp>(parent)) {
48
+ // check if the current allocation is between the nested pfor, and use
49
+ // inside the inner parallel loop
50
+ SmallVector<Operation *, 4 > parallelOpInCurBlock;
51
+ Block *curBlock = op->getBlock ();
52
+ for (auto &curOp : curBlock->getOperations ()) {
53
+ if (isa<scf::ForallOp>(curOp)) {
54
+ parallelOpInCurBlock.push_back (&curOp);
55
+ }
56
+ }
57
+
58
+ if (parallelOpInCurBlock.empty ())
59
+ return false ;
60
+
61
+ for (auto *use : op->getUsers ()) {
62
+ for (auto *parallelOp : parallelOpInCurBlock) {
63
+ if (parallelOp->isAncestor (use)) {
64
+ return true ;
65
+ }
66
+ }
67
+ }
68
+ }
69
+
70
+ return false ;
71
+ }
72
+
73
+ static bool isForallLoopBoundStatic (Operation *op) {
74
+ auto forallOp = dyn_cast<scf::ForallOp>(op);
75
+ if (!forallOp)
76
+ return false ;
77
+
78
+ auto lbs = forallOp.getMixedLowerBound ();
79
+ auto ubs = forallOp.getMixedUpperBound ();
80
+ auto steps = forallOp.getMixedStep ();
81
+ auto allConstantValue = [](SmallVector<OpFoldResult> vals) -> bool {
82
+ return llvm::all_of (vals, [](OpFoldResult val) {
83
+ std::optional<int64_t > const_val = getConstantIntValue (val);
84
+ return const_val.has_value ();
85
+ });
86
+ };
87
+
88
+ return allConstantValue (lbs) && allConstantValue (ubs) &&
89
+ allConstantValue (steps);
90
+ }
91
+
33
92
void Tick::update (int64_t tick) {
34
93
if (tick == UNTRACEABLE_ACCESS) {
35
94
firstAccess = UNTRACEABLE_ACCESS;
@@ -180,28 +239,60 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
180
239
// trait, and is not scf.for
181
240
Operation *TickCollecter::getAllocScope (TickCollecterStates *s,
182
241
Operation *op) const {
183
- auto parent = op;
242
+ Operation *parent = op;
243
+ bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop (op);
244
+
184
245
for (;;) {
185
246
parent = parent->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
186
247
if (!parent) {
187
248
return nullptr ;
188
249
}
189
- if (!isa<scf::ForOp>(parent)) {
190
- return parent;
191
- }
250
+
251
+ if (isa<scf::ForOp>(parent))
252
+ continue ;
253
+
254
+ if (isa<scf::ForallOp>(parent) &&
255
+ (moveToUpperParellelLoop && isForallLoopBoundStatic (parent)))
256
+ continue ;
257
+
258
+ return parent;
192
259
}
193
260
}
194
261
195
262
FailureOr<size_t > TickCollecter::getAllocSize (TickCollecterStates *s,
196
263
Operation *op) const {
197
264
auto refType = cast<MemRefType>(op->getResultTypes ().front ());
198
- int64_t size = refType.getElementTypeBitWidth () / 8 ;
199
- // treat bool (i1) as 1 byte. It may not be true for all targets, but we at
200
- // least have a large enough size for i1
201
- size = (size != 0 ) ? size : 1 ;
202
- for (auto v : refType.getShape ()) {
203
- size *= v;
265
+
266
+ // Get the total number of threads from the outermost to the current level of
267
+ // the parallel loop that the allocation located in.
268
+ int64_t numThreads = 1 ;
269
+ if (needsHoistOutOfParallelLoop (op)) {
270
+ Operation *parent =
271
+ op->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
272
+ while (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
273
+ if (!isForallLoopBoundStatic (forallOp))
274
+ break ;
275
+
276
+ OpBuilder builder{forallOp->getContext ()};
277
+ std::optional<int64_t > numIterations;
278
+ for (auto [lb, ub, step] : llvm::zip (forallOp.getLowerBound (builder),
279
+ forallOp.getUpperBound (builder),
280
+ forallOp.getStep (builder))) {
281
+ numIterations = constantTripCount (lb, ub, step);
282
+ if (numIterations.has_value ()) {
283
+ numThreads *= numIterations.value ();
284
+ } else {
285
+ return op->emitError (" Expecting static loop range!" );
286
+ }
287
+ }
288
+
289
+ parent = parent->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
290
+ }
204
291
}
292
+ assert (numThreads > 0 );
293
+
294
+ int64_t size = getSizeInBytes (refType);
295
+ size *= numThreads;
205
296
if (size > 0 ) {
206
297
return static_cast <size_t >(size);
207
298
}
@@ -391,11 +482,113 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
391
482
Value mergedAlloc,
392
483
int64_t byteOffset) const {
393
484
builder.setInsertionPoint (origAllocOp);
394
- auto byteShift =
395
- builder.create <arith::ConstantIndexOp>(origAllocOp->getLoc (), byteOffset);
396
- return builder.create <memref::ViewOp>(origAllocOp->getLoc (),
397
- origAllocOp->getResultTypes ().front (),
398
- mergedAlloc, byteShift, ValueRange{});
485
+ auto loc = origAllocOp->getLoc ();
486
+ auto byteShift = builder.create <arith::ConstantIndexOp>(loc, byteOffset);
487
+
488
+ bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop (origAllocOp);
489
+ Operation *parent =
490
+ origAllocOp->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
491
+ if (!moveToUpperParellelLoop || !parent || !isa<scf::ForallOp>(parent))
492
+ return builder.create <memref::ViewOp>(loc,
493
+ origAllocOp->getResultTypes ().front (),
494
+ mergedAlloc, byteShift, ValueRange{});
495
+
496
+ // get the aggregated inductorVar
497
+ Value inductVar;
498
+ bool isOuterMostLoop = true ;
499
+ int64_t innerLoopUpperBound = 1 ;
500
+ while (parent) {
501
+ if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
502
+ if (isForallLoopBoundStatic (forallOp)) {
503
+ SmallVector<Value> ubs = forallOp.getUpperBound (builder);
504
+ SmallVector<Value> lbs = forallOp.getLowerBound (builder);
505
+ SmallVector<Value> steps = forallOp.getStep (builder);
506
+ SmallVector<Value> inductionVars = forallOp.getInductionVars ();
507
+
508
+ auto getCurrentVar = [&loc, &builder](Value var, Value lb,
509
+ Value step) -> Value {
510
+ if (!isConstantIntValue (lb, 0 ))
511
+ var = builder.create <arith::SubIOp>(loc, var, lb);
512
+
513
+ if (!isConstantIntValue (step, 1 ))
514
+ var = builder.create <arith::DivSIOp>(loc, var, step);
515
+ return var;
516
+ };
517
+
518
+ auto getAggregatedVar =
519
+ [&loc, &builder, &getCurrentVar](
520
+ const SmallVector<Value> &_lbs, const SmallVector<Value> &_ubs,
521
+ const SmallVector<Value> &_steps,
522
+ const SmallVector<Value> &_inductVars) -> Value {
523
+ Value var;
524
+ if (_ubs.size () == 1 ) {
525
+ var = getCurrentVar (_inductVars[0 ], _lbs[0 ], _steps[0 ]);
526
+ return var;
527
+ } else {
528
+ bool isFirstLoop = true ;
529
+ for (auto [lb, ub, step, inductVar] :
530
+ llvm::zip (_lbs, _ubs, _steps, _inductVars)) {
531
+ if (isFirstLoop) {
532
+ var = getCurrentVar (inductVar, lb, step);
533
+ isFirstLoop = false ;
534
+ } else {
535
+ Value cur_var = getCurrentVar (inductVar, lb, step);
536
+ std::optional<int64_t > bound = constantTripCount (lb, ub, step);
537
+ assert (bound.has_value ());
538
+ Value boundVal =
539
+ builder.create <arith::ConstantIndexOp>(loc, bound.value ());
540
+ Value tmpVal =
541
+ builder.create <arith::MulIOp>(loc, var, boundVal);
542
+ var = builder.create <arith::AddIOp>(loc, tmpVal, cur_var);
543
+ }
544
+ }
545
+ return var;
546
+ }
547
+ };
548
+
549
+ if (isOuterMostLoop) {
550
+ inductVar = getAggregatedVar (lbs, ubs, steps, inductionVars);
551
+ isOuterMostLoop = false ;
552
+ } else {
553
+ Value currentVar = getAggregatedVar (lbs, ubs, steps, inductionVars);
554
+
555
+ Value innerLoopBoundVal =
556
+ builder.create <arith::ConstantIndexOp>(loc, innerLoopUpperBound);
557
+ Value intermediateVal =
558
+ builder.create <arith::MulIOp>(loc, currentVar, innerLoopBoundVal);
559
+ inductVar =
560
+ builder.create <arith::AddIOp>(loc, inductVar, intermediateVal);
561
+ }
562
+ // get aggregated loop bound
563
+ for (auto [lb, ub, step] : llvm::zip (lbs, ubs, steps)) {
564
+ std::optional<int64_t > cur_bound = constantTripCount (lb, ub, step);
565
+ assert (cur_bound.has_value ());
566
+ innerLoopUpperBound *= cur_bound.value ();
567
+ }
568
+ }
569
+ }
570
+
571
+ parent = parent->getParentWithTrait <OpTrait::AutomaticAllocationScope>();
572
+ }
573
+
574
+ if (!isOuterMostLoop) {
575
+ // get original shape size
576
+ auto memType = cast<MemRefType>(origAllocOp->getResultTypes ().front ());
577
+ int64_t size = getSizeInBytes (memType);
578
+ Value origSize = builder.create <arith::ConstantIndexOp>(loc, size);
579
+ Value offsetPerThread =
580
+ builder.create <arith::MulIOp>(loc, inductVar, origSize);
581
+ Value byteShiftPerThread =
582
+ builder.create <arith::AddIOp>(loc, byteShift, offsetPerThread);
583
+
584
+ return builder.create <memref::ViewOp>(
585
+ loc, origAllocOp->getResultTypes ().front (), mergedAlloc,
586
+ byteShiftPerThread, ValueRange{});
587
+ } else {
588
+ return builder.create <memref::ViewOp>(loc,
589
+ origAllocOp->getResultTypes ().front (),
590
+ mergedAlloc, byteShift, ValueRange{});
591
+ }
399
592
}
400
593
401
594
LogicalResult
0 commit comments