@@ -50,6 +50,16 @@ static bool isIntValue(Value val, int64_t idx) {
50
50
return false ;
51
51
}
52
52
53
+ // / Helper test for invariant value (defined outside given block).
54
+ static bool isInvariantValue (Value val, Block *block) {
55
+ return val.getDefiningOp () && val.getDefiningOp ()->getBlock () != block;
56
+ }
57
+
58
+ // / Helper test for invariant argument (defined outside given block).
59
+ static bool isInvariantArg (BlockArgument arg, Block *block) {
60
+ return arg.getOwner () != block;
61
+ }
62
+
53
63
// / Constructs vector type for element type.
54
64
static VectorType vectorType (VL vl, Type etp) {
55
65
unsigned numScalableDims = vl.enableVLAVectorization ;
@@ -236,13 +246,15 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
236
246
Value vmask, SmallVectorImpl<Value> &idxs) {
237
247
unsigned d = 0 ;
238
248
unsigned dim = subs.size ();
249
+ Block *block = &forOp.getRegion ().front ();
239
250
for (auto sub : subs) {
240
251
bool innermost = ++d == dim;
241
252
// Invariant subscripts in outer dimensions simply pass through.
242
253
// Note that we rely on LICM to hoist loads where all subscripts
243
254
// are invariant in the innermost loop.
244
- if (sub.getDefiningOp () &&
245
- sub.getDefiningOp ()->getBlock () != &forOp.getRegion ().front ()) {
255
+ // Example:
256
+ // a[inv][i] for inv
257
+ if (isInvariantValue (sub, block)) {
246
258
if (innermost)
247
259
return false ;
248
260
if (codegen)
@@ -252,9 +264,10 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
252
264
// Invariant block arguments (including outer loop indices) in outer
253
265
// dimensions simply pass through. Direct loop indices in the
254
266
// innermost loop simply pass through as well.
255
- if (auto barg = sub.dyn_cast <BlockArgument>()) {
256
- bool invariant = barg.getOwner () != &forOp.getRegion ().front ();
257
- if (invariant == innermost)
267
+ // Example:
268
+ // a[i][j] for both i and j
269
+ if (auto arg = sub.dyn_cast <BlockArgument>()) {
270
+ if (isInvariantArg (arg, block) == innermost)
258
271
return false ;
259
272
if (codegen)
260
273
idxs.push_back (sub);
@@ -281,6 +294,8 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
281
294
// values, there is no good way to state that the indices are unsigned,
282
295
// which creates the potential of incorrect address calculations in the
283
296
// unlikely case we need such extremely large offsets.
297
+ // Example:
298
+ // a[ ind[i] ]
284
299
if (auto load = cast.getDefiningOp <memref::LoadOp>()) {
285
300
if (!innermost)
286
301
return false ;
@@ -303,18 +318,20 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
303
318
continue ; // success so far
304
319
}
305
320
// Address calculation 'i = add inv, idx' (after LICM).
321
+ // Example:
322
+ // a[base + i]
306
323
if (auto load = cast.getDefiningOp <arith::AddIOp>()) {
307
324
Value inv = load.getOperand (0 );
308
325
Value idx = load.getOperand (1 );
309
- if (inv. getDefiningOp () &&
310
- inv. getDefiningOp ()-> getBlock () != &forOp. getRegion (). front () &&
311
- idx. dyn_cast <BlockArgument>()) {
312
- if (!innermost)
313
- return false ;
314
- if (codegen)
315
- idxs. push_back (
316
- rewriter. create <arith::AddIOp>(forOp. getLoc (), inv, idx));
317
- continue ; // success so far
326
+ if (isInvariantValue (inv, block)) {
327
+ if ( auto arg = idx. dyn_cast <BlockArgument>()) {
328
+ if ( isInvariantArg (arg, block) || !innermost)
329
+ return false ;
330
+ if (codegen)
331
+ idxs. push_back (
332
+ rewriter. create <arith::AddIOp>(forOp. getLoc (), inv, idx));
333
+ continue ; // success so far
334
+ }
318
335
}
319
336
}
320
337
return false ;
@@ -389,7 +406,8 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
389
406
}
390
407
// Something defined outside the loop-body is invariant.
391
408
Operation *def = exp.getDefiningOp ();
392
- if (def->getBlock () != &forOp.getRegion ().front ()) {
409
+ Block *block = &forOp.getRegion ().front ();
410
+ if (def->getBlock () != block) {
393
411
if (codegen)
394
412
vexp = genVectorInvariantValue (rewriter, vl, exp);
395
413
return true ;
@@ -450,6 +468,17 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
450
468
vx) &&
451
469
vectorizeExpr (rewriter, forOp, vl, def->getOperand (1 ), codegen, vmask,
452
470
vy)) {
471
+ // We only accept shift-by-invariant (where the same shift factor applies
472
+ // to all packed elements). In the vector dialect, this is still
473
+ // represented with an expanded vector at the right-hand-side, however,
474
+ // so that we do not have to special case the code generation.
475
+ if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
476
+ isa<arith::ShRSIOp>(def)) {
477
+ Value shiftFactor = def->getOperand (1 );
478
+ if (!isInvariantValue (shiftFactor, block))
479
+ return false ;
480
+ }
481
+ // Generate code.
453
482
BINOP (arith::MulFOp)
454
483
BINOP (arith::MulIOp)
455
484
BINOP (arith::DivFOp)
@@ -462,8 +491,10 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
462
491
BINOP (arith::AndIOp)
463
492
BINOP (arith::OrIOp)
464
493
BINOP (arith::XOrIOp)
494
+ BINOP (arith::ShLIOp)
495
+ BINOP (arith::ShRUIOp)
496
+ BINOP (arith::ShRSIOp)
465
497
// TODO: complex?
466
- // TODO: shift by invariant?
467
498
}
468
499
}
469
500
return false ;
0 commit comments