Skip to content

Commit fbd5821

Browse files
committed
Implement the conversion from sparse constant to sparse tensors.
The sparse constant provides a constant tensor in coordinate format. We first split the sparse constant into a constant tensor for indices and a constant tensor for values. We then generate a loop to fill a sparse tensor in coordinate format using the tensors for the indices and the values. Finally, we convert the sparse tensor in coordinate format to the destination sparse tensor format. Add tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D110373
1 parent 5357a98 commit fbd5821

File tree

4 files changed

+192
-22
lines changed

4 files changed

+192
-22
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,39 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
195195
llvm_unreachable("Unknown element type");
196196
}
197197

198+
/// Generates the code to read the value from tensor[ivs], and conditionally
199+
/// stores the indices ivs to the memory in ind. The generated code looks like
200+
/// the following and the insertion point after this routine is inside the
201+
/// if-then branch behind the assignment to ind. This is to ensure that the
202+
/// addEltX call generated after is inside the if-then branch.
203+
/// if (tensor[ivs]!=0) {
204+
/// ind = ivs
205+
static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
206+
Operation *op, Type eltType, Value tensor,
207+
Value ind, ValueRange ivs) {
208+
Location loc = op->getLoc();
209+
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
210+
Value cond = genIsNonzero(rewriter, loc, eltType, val);
211+
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
212+
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
213+
unsigned i = 0;
214+
for (auto iv : ivs) {
215+
Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
216+
rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
217+
}
218+
return val;
219+
}
220+
198221
/// Generates a call that adds one element to a coordinate scheme.
199222
/// In particular, this generates code like the following:
200223
/// val = a[i1,..,ik];
201224
/// if val != 0
202225
/// t->add(val, [i1,..,ik], [p1,..,pk]);
203226
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
204-
Value ptr, Value tensor, Value ind, Value perm,
205-
ValueRange ivs) {
227+
Type eltType, Value ptr, Value val, Value ind,
228+
Value perm) {
229+
Location loc = op->getLoc();
206230
StringRef name;
207-
Type eltType = tensor.getType().cast<ShapedType>().getElementType();
208231
if (eltType.isF64())
209232
name = "addEltF64";
210233
else if (eltType.isF32())
@@ -219,16 +242,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
219242
name = "addEltI8";
220243
else
221244
llvm_unreachable("Unknown element type");
222-
Location loc = op->getLoc();
223-
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
224-
Value cond = genIsNonzero(rewriter, loc, eltType, val);
225-
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
226-
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
227-
unsigned i = 0;
228-
for (auto iv : ivs) {
229-
Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
230-
rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
231-
}
232245
SmallVector<Value, 8> params;
233246
params.push_back(ptr);
234247
params.push_back(val);
@@ -240,6 +253,41 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
240253
params);
241254
}
242255

