Skip to content

Commit a150766

Browse files
committed
[mlir][sparse] Implement hybrid quick sort for sparse_tensor.sort.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143227
1 parent 550cb76 commit a150766

File tree

5 files changed

+216
-56
lines changed

5 files changed

+216
-56
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
275275
return builder.create<arith::ConstantIndexOp>(loc, i);
276276
}
277277

278+
/// Generates a constant of `i64` type.
279+
inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) {
280+
return builder.create<arith::ConstantIntOp>(loc, i, 64);
281+
}
282+
278283
/// Generates a constant of `i32` type.
279284
inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
280285
return builder.create<arith::ConstantIntOp>(loc, i, 32);

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp

Lines changed: 144 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Math/IR/Math.h"
1920
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2021
#include "mlir/Dialect/SCF/IR/SCF.h"
2122
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -38,12 +39,13 @@ static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
3839
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
3940
static constexpr const char kBinarySearchFuncNamePrefix[] =
4041
"_sparse_binary_search_";
41-
static constexpr const char kSortNonstableFuncNamePrefix[] =
42-
"_sparse_sort_nonstable_";
42+
static constexpr const char kHybridQuickSortFuncNamePrefix[] =
43+
"_sparse_hybrid_qsort_";
4344
static constexpr const char kSortStableFuncNamePrefix[] =
4445
"_sparse_sort_stable_";
4546
static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
4647
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
48+
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
4749

4850
using FuncGeneratorType = function_ref<void(
4951
OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
@@ -916,41 +918,19 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
916918
builder.create<func::ReturnOp>(loc);
917919
}
918920

919-
/// Creates a function to perform quick sort on the value in the range of
920-
/// index [lo, hi).
921-
//
922-
// The generate IR corresponds to this C like algorithm:
923-
// void quickSort(lo, hi, data) {
924-
// if (lo < hi) {
925-
// p = partition(low, high, data);
926-
// quickSort(lo, p, data);
927-
// quickSort(p + 1, hi, data);
928-
// }
929-
// }
930-
static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
931-
func::FuncOp func, uint64_t nx, uint64_t ny,
932-
bool isCoo, uint32_t nTrailingP) {
933-
(void)nTrailingP;
934-
OpBuilder::InsertionGuard insertionGuard(builder);
935-
Block *entryBlock = func.addEntryBlock();
936-
builder.setInsertionPointToStart(entryBlock);
937-
921+
static void createQuickSort(OpBuilder &builder, ModuleOp module,
922+
func::FuncOp func, ValueRange args, uint64_t nx,
923+
uint64_t ny, bool isCoo, uint32_t nTrailingP) {
938924
MLIRContext *context = module.getContext();
939925
Location loc = func.getLoc();
940-
ValueRange args = entryBlock->getArguments();
941926
Value lo = args[loIdx];
942927
Value hi = args[hiIdx];
943-
Value cond =
944-
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
945-
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
946-
947-
// The if-stmt true branch.
948-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
949928
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
950929
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
951-
ny, isCoo, args, createPartitionFunc);
952-
auto p = builder.create<func::CallOp>(
953-
loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
930+
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
931+
auto p = builder.create<func::CallOp>(loc, partitionFunc,
932+
TypeRange{IndexType::get(context)},
933+
args.drop_back(nTrailingP));
954934

955935
SmallVector<Value> lowOperands{lo, p.getResult(0)};
956936
lowOperands.append(args.begin() + xStartIdx, args.end());
@@ -962,10 +942,6 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
962942
hi};
963943
highOperands.append(args.begin() + xStartIdx, args.end());
964944
builder.create<func::CallOp>(loc, func, highOperands);
965-
966-
// After the if-stmt.
967-
builder.setInsertionPointAfter(ifOp);
968-
builder.create<func::ReturnOp>(loc);
969945
}
970946

971947
/// Creates a function to perform insertion sort on the values in the range of
@@ -1054,6 +1030,116 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
10541030
builder.create<func::ReturnOp>(loc);
10551031
}
10561032

