Skip to content

Commit 1c835b5

Browse files
committed
[mlir][sparse] Allow the push_back operator to skip capacity check and reallocation.
Add UnitAttr `inbounds` for this purpose. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D134913
1 parent 7eee2a2 commit 1c835b5

File tree

4 files changed

+69
-22
lines changed

4 files changed

+69
-22
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
240240
def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
241241
Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
242242
StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
243-
AnyType:$value, IndexAttr:$idx)>,
243+
AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>,
244244
Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> {
245245
string summary = "Pushes a value to the back of a given buffer";
246246
string description = [{
@@ -250,6 +250,11 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
250250
current buffer is full, then `inBuffer.realloc` is called before pushing the
251251
data to the buffer. This is similar to std::vector push_back.
252252

253+
The `inbounds` attribute tells the compiler that the insertion won't go
254+
beyond the current storage buffer. This allows the compiler to not generate
255+
the code for capacity check and reallocation. The typical usage will be for
256+
"dynamic" sparse tensors for which a capacity can be set beforehand.
257+
253258
The operation returns an SSA value for the memref. Referencing the memref
254259
through the old SSA value after this operation is undefined behavior.
255260

@@ -259,9 +264,14 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
259264
%r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index}
260265
: memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
261266
```
267+
268+
```mlir
269+
%r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
270+
{idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
271+
```
262272
}];
263-
let assemblyFormat = "$bufferSizes `,` $inBuffer `,` $value"
264-
" attr-dict `:` type($bufferSizes) `,`"
273+
let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
274+
" `,` $value attr-dict `:` type($bufferSizes) `,`"
265275
" type($inBuffer) `,` type($value) `to`"
266276
" type($outBuffer)";
267277
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -350,35 +350,43 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
350350
// buffer = new_buffer
351351
// store(buffer, value)
352352
// size(buffer)++
353+
//
354+
// The capacity check is skipped when the attribute inbounds is presented.
353355
Location loc = op->getLoc();
354356
Value c0 = constantIndex(rewriter, loc, 0);
355357
Value buffer = op.getInBuffer();
356358
Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
357359
Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
358360
Value bufferSizes = op.getBufferSizes();
359361
Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
360-
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
361-
size, capacity);
362362
Value value = op.getValue();
363-
auto bufferType =
364-
MemRefType::get({ShapedType::kDynamicSize}, value.getType());
365-
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
366-
/*else=*/true);
367-
// True branch.
368-
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
369-
Value c2 = constantIndex(rewriter, loc, 2);
370-
capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
371-
Value newBuffer =
372-
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
373-
rewriter.create<scf::YieldOp>(loc, newBuffer);
374-
375-
// False branch.
376-
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
377-
rewriter.create<scf::YieldOp>(loc, buffer);
363+
364+
if (!op.getInbounds()) {
365+
Value cond = rewriter.create<arith::CmpIOp>(
366+
loc, arith::CmpIPredicate::uge, size, capacity);
367+
368+
auto bufferType =
369+
MemRefType::get({ShapedType::kDynamicSize}, value.getType());
370+
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
371+
/*else=*/true);
372+
// True branch.
373+
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
374+
Value c2 = constantIndex(rewriter, loc, 2);
375+
capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
376+
Value newBuffer =
377+
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
378+
rewriter.create<scf::YieldOp>(loc, newBuffer);
379+
380+
// False branch.
381+
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
382+
rewriter.create<scf::YieldOp>(loc, buffer);
383+
384+
// Prepare for adding the value to the end of the buffer.
385+
rewriter.setInsertionPointAfter(ifOp);
386+
buffer = ifOp.getResult(0);
387+
}
378388

379389
// Add the value to the end of the buffer.
380-
rewriter.setInsertionPointAfter(ifOp);
381-
buffer = ifOp.getResult(0);
382390
rewriter.create<memref::StoreOp>(loc, value, buffer, size);
383391

384392
// Increment the size of the buffer by 1.

mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
2626
return %0 : memref<?xf64>
2727
}
2828

29+
// CHECK-LABEL: func @sparse_push_back_inbound(
30+
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
31+
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
32+
// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
33+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
34+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
35+
// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
36+
// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]]
37+
// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
38+
// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
39+
// CHECK: return %[[B]] : memref<?xf64>
40+
func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
41+
%0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
42+
return %0 : memref<?xf64>
43+
}
44+
2945
// CHECK-LABEL: func.func private @_sparse_less_than_1_i8(
3046
// CHECK-SAME: %[[I:arg0]]: index,
3147
// CHECK-SAME: %[[J:.*]]: index,

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,19 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
145145

146146
// -----
147147

148+
// CHECK-LABEL: func @sparse_push_back_inbound(
149+
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
150+
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
151+
// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
152+
// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
153+
// CHECK: return %[[D]]
154+
func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
155+
%0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
156+
return %0 : memref<?xf64>
157+
}
158+
159+
// -----
160+
148161
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
149162

150163
// CHECK-LABEL: func @sparse_expansion(

0 commit comments

Comments
 (0)