256+
/// If the tensor is a sparse constant, generates and returns the pair of
257+
/// the constants for the indices and the values.
258+
static Optional<std::pair<Value, Value>>
259+
genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
260+
Value tensor) {
261+
if (auto constOp = tensor.getDefiningOp<ConstantOp>()) {
262+
if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
263+
Location loc = op->getLoc();
264+
DenseElementsAttr indicesAttr = attr.getIndices();
265+
Value indices = rewriter.create<ConstantOp>(loc, indicesAttr);
266+
DenseElementsAttr valuesAttr = attr.getValues();
267+
Value values = rewriter.create<ConstantOp>(loc, valuesAttr);
268+
return std::make_pair(indices, values);
269+
}
270+
}
271+
return {};
272+
}
273+
274+
/// Generates the code to copy the index at indices[ivs] to ind, and return
275+
/// the value at value[ivs].
276+
static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
277+
Operation *op, Value indices,
278+
Value values, Value ind, ValueRange ivs,
279+
unsigned rank) {
280+
Location loc = op->getLoc();
281+
for (unsigned i = 0; i < rank; i++) {
282+
Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i));
283+
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
284+
ValueRange{ivs[0], idx});
285+
val = rewriter.create<IndexCastOp>(loc, val, rewriter.getIndexType());
286+
rewriter.create<memref::StoreOp>(loc, val, ind, idx);
287+
}
288+
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
289+
}
290+
243291
//===----------------------------------------------------------------------===//
244292
// Conversion rules.
245293
//===----------------------------------------------------------------------===//
@@ -330,15 +378,26 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
330378
// TODO: sparse => dense
331379
return failure();
332380
}
333-
// This is a dense => sparse conversion, which is handled as follows:
381+
// This is a dense => sparse conversion or a sparse constant in COO =>
382+
// sparse conversion, which is handled as follows:
334383
// t = newSparseCOO()
384+
// ...code to fill the COO tensor t...
385+
// s = newSparseTensor(t)
386+
//
387+
// To fill the COO tensor from a dense tensor:
335388
// for i1 in dim1
336389
// ..
337390
// for ik in dimk
338391
// val = a[i1,..,ik]
339392
// if val != 0
340393
// t->add(val, [i1,..,ik], [p1,..,pk])
341-
// s = newSparseTensor(t)
394+
//
395+
// To fill the COO tensor from a sparse constant in COO format:
396+
// for i in range(NNZ)
397+
// val = values[i]
398+
// [i1,..,ik] = indices[i]
399+
// t->add(val, [i1,..,ik], [p1,..,pk])
400+
//
342401
// Note that the dense tensor traversal code is actually implemented
343402
// using MLIR IR to avoid having to expose too much low-level
344403
// memref traversal details to the runtime support library.
@@ -351,7 +410,6 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
351410
MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
352411
Value perm;
353412
Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
354-
Value tensor = adaptor.getOperands()[0];
355413
Value arg = rewriter.create<ConstantOp>(
356414
loc, rewriter.getIndexAttr(shape.getRank()));
357415
Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
@@ -360,16 +418,38 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
360418
SmallVector<Value> st;
361419
Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
362420
Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
363-
for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
421+
Value tensor = adaptor.getOperands()[0];
422+
auto indicesValues = genSplitSparseConstant(rewriter, op, tensor);
423+
bool isCOOConstant = indicesValues.hasValue();
424+
Value indices;
425+
Value values;
426+
if (isCOOConstant) {
427+
indices = indicesValues->first;
428+
values = indicesValues->second;
364429
lo.push_back(zero);
365-
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
430+
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
366431
st.push_back(one);
432+
} else {
433+
for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
434+
lo.push_back(zero);
435+
hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
436+
st.push_back(one);
437+
}
367438
}
439+
Type eltType = shape.getElementType();
440+
unsigned rank = shape.getRank();
368441
scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
369442
[&](OpBuilder &builder, Location loc, ValueRange ivs,
370443
ValueRange args) -> scf::ValueVector {
371-
genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
372-
ivs);
444+
Value val;
445+
if (isCOOConstant)
446+
val = genIndexAndValueForSparse(
447+
rewriter, op, indices, values, ind, ivs, rank);
448+
else
449+
val = genIndexAndValueForDense(rewriter, op, eltType,
450+
tensor, ind, ivs);
451+
genAddEltCall(rewriter, op, eltType, ptr, val, ind,
452+
perm);
373453
return {};
374454
});
375455
rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ struct SparseTensorConversionPass
114114
});
115115
// The following operations and dialects may be introduced by the
116116
// rewriting rules, and are therefore marked as legal.
117-
target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp, CmpFOp,
118-
CmpIOp>();
117+
target.addLegalOp<ConstantOp, IndexCastOp, tensor::CastOp,
118+
tensor::ExtractOp, CmpFOp, CmpIOp>();
119119
target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
120120
memref::MemRefDialect>();
121121
// Populate with rules and apply rewriting rules.

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,45 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
182182
return %0 : tensor<2x4xf64, #SparseMatrix>
183183
}
184184