1033+
/// Creates a function to perform quick sort or a hybrid quick sort on the
1034+
/// values in the range of index [lo, hi).
1035+
//
1036+
//
1037+
// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1038+
// void quickSort(lo, hi, data) {
1039+
// if (lo + 1 < hi) {
1040+
// p = partition(low, high, data);
1041+
// quickSort(lo, p, data);
1042+
// quickSort(p + 1, hi, data);
1043+
// }
1044+
// }
1045+
//
1046+
// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1047+
// void hybridQuickSort(lo, hi, data, depthLimit) {
1048+
// if (lo + 1 < hi) {
1049+
// len = hi - lo;
1050+
// if (len <= limit) {
1051+
// insertionSort(lo, hi, data);
1052+
// } else {
1053+
// depthLimit --;
1054+
// if (depthLimit <= 0) {
1055+
// heapSort(lo, hi, data);
1056+
// } else {
1057+
// p = partition(low, high, data);
1058+
// quickSort(lo, p, data);
1059+
// quickSort(p + 1, hi, data);
1060+
// }
1061+
// depthLimit ++;
1062+
// }
1063+
// }
1064+
// }
1065+
//
1066+
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1067+
func::FuncOp func, uint64_t nx, uint64_t ny,
1068+
bool isCoo, uint32_t nTrailingP) {
1069+
assert(nTrailingP == 1 || nTrailingP == 0);
1070+
bool isHybrid = (nTrailingP == 1);
1071+
OpBuilder::InsertionGuard insertionGuard(builder);
1072+
Block *entryBlock = func.addEntryBlock();
1073+
builder.setInsertionPointToStart(entryBlock);
1074+
1075+
Location loc = func.getLoc();
1076+
ValueRange args = entryBlock->getArguments();
1077+
Value lo = args[loIdx];
1078+
Value hi = args[hiIdx];
1079+
Value loCmp =
1080+
builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1081+
Value cond =
1082+
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
1083+
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
1084+
1085+
// The if-stmt true branch.
1086+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1087+
Value pDepthLimit;
1088+
Value savedDepthLimit;
1089+
scf::IfOp depthIf;
1090+
1091+
if (isHybrid) {
1092+
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1093+
Value lenLimit = constantIndex(builder, loc, 30);
1094+
Value lenCond = builder.create<arith::CmpIOp>(
1095+
loc, arith::CmpIPredicate::ule, len, lenLimit);
1096+
scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
1097+
1098+
// When len <= limit.
1099+
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1100+
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1101+
builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
1102+
args.drop_back(nTrailingP), createSortStableFunc);
1103+
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1104+
ValueRange(args.drop_back(nTrailingP)));
1105+
1106+
// When len > limit.
1107+
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1108+
pDepthLimit = args.back();
1109+
savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit);
1110+
Value depthLimit = builder.create<arith::SubIOp>(
1111+
loc, savedDepthLimit, constantI64(builder, loc, 1));
1112+
builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
1113+
Value depthCond =
1114+
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1115+
depthLimit, constantI64(builder, loc, 0));
1116+
depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true);
1117+
1118+
// When depth exceeds limit.
1119+
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1120+
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1121+
builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
1122+
args.drop_back(nTrailingP), createHeapSortFunc);
1123+
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1124+
ValueRange(args.drop_back(nTrailingP)));
1125+
1126+
// When depth doesn't exceed limit.
1127+
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1128+
}
1129+
1130+
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
1131+
1132+
if (isHybrid) {
1133+
// Restore depthLimit.
1134+
builder.setInsertionPointAfter(depthIf);
1135+
builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
1136+
}
1137+
1138+
// After the if-stmt.
1139+
builder.setInsertionPointAfter(ifOp);
1140+
builder.create<func::ReturnOp>(loc);
1141+
}
1142+
10571143
/// Implements the rewriting for operator sort and sort_coo.
10581144
template <typename OpTy>
10591145
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
@@ -1078,10 +1164,30 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
10781164
FuncGeneratorType funcGenerator;
10791165
uint32_t nTrailingP = 0;
10801166
switch (op.getAlgorithm()) {
1081-
case SparseTensorSortKind::HybridQuickSort:
1167+
case SparseTensorSortKind::HybridQuickSort: {
1168+
funcName = kHybridQuickSortFuncNamePrefix;
1169+
funcGenerator = createQuickSortFunc;
1170+
nTrailingP = 1;
1171+
Value pDepthLimit = rewriter.create<memref::AllocaOp>(
1172+
loc, MemRefType::get({}, rewriter.getI64Type()));
1173+
operands.push_back(pDepthLimit);
1174+
// As a heuristics, set depthLimit = 2 * log2(n).
1175+
Value lo = operands[loIdx];
1176+
Value hi = operands[hiIdx];
1177+
Value len = rewriter.create<arith::IndexCastOp>(
1178+
loc, rewriter.getI64Type(),
1179+
rewriter.create<arith::SubIOp>(loc, hi, lo));
1180+
Value depthLimit = rewriter.create<arith::SubIOp>(
1181+
loc, constantI64(rewriter, loc, 64),
1182+
rewriter.create<math::CountLeadingZerosOp>(loc, len));
1183+
depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit,
1184+
constantI64(rewriter, loc, 1));
1185+
rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
1186+
break;
1187+
}
10821188
case SparseTensorSortKind::QuickSort:
1083-
funcName = kSortNonstableFuncNamePrefix;
1084-
funcGenerator = createSortNonstableFunc;
1189+
funcName = kQuickSortFuncNamePrefix;
1190+
funcGenerator = createQuickSortFunc;
10851191
break;
10861192
case SparseTensorSortKind::InsertionSortStable:
10871193
funcName = kSortStableFuncNamePrefix;

0 commit comments

Comments
 (0)