16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/Func/IR/FuncOps.h"
18
18
#include " mlir/Dialect/Linalg/IR/Linalg.h"
19
+ #include " mlir/Dialect/Math/IR/Math.h"
19
20
#include " mlir/Dialect/MemRef/IR/MemRef.h"
20
21
#include " mlir/Dialect/SCF/IR/SCF.h"
21
22
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -38,12 +39,13 @@ static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
38
39
static constexpr const char kPartitionFuncNamePrefix [] = " _sparse_partition_" ;
39
40
static constexpr const char kBinarySearchFuncNamePrefix [] =
40
41
" _sparse_binary_search_" ;
41
- static constexpr const char kSortNonstableFuncNamePrefix [] =
42
- " _sparse_sort_nonstable_ " ;
42
+ static constexpr const char kHybridQuickSortFuncNamePrefix [] =
43
+ " _sparse_hybrid_qsort_ " ;
43
44
static constexpr const char kSortStableFuncNamePrefix [] =
44
45
" _sparse_sort_stable_" ;
45
46
static constexpr const char kShiftDownFuncNamePrefix [] = " _sparse_shift_down_" ;
46
47
static constexpr const char kHeapSortFuncNamePrefix [] = " _sparse_heap_sort_" ;
48
+ static constexpr const char kQuickSortFuncNamePrefix [] = " _sparse_qsort_" ;
47
49
48
50
using FuncGeneratorType = function_ref<void (
49
51
OpBuilder &, ModuleOp, func::FuncOp, uint64_t , uint64_t , bool , uint32_t )>;
@@ -916,41 +918,19 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
916
918
builder.create <func::ReturnOp>(loc);
917
919
}
918
920
919
- // / Creates a function to perform quick sort on the value in the range of
920
- // / index [lo, hi).
921
- //
922
- // The generate IR corresponds to this C like algorithm:
923
- // void quickSort(lo, hi, data) {
924
- // if (lo < hi) {
925
- // p = partition(low, high, data);
926
- // quickSort(lo, p, data);
927
- // quickSort(p + 1, hi, data);
928
- // }
929
- // }
930
- static void createSortNonstableFunc (OpBuilder &builder, ModuleOp module ,
931
- func::FuncOp func, uint64_t nx, uint64_t ny,
932
- bool isCoo, uint32_t nTrailingP) {
933
- (void )nTrailingP;
934
- OpBuilder::InsertionGuard insertionGuard (builder);
935
- Block *entryBlock = func.addEntryBlock ();
936
- builder.setInsertionPointToStart (entryBlock);
937
-
921
+ static void createQuickSort (OpBuilder &builder, ModuleOp module ,
922
+ func::FuncOp func, ValueRange args, uint64_t nx,
923
+ uint64_t ny, bool isCoo, uint32_t nTrailingP) {
938
924
MLIRContext *context = module .getContext ();
939
925
Location loc = func.getLoc ();
940
- ValueRange args = entryBlock->getArguments ();
941
926
Value lo = args[loIdx];
942
927
Value hi = args[hiIdx];
943
- Value cond =
944
- builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
945
- scf::IfOp ifOp = builder.create <scf::IfOp>(loc, cond, /* else=*/ false );
946
-
947
- // The if-stmt true branch.
948
- builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
949
928
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc (
950
929
builder, func, {IndexType::get (context)}, kPartitionFuncNamePrefix , nx,
951
- ny, isCoo, args, createPartitionFunc);
952
- auto p = builder.create <func::CallOp>(
953
- loc, partitionFunc, TypeRange{IndexType::get (context)}, ValueRange (args));
930
+ ny, isCoo, args.drop_back (nTrailingP), createPartitionFunc);
931
+ auto p = builder.create <func::CallOp>(loc, partitionFunc,
932
+ TypeRange{IndexType::get (context)},
933
+ args.drop_back (nTrailingP));
954
934
955
935
SmallVector<Value> lowOperands{lo, p.getResult (0 )};
956
936
lowOperands.append (args.begin () + xStartIdx, args.end ());
@@ -962,10 +942,6 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
962
942
hi};
963
943
highOperands.append (args.begin () + xStartIdx, args.end ());
964
944
builder.create <func::CallOp>(loc, func, highOperands);
965
-
966
- // After the if-stmt.
967
- builder.setInsertionPointAfter (ifOp);
968
- builder.create <func::ReturnOp>(loc);
969
945
}
970
946
971
947
// / Creates a function to perform insertion sort on the values in the range of
@@ -1054,6 +1030,116 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
1054
1030
builder.create <func::ReturnOp>(loc);
1055
1031
}
1056
1032
1033
+ // / Creates a function to perform quick sort or a hybrid quick sort on the
1034
+ // / values in the range of index [lo, hi).
1035
+ //
1036
+ //
1037
+ // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1038
+ // void quickSort(lo, hi, data) {
1039
+ // if (lo + 1 < hi) {
1040
+ // p = partition(low, high, data);
1041
+ // quickSort(lo, p, data);
1042
+ // quickSort(p + 1, hi, data);
1043
+ // }
1044
+ // }
1045
+ //
1046
+ // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1047
+ // void hybridQuickSort(lo, hi, data, depthLimit) {
1048
+ // if (lo + 1 < hi) {
1049
+ // len = hi - lo;
1050
+ // if (len <= limit) {
1051
+ // insertionSort(lo, hi, data);
1052
+ // } else {
1053
+ // depthLimit --;
1054
+ // if (depthLimit <= 0) {
1055
+ // heapSort(lo, hi, data);
1056
+ // } else {
1057
+ // p = partition(low, high, data);
1058
+ // quickSort(lo, p, data);
1059
+ // quickSort(p + 1, hi, data);
1060
+ // }
1061
+ // depthLimit ++;
1062
+ // }
1063
+ // }
1064
+ // }
1065
+ //
1066
+ static void createQuickSortFunc (OpBuilder &builder, ModuleOp module ,
1067
+ func::FuncOp func, uint64_t nx, uint64_t ny,
1068
+ bool isCoo, uint32_t nTrailingP) {
1069
+ assert (nTrailingP == 1 || nTrailingP == 0 );
1070
+ bool isHybrid = (nTrailingP == 1 );
1071
+ OpBuilder::InsertionGuard insertionGuard (builder);
1072
+ Block *entryBlock = func.addEntryBlock ();
1073
+ builder.setInsertionPointToStart (entryBlock);
1074
+
1075
+ Location loc = func.getLoc ();
1076
+ ValueRange args = entryBlock->getArguments ();
1077
+ Value lo = args[loIdx];
1078
+ Value hi = args[hiIdx];
1079
+ Value loCmp =
1080
+ builder.create <arith::AddIOp>(loc, lo, constantIndex (builder, loc, 1 ));
1081
+ Value cond =
1082
+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
1083
+ scf::IfOp ifOp = builder.create <scf::IfOp>(loc, cond, /* else=*/ false );
1084
+
1085
+ // The if-stmt true branch.
1086
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1087
+ Value pDepthLimit;
1088
+ Value savedDepthLimit;
1089
+ scf::IfOp depthIf;
1090
+
1091
+ if (isHybrid) {
1092
+ Value len = builder.create <arith::SubIOp>(loc, hi, lo);
1093
+ Value lenLimit = constantIndex (builder, loc, 30 );
1094
+ Value lenCond = builder.create <arith::CmpIOp>(
1095
+ loc, arith::CmpIPredicate::ule, len, lenLimit);
1096
+ scf::IfOp lenIf = builder.create <scf::IfOp>(loc, lenCond, /* else=*/ true );
1097
+
1098
+ // When len <= limit.
1099
+ builder.setInsertionPointToStart (&lenIf.getThenRegion ().front ());
1100
+ FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc (
1101
+ builder, func, TypeRange (), kSortStableFuncNamePrefix , nx, ny, isCoo,
1102
+ args.drop_back (nTrailingP), createSortStableFunc);
1103
+ builder.create <func::CallOp>(loc, insertionSortFunc, TypeRange (),
1104
+ ValueRange (args.drop_back (nTrailingP)));
1105
+
1106
+ // When len > limit.
1107
+ builder.setInsertionPointToStart (&lenIf.getElseRegion ().front ());
1108
+ pDepthLimit = args.back ();
1109
+ savedDepthLimit = builder.create <memref::LoadOp>(loc, pDepthLimit);
1110
+ Value depthLimit = builder.create <arith::SubIOp>(
1111
+ loc, savedDepthLimit, constantI64 (builder, loc, 1 ));
1112
+ builder.create <memref::StoreOp>(loc, depthLimit, pDepthLimit);
1113
+ Value depthCond =
1114
+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1115
+ depthLimit, constantI64 (builder, loc, 0 ));
1116
+ depthIf = builder.create <scf::IfOp>(loc, depthCond, /* else=*/ true );
1117
+
1118
+ // When depth exceeds limit.
1119
+ builder.setInsertionPointToStart (&depthIf.getThenRegion ().front ());
1120
+ FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc (
1121
+ builder, func, TypeRange (), kHeapSortFuncNamePrefix , nx, ny, isCoo,
1122
+ args.drop_back (nTrailingP), createHeapSortFunc);
1123
+ builder.create <func::CallOp>(loc, heapSortFunc, TypeRange (),
1124
+ ValueRange (args.drop_back (nTrailingP)));
1125
+
1126
+ // When depth doesn't exceed limit.
1127
+ builder.setInsertionPointToStart (&depthIf.getElseRegion ().front ());
1128
+ }
1129
+
1130
+ createQuickSort (builder, module , func, args, nx, ny, isCoo, nTrailingP);
1131
+
1132
+ if (isHybrid) {
1133
+ // Restore depthLimit.
1134
+ builder.setInsertionPointAfter (depthIf);
1135
+ builder.create <memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
1136
+ }
1137
+
1138
+ // After the if-stmt.
1139
+ builder.setInsertionPointAfter (ifOp);
1140
+ builder.create <func::ReturnOp>(loc);
1141
+ }
1142
+
1057
1143
// / Implements the rewriting for operator sort and sort_coo.
1058
1144
template <typename OpTy>
1059
1145
LogicalResult matchAndRewriteSortOp (OpTy op, ValueRange xys, uint64_t nx,
@@ -1078,10 +1164,30 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
1078
1164
FuncGeneratorType funcGenerator;
1079
1165
uint32_t nTrailingP = 0 ;
1080
1166
switch (op.getAlgorithm ()) {
1081
- case SparseTensorSortKind::HybridQuickSort:
1167
+ case SparseTensorSortKind::HybridQuickSort: {
1168
+ funcName = kHybridQuickSortFuncNamePrefix ;
1169
+ funcGenerator = createQuickSortFunc;
1170
+ nTrailingP = 1 ;
1171
+ Value pDepthLimit = rewriter.create <memref::AllocaOp>(
1172
+ loc, MemRefType::get ({}, rewriter.getI64Type ()));
1173
+ operands.push_back (pDepthLimit);
1174
+ // As a heuristics, set depthLimit = 2 * log2(n).
1175
+ Value lo = operands[loIdx];
1176
+ Value hi = operands[hiIdx];
1177
+ Value len = rewriter.create <arith::IndexCastOp>(
1178
+ loc, rewriter.getI64Type (),
1179
+ rewriter.create <arith::SubIOp>(loc, hi, lo));
1180
+ Value depthLimit = rewriter.create <arith::SubIOp>(
1181
+ loc, constantI64 (rewriter, loc, 64 ),
1182
+ rewriter.create <math::CountLeadingZerosOp>(loc, len));
1183
+ depthLimit = rewriter.create <arith::ShLIOp>(loc, depthLimit,
1184
+ constantI64 (rewriter, loc, 1 ));
1185
+ rewriter.create <memref::StoreOp>(loc, depthLimit, pDepthLimit);
1186
+ break ;
1187
+ }
1082
1188
case SparseTensorSortKind::QuickSort:
1083
- funcName = kSortNonstableFuncNamePrefix ;
1084
- funcGenerator = createSortNonstableFunc ;
1189
+ funcName = kQuickSortFuncNamePrefix ;
1190
+ funcGenerator = createQuickSortFunc ;
1085
1191
break ;
1086
1192
case SparseTensorSortKind::InsertionSortStable:
1087
1193
funcName = kSortStableFuncNamePrefix ;
0 commit comments