185+
#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
186+
187+
// CHECK-LABEL: func @entry() -> !llvm.ptr<i8> {
188+
// CHECK: %[[C1:.*]] = constant 1 : i32
189+
// CHECK: %[[Offset:.*]] = constant dense<[0, 1]> : tensor<2xi64>
190+
// CHECK: %[[Dims:.*]] = constant dense<[8, 7]> : tensor<2xi64>
191+
// CHECK: %[[Base:.*]] = constant dense<[0, 1]> : tensor<2xi8>
192+
// CHECK: %[[I2:.*]] = constant 2 : index
193+
// CHECK: %[[SparseV:.*]] = constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
194+
// CHECK: %[[SparseI:.*]] = constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
195+
// CHECK: %[[I1:.*]] = constant 1 : index
196+
// CHECK: %[[I0:.*]] = constant 0 : index
197+
// CHECK: %[[C2:.*]] = constant 2 : i32
198+
// CHECK: %[[BaseD:.*]] = tensor.cast %[[Base]] : tensor<2xi8> to tensor<?xi8>
199+
// CHECK: %[[DimsD:.*]] = tensor.cast %[[Dims]] : tensor<2xi64> to tensor<?xi64>
200+
// CHECK: %[[OffsetD:.*]] = tensor.cast %[[Offset]] : tensor<2xi64> to tensor<?xi64>
201+
// CHECK: %[[TCOO:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[C2]], %{{.}})
202+
// CHECK: %[[Index:.*]] = memref.alloca() : memref<2xindex>
203+
// CHECK: %[[IndexD:.*]] = memref.cast %[[Index]] : memref<2xindex> to memref<?xindex>
204+
// CHECK: scf.for %[[IV:.*]] = %[[I0]] to %[[I2]] step %[[I1]] {
205+
// CHECK: %[[VAL0:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I0]]] : tensor<2x2xi64>
206+
// CHECK: %[[VAL1:.*]] = index_cast %[[VAL0]] : i64 to index
207+
// CHECK: memref.store %[[VAL1]], %[[Index]]{{\[}}%[[I0]]] : memref<2xindex>
208+
// CHECK: %[[VAL2:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I1]]] : tensor<2x2xi64>
209+
// CHECK: %[[VAL3:.*]] = index_cast %[[VAL2]] : i64 to index
210+
// CHECK: memref.store %[[VAL3]], %[[Index]]{{\[}}%[[I1]]] : memref<2xindex>
211+
// CHECK: %[[VAL4:.*]] = tensor.extract %[[SparseV]]{{\[}}%[[IV]]] : tensor<2xf32>
212+
// CHECK: call @addEltF32(%[[TCOO]], %[[VAL4]], %[[IndexD]], %[[OffsetD]])
213+
// CHECK: }
214+
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %[[C1]], %{{.*}})
215+
// CHECK: return %[[T]] : !llvm.ptr<i8>
216+
func @entry() -> tensor<8x7xf32, #CSR>{
217+
// Initialize a tensor.
218+
%0 = constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
219+
// Convert the tensor to a sparse tensor.
220+
%1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #CSR>
221+
return %1 : tensor<8x7xf32, #CSR>
222+
}
223+
185224
// CHECK-LABEL: func @sparse_convert_3d(
186225
// CHECK-SAME: %[[A:.*]]: tensor<?x?x?xf64>) -> !llvm.ptr<i8>
187226
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: --sparsification --sparse-tensor-conversion \
3+
// RUN: --convert-vector-to-scf --convert-scf-to-std \
4+
// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
5+
// RUN: --std-bufferize --finalizing-bufferize \
6+
// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm --reconcile-unrealized-casts | \
7+
// RUN: mlir-cpu-runner \
8+
// RUN: -e entry -entry-point-result=void \
9+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
10+
// RUN: FileCheck %s
11+
12+
#Tensor1 = #sparse_tensor.encoding<{
13+
dimLevelType = [ "compressed", "compressed"]
14+
}>
15+
16+
//
17+
// Integration tests for conversions from sparse constants to sparse tensors.
18+
//
19+
module {
20+
func @entry() {
21+
%c0 = constant 0 : index
22+
%c1 = constant 1 : index
23+
%c2 = constant 2 : index
24+
%d0 = constant 0.0 : f64
25+
26+
// A tensor in COO format.
27+
%ti = constant sparse<[[0, 0], [0, 7], [1, 2], [4, 2], [5, 3], [6, 4], [6, 6], [9, 7]],
28+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : tensor<10x8xf64>
29+
30+
// Convert the tensor in COO format to a sparse tensor with annotation #Tensor1.
31+
%ts = sparse_tensor.convert %ti : tensor<10x8xf64> to tensor<10x8xf64, #Tensor1>
32+
33+
// CHECK: ( 0, 1, 4, 5, 6, 9 )
34+
%i0 = sparse_tensor.indices %ts, %c0 : tensor<10x8xf64, #Tensor1> to memref<?xindex>
35+
%i0r = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<6xindex>
36+
vector.print %i0r : vector<6xindex>
37+
38+
// CHECK: ( 0, 7, 2, 2, 3, 4, 6, 7 )
39+
%i1 = sparse_tensor.indices %ts, %c1 : tensor<10x8xf64, #Tensor1> to memref<?xindex>
40+
%i1r = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<8xindex>
41+
vector.print %i1r : vector<8xindex>
42+
43+
// CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )
44+
%v = sparse_tensor.values %ts : tensor<10x8xf64, #Tensor1> to memref<?xf64>
45+
%vr = vector.transfer_read %v[%c0], %d0: memref<?xf64>, vector<8xf64>
46+
vector.print %vr : vector<8xf64>
47+
48+
return
49+
}
50+
}
51+

0 commit comments

Comments
 (0)