Skip to content

Commit 431f6a5

Browse files
committed
[sparse][mlir][vectorization] add support for shift-by-invariant
Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D140596
1 parent e91e62d commit 431f6a5

File tree

2 files changed

+61
-20
lines changed

2 files changed

+61
-20
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ static bool isIntValue(Value val, int64_t idx) {
5050
return false;
5151
}
5252

53+
/// Helper test for invariant value (defined outside given block).
54+
static bool isInvariantValue(Value val, Block *block) {
55+
return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
56+
}
57+
58+
/// Helper test for invariant argument (defined outside given block).
59+
static bool isInvariantArg(BlockArgument arg, Block *block) {
60+
return arg.getOwner() != block;
61+
}
62+
5363
/// Constructs vector type for element type.
5464
static VectorType vectorType(VL vl, Type etp) {
5565
unsigned numScalableDims = vl.enableVLAVectorization;
@@ -236,13 +246,15 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
236246
Value vmask, SmallVectorImpl<Value> &idxs) {
237247
unsigned d = 0;
238248
unsigned dim = subs.size();
249+
Block *block = &forOp.getRegion().front();
239250
for (auto sub : subs) {
240251
bool innermost = ++d == dim;
241252
// Invariant subscripts in outer dimensions simply pass through.
242253
// Note that we rely on LICM to hoist loads where all subscripts
243254
// are invariant in the innermost loop.
244-
if (sub.getDefiningOp() &&
245-
sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
255+
// Example:
256+
// a[inv][i] for inv
257+
if (isInvariantValue(sub, block)) {
246258
if (innermost)
247259
return false;
248260
if (codegen)
@@ -252,9 +264,10 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
252264
// Invariant block arguments (including outer loop indices) in outer
253265
// dimensions simply pass through. Direct loop indices in the
254266
// innermost loop simply pass through as well.
255-
if (auto barg = sub.dyn_cast<BlockArgument>()) {
256-
bool invariant = barg.getOwner() != &forOp.getRegion().front();
257-
if (invariant == innermost)
267+
// Example:
268+
// a[i][j] for both i and j
269+
if (auto arg = sub.dyn_cast<BlockArgument>()) {
270+
if (isInvariantArg(arg, block) == innermost)
258271
return false;
259272
if (codegen)
260273
idxs.push_back(sub);
@@ -281,6 +294,8 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
281294
// values, there is no good way to state that the indices are unsigned,
282295
// which creates the potential of incorrect address calculations in the
283296
// unlikely case we need such extremely large offsets.
297+
// Example:
298+
// a[ ind[i] ]
284299
if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
285300
if (!innermost)
286301
return false;
@@ -303,18 +318,20 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
303318
continue; // success so far
304319
}
305320
// Address calculation 'i = add inv, idx' (after LICM).
321+
// Example:
322+
// a[base + i]
306323
if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
307324
Value inv = load.getOperand(0);
308325
Value idx = load.getOperand(1);
309-
if (inv.getDefiningOp() &&
310-
inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() &&
311-
idx.dyn_cast<BlockArgument>()) {
312-
if (!innermost)
313-
return false;
314-
if (codegen)
315-
idxs.push_back(
316-
rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
317-
continue; // success so far
326+
if (isInvariantValue(inv, block)) {
327+
if (auto arg = idx.dyn_cast<BlockArgument>()) {
328+
if (isInvariantArg(arg, block) || !innermost)
329+
return false;
330+
if (codegen)
331+
idxs.push_back(
332+
rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
333+
continue; // success so far
334+
}
318335
}
319336
}
320337
return false;
@@ -389,7 +406,8 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
389406
}
390407
// Something defined outside the loop-body is invariant.
391408
Operation *def = exp.getDefiningOp();
392-
if (def->getBlock() != &forOp.getRegion().front()) {
409+
Block *block = &forOp.getRegion().front();
410+
if (def->getBlock() != block) {
393411
if (codegen)
394412
vexp = genVectorInvariantValue(rewriter, vl, exp);
395413
return true;
@@ -450,6 +468,17 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
450468
vx) &&
451469
vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
452470
vy)) {
471+
// We only accept shift-by-invariant (where the same shift factor applies
472+
// to all packed elements). In the vector dialect, this is still
473+
// represented with an expanded vector at the right-hand-side, however,
474+
// so that we do not have to special case the code generation.
475+
if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
476+
isa<arith::ShRSIOp>(def)) {
477+
Value shiftFactor = def->getOperand(1);
478+
if (!isInvariantValue(shiftFactor, block))
479+
return false;
480+
}
481+
// Generate code.
453482
BINOP(arith::MulFOp)
454483
BINOP(arith::MulIOp)
455484
BINOP(arith::DivFOp)
@@ -462,8 +491,10 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
462491
BINOP(arith::AndIOp)
463492
BINOP(arith::OrIOp)
464493
BINOP(arith::XOrIOp)
494+
BINOP(arith::ShLIOp)
495+
BINOP(arith::ShRUIOp)
496+
BINOP(arith::ShRSIOp)
465497
// TODO: complex?
466-
// TODO: shift by invariant?
467498
}
468499
}
469500
return false;

mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
// CHECK-DAG: %[[C1:.*]] = arith.constant dense<2.000000e+00> : vector<8xf32>
1818
// CHECK-DAG: %[[C2:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32>
1919
// CHECK-DAG: %[[C3:.*]] = arith.constant dense<255> : vector<8xi64>
20+
// CHECK-DAG: %[[C4:.*]] = arith.constant dense<4> : vector<8xi32>
21+
// CHECK-DAG: %[[C5:.*]] = arith.constant dense<1> : vector<8xi32>
2022
// CHECK: scf.for
2123
// CHECK: %[[VAL_14:.*]] = vector.load
2224
// CHECK: %[[VAL_15:.*]] = math.absf %[[VAL_14]] : vector<8xf32>
@@ -38,15 +40,20 @@
3840
// CHECK: %[[VAL_31:.*]] = arith.andi %[[VAL_30]], %[[C3]] : vector<8xi64>
3941
// CHECK: %[[VAL_32:.*]] = arith.trunci %[[VAL_31]] : vector<8xi64> to vector<8xi16>
4042
// CHECK: %[[VAL_33:.*]] = arith.extsi %[[VAL_32]] : vector<8xi16> to vector<8xi32>
41-
// CHECK: %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : vector<8xi32> to vector<8xf32>
42-
// CHECK: vector.store %[[VAL_34]]
43+
// CHECK: %[[VAL_34:.*]] = arith.shrsi %[[VAL_33]], %[[C4]] : vector<8xi32>
44+
// CHECK: %[[VAL_35:.*]] = arith.shrui %[[VAL_34]], %[[C4]] : vector<8xi32>
45+
// CHECK: %[[VAL_36:.*]] = arith.shli %[[VAL_35]], %[[C5]] : vector<8xi32>
46+
// CHECK: %[[VAL_37:.*]] = arith.uitofp %[[VAL_36]] : vector<8xi32> to vector<8xf32>
47+
// CHECK: vector.store %[[VAL_37]]
4348
// CHECK: }
4449
func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
4550
%argb: tensor<1024xf32, #DenseVector>) -> tensor<1024xf32> {
4651
%init = bufferization.alloc_tensor() : tensor<1024xf32>
4752
%o = arith.constant 1.0 : f32
4853
%c = arith.constant 2.0 : f32
4954
%i = arith.constant 255 : i64
55+
%s = arith.constant 4 : i32
56+
%t = arith.constant 1 : i32
5057
%0 = linalg.generic #trait
5158
ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32, #DenseVector>)
5259
outs(%init: tensor<1024xf32>) {
@@ -69,8 +76,11 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
6976
%15 = arith.andi %14, %i : i64
7077
%16 = arith.trunci %15 : i64 to i16
7178
%17 = arith.extsi %16 : i16 to i32
72-
%18 = arith.uitofp %17 : i32 to f32
73-
linalg.yield %18 : f32
79+
%18 = arith.shrsi %17, %s : i32
80+
%19 = arith.shrui %18, %s : i32
81+
%20 = arith.shli %19, %t : i32
82+
%21 = arith.uitofp %20 : i32 to f32
83+
linalg.yield %21 : f32
7484
} -> tensor<1024xf32>
7585
return %0 : tensor<1024xf32>
7686
}

0 commit comments

Comments
 (0)