Skip to content

Commit 11069cb

Browse files
author
Peiming Liu
committed
[mlir][sparse] refactoring: split translateIndices.
TranslateIndicesArray take an array of SSA value and convert them into another array of SSA values based on reassociation. Which makes it easier to be reused by `foreach` operator (as the indices array are given as an array of SSA values). Reviewed By: aartbik, bixia Differential Revision: https://reviews.llvm.org/D134918
1 parent 073534c commit 11069cb

File tree

4 files changed

+77
-43
lines changed

4 files changed

+77
-43
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,52 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
199199
return builder.create<complex::NotEqualOp>(loc, v, zero);
200200
llvm_unreachable("Non-numeric type");
201201
}
202+
203+
void mlir::sparse_tensor::translateIndicesArray(
204+
OpBuilder &builder, Location loc,
205+
ArrayRef<ReassociationIndices> reassociation, ValueRange srcIndices,
206+
ArrayRef<Value> srcShape, ArrayRef<Value> dstShape,
207+
SmallVectorImpl<Value> &dstIndices) {
208+
unsigned i = 0;
209+
unsigned start = 0;
210+
unsigned dstRank = dstShape.size();
211+
unsigned srcRank = srcShape.size();
212+
assert(srcRank == srcIndices.size());
213+
bool isCollapse = srcRank > dstRank;
214+
ArrayRef<Value> shape = isCollapse ? srcShape : dstShape;
215+
// Iterate over reassociation map.
216+
for (const auto &map : llvm::enumerate(reassociation)) {
217+
// Prepare strides information in dimension slice.
218+
Value linear = constantIndex(builder, loc, 1);
219+
for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
220+
linear = builder.create<arith::MulIOp>(loc, linear, shape[j]);
221+
}
222+
// Start expansion.
223+
Value val;
224+
if (!isCollapse)
225+
val = srcIndices[i];
226+
// Iterate over dimension slice.
227+
for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
228+
linear = builder.create<arith::DivUIOp>(loc, linear, shape[j]);
229+
if (isCollapse) {
230+
Value old = srcIndices[j];
231+
Value mul = builder.create<arith::MulIOp>(loc, old, linear);
232+
val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
233+
} else {
234+
Value old = val;
235+
val = builder.create<arith::DivUIOp>(loc, val, linear);
236+
assert(dstIndices.size() == j);
237+
dstIndices.push_back(val);
238+
val = builder.create<arith::RemUIOp>(loc, old, linear);
239+
}
240+
}
241+
// Finalize collapse.
242+
if (isCollapse) {
243+
assert(dstIndices.size() == i);
244+
dstIndices.push_back(val);
245+
}
246+
start += map.value().size();
247+
i++;
248+
}
249+
assert(dstIndices.size() == dstRank);
250+
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Complex/IR/Complex.h"
1818
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
19+
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1920
#include "mlir/ExecutionEngine/SparseTensor/Enums.h"
21+
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
2022
#include "mlir/IR/Builders.h"
2123

2224
namespace mlir {
@@ -193,6 +195,12 @@ constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
193195
static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
194196
}
195197

