|
| 1 | +//===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ |
| 10 | +#define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ |
| 11 | + |
| 12 | +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 13 | +#include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 14 | + |
| 15 | +//===----------------------------------------------------------------------===// |
| 16 | +// Helpers for Unstructured Control Flow |
| 17 | +//===----------------------------------------------------------------------===// |
| 18 | + |
| 19 | +namespace mlir { |
| 20 | +namespace bufferization { |
| 21 | + |
| 22 | +namespace detail { |
| 23 | +/// Return a list of operands that are forwarded to the given block argument. |
| 24 | +/// I.e., find all predecessors of the block argument's owner and gather the |
| 25 | +/// operands that are equivalent to the block argument. |
| 26 | +SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg); |
| 27 | +} // namespace detail |
| 28 | + |
| 29 | +/// A template that provides a default implementation of `getAliasingOpOperands` |
| 30 | +/// for ops that support unstructured control flow within their regions. |
| 31 | +template <typename ConcreteModel, typename ConcreteOp> |
| 32 | +struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel |
| 33 | + : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> { |
| 34 | + |
| 35 | + FailureOr<BaseMemRefType> |
| 36 | + getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 37 | + SmallVector<Value> &invocationStack) const { |
| 38 | + // Note: The user may want to override this function for OpResults in |
| 39 | + // case the bufferized result type is different from the bufferized type of |
| 40 | + // the aliasing OpOperand (if any). |
| 41 | + if (isa<OpResult>(value)) |
| 42 | + return bufferization::detail::defaultGetBufferType(value, options, |
| 43 | + invocationStack); |
| 44 | + |
| 45 | + // Compute the buffer type of the block argument by computing the bufferized |
| 46 | + // operand types of all forwarded values. If these are all the same type, |
| 47 | + // take that type. Otherwise, take only the memory space and fall back to a |
| 48 | + // buffer type with a fully dynamic layout map. |
| 49 | + BaseMemRefType bufferType; |
| 50 | + auto tensorType = cast<TensorType>(value.getType()); |
| 51 | + for (OpOperand *opOperand : |
| 52 | + detail::getCallerOpOperands(cast<BlockArgument>(value))) { |
| 53 | + |
| 54 | + // If the forwarded operand is already on the invocation stack, we ran |
| 55 | + // into a loop and this operand cannot be used to compute the bufferized |
| 56 | + // type. |
| 57 | + if (llvm::find(invocationStack, opOperand->get()) != |
| 58 | + invocationStack.end()) |
| 59 | + continue; |
| 60 | + |
| 61 | + // Compute the bufferized type of the forwarded operand. |
| 62 | + BaseMemRefType callerType; |
| 63 | + if (auto memrefType = |
| 64 | + dyn_cast<BaseMemRefType>(opOperand->get().getType())) { |
| 65 | + // The operand was already bufferized. Take its type directly. |
| 66 | + callerType = memrefType; |
| 67 | + } else { |
| 68 | + FailureOr<BaseMemRefType> maybeCallerType = |
| 69 | + bufferization::getBufferType(opOperand->get(), options, |
| 70 | + invocationStack); |
| 71 | + if (failed(maybeCallerType)) |
| 72 | + return failure(); |
| 73 | + callerType = *maybeCallerType; |
| 74 | + } |
| 75 | + |
| 76 | + if (!bufferType) { |
| 77 | + // This is the first buffer type that we computed. |
| 78 | + bufferType = callerType; |
| 79 | + continue; |
| 80 | + } |
| 81 | + |
| 82 | + if (bufferType == callerType) |
| 83 | + continue; |
| 84 | + |
| 85 | + // If the computed buffer type does not match the computed buffer type |
| 86 | + // of the earlier forwarded operands, fall back to a buffer type with a |
| 87 | + // fully dynamic layout map. |
| 88 | +#ifndef NDEBUG |
| 89 | + if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) { |
| 90 | + assert(bufferType.hasRank() && callerType.hasRank() && |
| 91 | + "expected ranked memrefs"); |
| 92 | + assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(), |
| 93 | + rankedTensorType.getShape()}) && |
| 94 | + "expected same shape"); |
| 95 | + } else { |
| 96 | + assert(!bufferType.hasRank() && !callerType.hasRank() && |
| 97 | + "expected unranked memrefs"); |
| 98 | + } |
| 99 | +#endif // NDEBUG |
| 100 | + |
| 101 | + if (bufferType.getMemorySpace() != callerType.getMemorySpace()) |
| 102 | + return op->emitOpError("incoming operands of block argument have " |
| 103 | + "inconsistent memory spaces"); |
| 104 | + |
| 105 | + bufferType = getMemRefTypeWithFullyDynamicLayout( |
| 106 | + tensorType, bufferType.getMemorySpace()); |
| 107 | + } |
| 108 | + |
| 109 | + if (!bufferType) |
| 110 | + return op->emitOpError("could not infer buffer type of block argument"); |
| 111 | + |
| 112 | + return bufferType; |
| 113 | + } |
| 114 | + |
| 115 | +protected: |
| 116 | + /// Assuming that `bbArg` is a block argument of a block that belongs to the |
| 117 | + /// given `op`, return all OpOperands of users of this block that are |
| 118 | + /// aliasing with the given block argument. |
| 119 | + AliasingOpOperandList |
| 120 | + getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, |
| 121 | + const AnalysisState &state) const { |
| 122 | + assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg"); |
| 123 | + |
| 124 | + // Gather aliasing OpOperands of all operations (callers) that link to |
| 125 | + // this block. |
| 126 | + AliasingOpOperandList result; |
| 127 | + for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg)) |
| 128 | + result.addAlias( |
| 129 | + {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false}); |
| 130 | + |
| 131 | + return result; |
| 132 | + } |
| 133 | +}; |
| 134 | + |
| 135 | +/// A template that provides a default implementation of `getAliasingValues` |
| 136 | +/// for ops that implement the `BranchOpInterface`. |
| 137 | +template <typename ConcreteModel, typename ConcreteOp> |
| 138 | +struct BranchOpBufferizableOpInterfaceExternalModel |
| 139 | + : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> { |
| 140 | + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 141 | + const AnalysisState &state) const { |
| 142 | + AliasingValueList result; |
| 143 | + auto branchOp = cast<BranchOpInterface>(op); |
| 144 | + auto operandNumber = opOperand.getOperandNumber(); |
| 145 | + |
| 146 | + // Gather aliasing block arguments of blocks to which this op may branch to. |
| 147 | + for (const auto &it : llvm::enumerate(op->getSuccessors())) { |
| 148 | + Block *block = it.value(); |
| 149 | + SuccessorOperands operands = branchOp.getSuccessorOperands(it.index()); |
| 150 | + assert(operands.getProducedOperandCount() == 0 && |
| 151 | + "produced operands not supported"); |
| 152 | + if (operands.getForwardedOperands().empty()) |
| 153 | + continue; |
| 154 | + // The first and last operands that are forwarded to this successor. |
| 155 | + int64_t firstOperandIndex = |
| 156 | + operands.getForwardedOperands().getBeginOperandIndex(); |
| 157 | + int64_t lastOperandIndex = |
| 158 | + firstOperandIndex + operands.getForwardedOperands().size(); |
| 159 | + bool matchingDestination = operandNumber >= firstOperandIndex && |
| 160 | + operandNumber < lastOperandIndex; |
| 161 | + // A branch op may have multiple successors. Find the ones that correspond |
| 162 | + // to this OpOperand. (There is usually only one.) |
| 163 | + if (!matchingDestination) |
| 164 | + continue; |
| 165 | + // Compute the matching block argument of the destination block. |
| 166 | + BlockArgument bbArg = |
| 167 | + block->getArgument(operandNumber - firstOperandIndex); |
| 168 | + result.addAlias( |
| 169 | + {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false}); |
| 170 | + } |
| 171 | + |
| 172 | + return result; |
| 173 | + } |
| 174 | +}; |
| 175 | + |
| 176 | +} // namespace bufferization |
| 177 | +} // namespace mlir |
| 178 | + |
| 179 | +#endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_ |
0 commit comments