Skip to content

[MLIR][Affine] NFC. Move misplaced MDG init method #71665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -104,6 +106,127 @@ void Node::getLoadAndStoreMemrefSet(
}
}

// Initializes the data dependence graph by walking operations in `block`.
// Assigns each node in the graph a node id based on program order in 'f'.
bool MemRefDependenceGraph::init() {
LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;

DenseMap<Operation *, unsigned> forToNodeMap;
for (Operation &op : block) {
if (dyn_cast<AffineForOp>(op)) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.collect(&op);
// Return false if a region holding op other than 'affine.for' and
// 'affine.if' was found (not currently supported).
if (collector.hasNonAffineRegionOp)
return false;
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
nodes.insert({node.id, node});
} else if (dyn_cast<AffineReadOpInterface>(op)) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
auto memref = cast<AffineReadOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (dyn_cast<AffineWriteOpInterface>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
// Return false if another region is found (not currently supported).
return false;
} else if (op.getNumResults() > 0 && !op.use_empty()) {
// Create graph node for top-level producer of SSA values, which
// could be used by loop nest nodes.
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
} else if (isa<CallOpInterface>(op)) {
// Create graph node for top-level Call Op that takes any argument of
// memref type. Call Op that returns one or more memref type results
// is already taken care of, by the previous conditions.
if (llvm::any_of(op.getOperandTypes(),
[&](Type t) { return isa<MemRefType>(t); })) {
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
}
} else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) {
// Create graph node for top-level op, which could have a memory write
// side effect.
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
}
}

for (auto &idAndNode : nodes) {
LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
<< *(idAndNode.second.op) << "\n");
(void)idAndNode;
}

// Add dependence edges between nodes which produce SSA values and their
// users. Load ops can be considered as the ones producing SSA values.
for (auto &idAndNode : nodes) {
const Node &node = idAndNode.second;
// Stores don't define SSA values, skip them.
if (!node.stores.empty())
continue;
Operation *opInst = node.op;
for (Value value : opInst->getResults()) {
for (Operation *user : value.getUsers()) {
// Ignore users outside of the block.
if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
&block)
continue;
SmallVector<AffineForOp, 4> loops;
getAffineForIVs(*user, &loops);
if (loops.empty())
continue;
assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
unsigned userLoopNestId = forToNodeMap[loops[0]];
addEdge(node.id, userLoopNestId, value);
}
}
}

// Walk memref access lists and add graph edges between dependent nodes.
for (auto &memrefAndList : memrefAccesses) {
unsigned n = memrefAndList.second.size();
for (unsigned i = 0; i < n; ++i) {
unsigned srcId = memrefAndList.second[i];
bool srcHasStore =
getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
for (unsigned j = i + 1; j < n; ++j) {
unsigned dstId = memrefAndList.second[j];
bool dstHasStore =
getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
if (srcHasStore || dstHasStore)
addEdge(srcId, dstId, memrefAndList.first);
}
}
}
return true;
}

// Returns the graph node for 'id'.
Node *MemRefDependenceGraph::getNode(unsigned id) {
auto it = nodes.find(id);
Expand Down
122 changes: 0 additions & 122 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -230,127 +229,6 @@ static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
}
}

// Initializes the data dependence graph by walking operations in `block`.
// Assigns each node in the graph a node id based on program order in 'f'.
bool MemRefDependenceGraph::init() {
LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;

DenseMap<Operation *, unsigned> forToNodeMap;
for (Operation &op : block) {
if (dyn_cast<AffineForOp>(op)) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.collect(&op);
// Return false if a region holding op other than 'affine.for' and
// 'affine.if' was found (not currently supported).
if (collector.hasNonAffineRegionOp)
return false;
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
nodes.insert({node.id, node});
} else if (dyn_cast<AffineReadOpInterface>(op)) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
auto memref = cast<AffineReadOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (dyn_cast<AffineWriteOpInterface>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
// Return false if another region is found (not currently supported).
return false;
} else if (op.getNumResults() > 0 && !op.use_empty()) {
// Create graph node for top-level producer of SSA values, which
// could be used by loop nest nodes.
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
} else if (isa<CallOpInterface>(op)) {
// Create graph node for top-level Call Op that takes any argument of
// memref type. Call Op that returns one or more memref type results
// is already taken care of, by the previous conditions.
if (llvm::any_of(op.getOperandTypes(),
[&](Type t) { return isa<MemRefType>(t); })) {
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
}
} else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) {
// Create graph node for top-level op, which could have a memory write
// side effect.
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
}
}

for (auto &idAndNode : nodes) {
LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
<< *(idAndNode.second.op) << "\n");
(void)idAndNode;
}

// Add dependence edges between nodes which produce SSA values and their
// users. Load ops can be considered as the ones producing SSA values.
for (auto &idAndNode : nodes) {
const Node &node = idAndNode.second;
// Stores don't define SSA values, skip them.
if (!node.stores.empty())
continue;
Operation *opInst = node.op;
for (Value value : opInst->getResults()) {
for (Operation *user : value.getUsers()) {
// Ignore users outside of the block.
if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
&block)
continue;
SmallVector<AffineForOp, 4> loops;
getAffineForIVs(*user, &loops);
if (loops.empty())
continue;
assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
unsigned userLoopNestId = forToNodeMap[loops[0]];
addEdge(node.id, userLoopNestId, value);
}
}
}

// Walk memref access lists and add graph edges between dependent nodes.
for (auto &memrefAndList : memrefAccesses) {
unsigned n = memrefAndList.second.size();
for (unsigned i = 0; i < n; ++i) {
unsigned srcId = memrefAndList.second[i];
bool srcHasStore =
getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
for (unsigned j = i + 1; j < n; ++j) {
unsigned dstId = memrefAndList.second[j];
bool dstHasStore =
getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
if (srcHasStore || dstHasStore)
addEdge(srcId, dstId, memrefAndList.first);
}
}
}
return true;
}

// Sinks all sequential loops to the innermost levels (while preserving
// relative order among them) and moves all parallel loops to the
// outermost (while again preserving relative order among them).
Expand Down