-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Analysis] Consolidate topological sort utilities #92563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir-core Author: Christian Ulmann (Dinistro) ChangesThis PR attempts to consolidate the different topological sort utilities into one place. It adds them to the analysis folder because the There are now two different sorting strategies:
This additionally reimplements the region aware topological sorting because the previous implementation had an exponential space complexity. I'm open to suggestions on how to combine this further or how to fuse the test passes. Patch is 38.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92563.diff 25 Files Affected:
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index d5cdf72c3889f..99279fdfe427c 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -223,11 +223,6 @@ SetVector<Operation *>
getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
const ForwardSliceOptions &forwardSliceOptions = {});
-/// Multi-root DAG topological sort.
-/// Performs a topological sort of the Operation in the `toSort` SetVector.
-/// Returns a topologically sorted SetVector.
-SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
-
/// Utility to match a generic reduction given a list of iteration-carried
/// arguments, `iterCarriedArgs` and the position of the potential reduction
/// argument within the list, `redPos`. If a reduction is matched, returns the
diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
similarity index 86%
rename from mlir/include/mlir/Transforms/TopologicalSortUtils.h
rename to mlir/include/mlir/Analysis/TopologicalSortUtils.h
index 74e44b1dc485d..7aabc5ee457c0 100644
--- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
-#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
+#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
#include "mlir/IR/Block.h"
@@ -104,6 +104,14 @@ bool computeTopologicalSorting(
MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+/// Get a list of blocks that is sorted according to dominance. This sort is
+/// stable.
+SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
+
+/// Sorts all operation in `toSort` topologically while also region semantics.
+/// Does not support multi-sets.
+SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
+
} // end namespace mlir
-#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index f65d0d44eef42..06eebff201d1b 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
LogicalResult runRegionDCE(RewriterBase &rewriter,
MutableArrayRef<Region> regions);
-/// Get a list of blocks that is sorted according to dominance. This sort is
-/// stable.
-SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
-
} // namespace mlir
#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 005814ddbec79..38d8415d81c72 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
@@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 26fe8e3dc0819..2b1cf411ceeee 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,7 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
@@ -164,62 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-namespace {
-/// DFS post-order implementation that maintains a global count to work across
-/// multiple invocations, to help implement topological sort on multi-root DAGs.
-/// We traverse all operations but only record the ones that appear in
-/// `toSort` for the final result.
-struct DFSState {
- DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
- const SetVector<Operation *> &toSort;
- SmallVector<Operation *, 16> topologicalCounts;
- DenseSet<Operation *> seen;
-};
-} // namespace
-
-static void dfsPostorder(Operation *root, DFSState *state) {
- SmallVector<Operation *> queue(1, root);
- std::vector<Operation *> ops;
- while (!queue.empty()) {
- Operation *current = queue.pop_back_val();
- ops.push_back(current);
- for (Operation *op : current->getUsers())
- queue.push_back(op);
- for (Region ®ion : current->getRegions()) {
- for (Operation &op : region.getOps())
- queue.push_back(&op);
- }
- }
-
- for (Operation *op : llvm::reverse(ops)) {
- if (state->seen.insert(op).second && state->toSort.count(op) > 0)
- state->topologicalCounts.push_back(op);
- }
-}
-
-SetVector<Operation *>
-mlir::topologicalSort(const SetVector<Operation *> &toSort) {
- if (toSort.empty()) {
- return toSort;
- }
-
- // Run from each root with global count and `seen` set.
- DFSState state(toSort);
- for (auto *s : toSort) {
- assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
- dfsPostorder(s, &state);
- }
-
- // Reorder and return.
- SetVector<Operation *> res;
- for (auto it = state.topologicalCounts.rbegin(),
- eit = state.topologicalCounts.rend();
- it != eit; ++it) {
- res.insert(*it);
- }
- return res;
-}
-
/// Returns true if `value` (transitively) depends on iteration-carried values
/// of the given `ancestorOp`.
static bool dependsOnCarriedVals(Value value,
diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
similarity index 58%
rename from mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
rename to mlir/lib/Analysis/TopologicalSortUtils.cpp
index f3a9d217f2c98..2fc1bd582ef47 100644
--- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -1,4 +1,4 @@
-//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/TopologicalSortUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/RegionGraphTraits.h"
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
using namespace mlir;
@@ -146,3 +151,93 @@ bool mlir::computeTopologicalSorting(
return allOpsScheduled;
}
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
+ }
+ }
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
+
+ return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
+ }
+ }
+ }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+ if (toSort.size() <= 1)
+ return toSort;
+
+ assert(llvm::all_of(toSort,
+ [&](Operation *op) { return toSort.count(op) == 1; }) &&
+ "expected only unique set entries");
+
+ // First, find the root region to start the recursive traversal through the
+ // IR.
+ DenseSet<Region *> relevantRegions;
+ Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+ assert(rootRegion && "expected all ops to have a common ancestor");
+
+ // Sort all element in `toSort` by recursively traversing the IR.
+ SetVector<Operation *> result;
+ topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+ assert(result.size() == toSort.size() &&
+ "expected all operations to be present in the result");
+ return result;
+}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 332f0a2eecfcf..4496c2bc5fe8b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -15,6 +15,7 @@
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 84ae4b52dcf4e..7f3e43d0b4cd3 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index eeda245ce969f..d9cf85e4aecab 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
@@ -19,7 +20,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 34b6903f8da07..9d125b7f11809 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index cf3257c8b9b87..1ec0736ec08bf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -16,6 +16,7 @@
#include "AttrKindDetail.h"
#include "DebugTranslation.h"
#include "LoopAnnotationTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
@@ -33,7 +34,6 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index e2e240ad865ce..a452cc3fae8ac 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/Mem2Reg.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,7 +17,6 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67cbade07bc94..39f7256fb789d 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/SROA.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp
index 1219968fb3692..528f6ef676020 100644
--- a/mlir/lib/Transforms/TopologicalSort.cpp
+++ b/mlir/lib/Transforms/TopologicalSort.cpp
@@ -8,8 +8,8 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/RegionKindInterface.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
namespace mlir {
#define GEN_PASS_DEF_TOPOLOGICALSORT
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index d6aac0e2da4f5..b5788c679edc4 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
- TopologicalSortUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 192f59b353295..b5e641d39fc0a 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
@@ -15,11 +16,9 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/SmallSet.h"
#include <deque>
@@ -836,19 +835,3 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
}
-
-SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
- // For each block that has not been visited yet (i.e. that has no
- // predecessors), add it to the list as well as its successors.
- SetVector<Block *> blocks;
- for (Block &b : region) {
- if (blocks.count(&b) == 0) {
- llvm::ReversePostOrderTraversal<Block *> traversal(&b);
- blocks.insert(traversal.begin(), traversal.end());
- }
- }
- assert(blocks.size() == region.getBlocks().size() &&
- "some blocks are not sorted");
-
- return blocks;
-}
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index c2eb2b893cea4..b3c0a06c96fea 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -8,12 +8,12 @@
#include "mlir/Transforms/ViewOpGraph.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/IndentedOstream.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/GraphWriter.h"
#include <map>
diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir
index 8608586402055..150aff854fc8f 100644
--- a/mlir/test/Analysis/test-topoligical-sort.mlir
+++ b/mlir/test/Analysis/test-topoligical-sort.mlir
@@ -1,21 +1,38 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" --split-input-file | FileCheck %s
-// CHECK-LABEL: Testing : region
-// CHECK: arith.addi {{.*}} : index
-// CHECK-NEXT: scf.for
-// CHECK: } {__test_sort_original_idx__ = 2 : i64}
-// CHECK-NEXT: arith.addi {{.*}} : i32
-// CHECK-NEXT: arith.subi {{.*}} : i32
-func.func @region(
- %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
- %arg4 : i32, %arg5 : i32, %arg6 : i32,
- %buffer : memref<i32>) {
- %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
- %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
- scf.for %arg7 = %idx to %arg2 step %arg3 {
- %2 = arith.addi %0, %arg5 : i32
- %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
- memref.store %3, %buffer[] : memref<i32>
- } {__test_sort_original_idx__ = 2}
+// CHECK-LABEL: single_element
+func.func @single_element() {
+ // CHECK: test_sort_index = 0
+ return {test_to_sort}
+}
+
+// -----
+
+// CHECK-LABEL: @simple_region
+func.func @simple_region(%cond: i1) {
+ // CHECK: test_sort_index = 0
+ %0 = arith.constant {test_to_sort} 42 : i32
+ scf.if %cond {
+ %...
[truncated]
|
// Skip regions that do not contain operations from `toSort`. | ||
if (!relevantRegions.contains(®ion)) | ||
continue; | ||
topoSortRegion(subRegion, relevantRegions, toSort, result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you replace this recursion with an iterative algorithm please?
These recursions are ticking bombs...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be good to me once this and Mehdi's comment is addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
8c6a17c
to
9ea5fb9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Looks good modulo nits.
/// constructed with. This function will destroy the internal state of the | ||
/// instance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// constructed with. This function will destroy the internal state of the | |
/// instance. | |
/// constructed with. |
nit: We could implement this but currently it is not implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a purely internal class that is not accessible from the outside. I therefore decided to not add clearing logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG with other folks comments' addressed. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Moved to Analysis on May 22, 2024. See llvm/llvm-project#92563 and llvm/llvm-project@b00e0c1
This PR attempts to consolidate the different topological sort utilities into one place. It adds them to the analysis folder because the
SliceAnalysis
uses some of these.There are now two different sorting strategies:
This additionally reimplements the region aware topological sorting because the previous implementation had an exponential space complexity.
I'm open to suggestions on how to combine this further or how to fuse the test passes.