198+
/// Helper method to translate indices during a reshaping operation.
199+
void translateIndicesArray(OpBuilder &builder, Location loc,
200+
ArrayRef<ReassociationIndices> reassociation,
201+
ValueRange srcIndices, ArrayRef<Value> srcShape,
202+
ArrayRef<Value> dstShape,
203+
SmallVectorImpl<Value> &dstIndices);
196204
} // namespace sparse_tensor
197205
} // namespace mlir
198206

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -475,44 +475,21 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
475475
ArrayRef<Value> srcShape) {
476476
unsigned dstRank = dstTp.getRank();
477477
unsigned srcRank = srcTp.getRank();
478-
unsigned start = 0;
479-
unsigned i = 0;
480-
bool isExpand = srcRank > dstRank;
481-
ArrayRef<Value> shape = isExpand ? srcShape : dstShape;
482-
// Iterate over reassociation map.
483-
for (const auto &map : llvm::enumerate(reassociation)) {
484-
// Prepare strides information in dimension slice.
485-
Value linear = constantIndex(rewriter, loc, 1);
486-
for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
487-
linear = rewriter.create<arith::MulIOp>(loc, linear, shape[j]);
488-
}
489-
// Start collapse.
490-
Value idx = constantIndex(rewriter, loc, i++);
491-
Value val;
492-
if (!isExpand)
493-
val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
494-
// Iterate over dimension slice.
495-
for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
496-
linear = rewriter.create<arith::DivUIOp>(loc, linear, shape[j]);
497-
Value jdx = constantIndex(rewriter, loc, j);
498-
if (isExpand) {
499-
Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
500-
Value mul = rewriter.create<arith::MulIOp>(loc, old, linear);
501-
val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
502-
} else {
503-
Value old = val;
504-
val = rewriter.create<arith::DivUIOp>(loc, val, linear);
505-
rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
506-
val = rewriter.create<arith::RemUIOp>(loc, old, linear);
507-
}
508-
}
509-
// Finalize expansion.
510-
if (isExpand)
511-
rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
512-
start += map.value().size();
478+
479+
SmallVector<Value, 4> srcIndices;
480+
for (unsigned i = 0; i < srcRank; i++) {
481+
Value idx = rewriter.create<memref::LoadOp>(
482+
loc, srcIdx, constantIndex(rewriter, loc, i));
483+
srcIndices.push_back(idx);
513484
}
514-
// Sanity.
515-
assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
485+
486+
SmallVector<Value, 4> dstIndices;
487+
translateIndicesArray(rewriter, loc, reassociation, srcIndices, srcShape,
488+
dstShape, dstIndices);
489+
490+
for (unsigned i = 0; i < dstRank; i++)
491+
rewriter.create<memref::StoreOp>(loc, dstIndices[i], dstIdx,
492+
constantIndex(rewriter, loc, i));
516493
}
517494

518495
/// Helper method to compute the shape of destination tensor of a reshape

mlir/test/Dialect/SparseTensor/sparse_reshape.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
// CHECK-CONV: } do {
2727
// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex>
2828
// CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index
29-
// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex>
3029
// CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index
30+
// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex>
3131
// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<2xindex>
3232
// CHECK-CONV: call @addEltF64
3333
// CHECK-CONV: scf.yield
@@ -64,8 +64,8 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
6464
// CHECK-CONV: scf.condition
6565
// CHECK-CONV: } do {
6666
// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex>
67-
// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
6867
// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
68+
// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
6969
// CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index
7070
// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>
7171
// CHECK-CONV: call @addEltF64
@@ -103,14 +103,14 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
103103
// CHECK-CONV: call @getNextF64
104104
// CHECK-CONV: scf.condition
105105
// CHECK-CONV: } do {
106-
// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index
107106
// CHECK-CONV: %[[L:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex>
107+
// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index
108108
// CHECK-CONV: %[[D2:.*]] = arith.divui %[[M]], %[[D1]] : index
109109
// CHECK-CONV: %[[D3:.*]] = arith.divui %[[L]], %[[D2]] : index
110-
// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex>
111110
// CHECK-CONV: %[[R:.*]] = arith.remui %[[L]], %[[D2]] : index
112111
// CHECK-CONV: %[[D4:.*]] = arith.divui %[[D2]], %[[C10]] : index
113112
// CHECK-CONV: %[[D5:.*]] = arith.divui %[[R]], %[[D4]] : index
113+
// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex>
114114
// CHECK-CONV: memref.store %[[D5]], %{{.*}}[%[[C1]]] : memref<2xindex>
115115
// CHECK-CONV: call @addEltF64
116116
// CHECK-CONV: scf.yield
@@ -147,11 +147,11 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
147147
// CHECK-CONV: call @getNextF64
148148
// CHECK-CONV: scf.condition
149149
// CHECK-CONV: } do {
150-
// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index
151150
// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex>
151+
// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
152+
// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index
152153
// CHECK-CONV: %[[M2:.*]] = arith.muli %[[X]], %[[D1]] : index
153154
// CHECK-CONV: %[[D2:.*]] = arith.divui %[[D1]], %{{.*}} : index
154-
// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
155155
// CHECK-CONV: %[[M3:.*]] = arith.muli %[[Y]], %[[D2]] : index
156156
// CHECK-CONV: %[[A:.*]] = arith.addi %[[M2]], %[[M3]] : index
157157
// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>

0 commit comments

Comments
 (0)