|
| 1 | +//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ |
| 10 | +#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ |
| 11 | + |
| 12 | +#include "mlir/IR/IRMapping.h" |
| 13 | +#include "mlir/IR/PatternMatch.h" |
| 14 | +#include "mlir/IR/Value.h" |
| 15 | +#include "mlir/Support/LLVM.h" |
| 16 | +#include "mlir/Support/LogicalResult.h" |
| 17 | +#include "llvm/ADT/SmallVector.h" |
| 18 | +#include "llvm/Support/Casting.h" |
| 19 | +#include <iterator> |
| 20 | +#include <optional> |
| 21 | +#include <type_traits> |
| 22 | +#include <utility> |
| 23 | + |
| 24 | +namespace mlir { |
| 25 | + |
| 26 | +// If `h` is an homomorphism with respect to the source algebraic structure |
| 27 | +// induced by function `s` and the target algebraic structure induced by |
| 28 | +// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into |
| 29 | +// `h(t(x1, x2, ..., xn))`. |
| 30 | +// |
| 31 | +// Functors: |
| 32 | +// --------- |
| 33 | +// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*` |
| 34 | +// Returns the operand relevant to the homomorphism. |
| 35 | +// There may be other operands that are not relevant. |
| 36 | +// |
| 37 | +// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult` |
| 38 | +// Returns the result relevant to the homomorphism. |
| 39 | +// |
| 40 | +// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> |
| 41 | +// void` Populates into the vector the operands relevant to the homomorphism. |
| 42 | +// |
| 43 | +// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult` |
| 44 | +// Return the result of the source algebraic operation relevant to the |
| 45 | +// homomorphism. |
| 46 | +// |
| 47 | +// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult` |
| 48 | +// Return the result of the target algebraic operation relevant to the |
| 49 | +// homomorphism. |
| 50 | +// |
| 51 | +// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool` |
| 52 | +// Check if the operation is an homomorphism of the required type. |
| 53 | +// Additionally if the optional is present checks if the operations are |
| 54 | +// compatible homomorphisms. |
| 55 | +// |
| 56 | +// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool` |
| 57 | +// Check if the operation is an operation of the algebraic structure. |
| 58 | +// |
| 59 | +// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping, |
| 60 | +// PatternRewriter &rewriter) -> Operation*` |
| 61 | +template <typename GetHomomorphismOpOperandFn, |
| 62 | + typename GetHomomorphismOpResultFn, |
| 63 | + typename GetSourceAlgebraicOpOperandsFn, |
| 64 | + typename GetSourceAlgebraicOpResultFn, |
| 65 | + typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn, |
| 66 | + typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn> |
| 67 | +struct HomomorphismSimplification : public RewritePattern { |
| 68 | + template <typename GetHomomorphismOpOperandFnArg, |
| 69 | + typename GetHomomorphismOpResultFnArg, |
| 70 | + typename GetSourceAlgebraicOpOperandsFnArg, |
| 71 | + typename GetSourceAlgebraicOpResultFnArg, |
| 72 | + typename GetTargetAlgebraicOpResultFnArg, |
| 73 | + typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg, |
| 74 | + typename CreateTargetAlgebraicOpFnArg, |
| 75 | + typename... RewritePatternArgs> |
| 76 | + HomomorphismSimplification( |
| 77 | + GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand, |
| 78 | + GetHomomorphismOpResultFnArg &&getHomomorphismOpResult, |
| 79 | + GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands, |
| 80 | + GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult, |
| 81 | + GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult, |
| 82 | + IsHomomorphismOpFnArg &&isHomomorphismOp, |
| 83 | + IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp, |
| 84 | + CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn, |
| 85 | + RewritePatternArgs &&...args) |
| 86 | + : RewritePattern(std::forward<RewritePatternArgs>(args)...), |
| 87 | + getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>( |
| 88 | + getHomomorphismOpOperand)), |
| 89 | + getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>( |
| 90 | + getHomomorphismOpResult)), |
| 91 | + getSourceAlgebraicOpOperands( |
| 92 | + std::forward<GetSourceAlgebraicOpOperandsFnArg>( |
| 93 | + getSourceAlgebraicOpOperands)), |
| 94 | + getSourceAlgebraicOpResult( |
| 95 | + std::forward<GetSourceAlgebraicOpResultFnArg>( |
| 96 | + getSourceAlgebraicOpResult)), |
| 97 | + getTargetAlgebraicOpResult( |
| 98 | + std::forward<GetTargetAlgebraicOpResultFnArg>( |
| 99 | + getTargetAlgebraicOpResult)), |
| 100 | + isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)), |
| 101 | + isSourceAlgebraicOp( |
| 102 | + std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)), |
| 103 | + createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>( |
| 104 | + createTargetAlgebraicOpFn)) {} |
| 105 | + |
| 106 | + LogicalResult matchAndRewrite(Operation *op, |
| 107 | + PatternRewriter &rewriter) const override { |
| 108 | + SmallVector<OpOperand *> algebraicOpOperands; |
| 109 | + if (failed(matchOp(op, algebraicOpOperands))) { |
| 110 | + return failure(); |
| 111 | + } |
| 112 | + return rewriteOp(op, algebraicOpOperands, rewriter); |
| 113 | + } |
| 114 | + |
| 115 | +private: |
| 116 | + LogicalResult |
| 117 | + matchOp(Operation *sourceAlgebraicOp, |
| 118 | + SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const { |
| 119 | + if (!isSourceAlgebraicOp(sourceAlgebraicOp)) { |
| 120 | + return failure(); |
| 121 | + } |
| 122 | + sourceAlgebraicOpOperands.clear(); |
| 123 | + getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands); |
| 124 | + if (sourceAlgebraicOpOperands.empty()) { |
| 125 | + return failure(); |
| 126 | + } |
| 127 | + |
| 128 | + Operation *firstHomomorphismOp = |
| 129 | + sourceAlgebraicOpOperands.front()->get().getDefiningOp(); |
| 130 | + if (!firstHomomorphismOp || |
| 131 | + !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) { |
| 132 | + return failure(); |
| 133 | + } |
| 134 | + OpResult firstHomomorphismOpResult = |
| 135 | + getHomomorphismOpResult(firstHomomorphismOp); |
| 136 | + if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) { |
| 137 | + return failure(); |
| 138 | + } |
| 139 | + |
| 140 | + for (auto operand : sourceAlgebraicOpOperands) { |
| 141 | + Operation *homomorphismOp = operand->get().getDefiningOp(); |
| 142 | + if (!homomorphismOp || |
| 143 | + !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) { |
| 144 | + return failure(); |
| 145 | + } |
| 146 | + } |
| 147 | + return success(); |
| 148 | + } |
| 149 | + |
| 150 | + LogicalResult |
| 151 | + rewriteOp(Operation *sourceAlgebraicOp, |
| 152 | + const SmallVector<OpOperand *> &sourceAlgebraicOpOperands, |
| 153 | + PatternRewriter &rewriter) const { |
| 154 | + IRMapping irMapping; |
| 155 | + for (auto operand : sourceAlgebraicOpOperands) { |
| 156 | + Operation *homomorphismOp = operand->get().getDefiningOp(); |
| 157 | + irMapping.map(operand->get(), |
| 158 | + getHomomorphismOpOperand(homomorphismOp)->get()); |
| 159 | + } |
| 160 | + Operation *targetAlgebraicOp = |
| 161 | + createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter); |
| 162 | + |
| 163 | + irMapping.clear(); |
| 164 | + assert(!sourceAlgebraicOpOperands.empty()); |
| 165 | + Operation *firstHomomorphismOp = |
| 166 | + sourceAlgebraicOpOperands[0]->get().getDefiningOp(); |
| 167 | + irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(), |
| 168 | + getTargetAlgebraicOpResult(targetAlgebraicOp)); |
| 169 | + Operation *newHomomorphismOp = |
| 170 | + rewriter.clone(*firstHomomorphismOp, irMapping); |
| 171 | + rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp), |
| 172 | + getHomomorphismOpResult(newHomomorphismOp)); |
| 173 | + return success(); |
| 174 | + } |
| 175 | + |
| 176 | + GetHomomorphismOpOperandFn getHomomorphismOpOperand; |
| 177 | + GetHomomorphismOpResultFn getHomomorphismOpResult; |
| 178 | + GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands; |
| 179 | + GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult; |
| 180 | + GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult; |
| 181 | + IsHomomorphismOpFn isHomomorphismOp; |
| 182 | + IsSourceAlgebraicOpFn isSourceAlgebraicOp; |
| 183 | + CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn; |
| 184 | +}; |
| 185 | + |
| 186 | +} // namespace mlir |
| 187 | + |
| 188 | +#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_ |
0 commit comments