|
1 | 1 | #include "flang/Optimizer/Dialect/FIRDialect.h"
|
2 | 2 | #include "flang/Optimizer/Dialect/FIROps.h"
|
3 | 3 | #include "flang/Optimizer/Dialect/FIRType.h"
|
| 4 | +#include "flang/Optimizer/HLFIR/HLFIROps.h" |
4 | 5 | #include "flang/Optimizer/Support/InternalNames.h"
|
5 | 6 | #include "flang/Optimizer/Transforms/Passes.h"
|
6 | 7 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
@@ -99,6 +100,20 @@ class OMPEarlyOutliningPass
|
99 | 100 | return;
|
100 | 101 | }
|
101 | 102 |
|
| 103 | + // Clone into the outlined function all hlfir.declare ops that define inputs |
| 104 | + // to the target region and set up remapping of its inputs and outputs. |
| 105 | + if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>( |
| 106 | + varPtr.getDefiningOp())) { |
| 107 | + auto clone = llvm::cast<hlfir::DeclareOp>( |
| 108 | + cloneArgAndChildren(builder, declareOp, inputs, newInputs)); |
| 109 | + mlir::Value newBase = clone.getBase(); |
| 110 | + mlir::Value newOrigBase = clone.getOriginalBase(); |
| 111 | + mapInfoMap.map(varPtr, newOrigBase); |
| 112 | + valueMap.map(declareOp.getBase(), newBase); |
| 113 | + valueMap.map(declareOp.getOriginalBase(), newOrigBase); |
| 114 | + return; |
| 115 | + } |
| 116 | + |
102 | 117 | if (isAddressOfGlobalDeclareTarget(varPtr)) {
|
103 | 118 | fir::AddrOfOp addrOp =
|
104 | 119 | mlir::dyn_cast<fir::AddrOfOp>(varPtr.getDefiningOp());
|
@@ -127,19 +142,46 @@ class OMPEarlyOutliningPass
|
127 | 142 | llvm::SetVector<mlir::Value> inputs;
|
128 | 143 | mlir::Region &targetRegion = targetOp.getRegion();
|
129 | 144 | mlir::getUsedValuesDefinedAbove(targetRegion, inputs);
|
130 |
| - |
131 |
| - // filter out declareTarget and map entries which are specially handled |
| 145 | + |
| 146 | + // Collect all map info. Even non-used maps must be collected to avoid ICEs. |
| 147 | + for (mlir::Value oper : targetOp->getOperands()) { |
| 148 | + if (auto mapEntry = |
| 149 | + mlir::dyn_cast<mlir::omp::MapInfoOp>(oper.getDefiningOp())) { |
| 150 | + if (!inputs.contains(mapEntry.getVarPtr())) |
| 151 | + inputs.insert(mapEntry.getVarPtr()); |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + // Filter out declare-target and map entries which are specially handled |
132 | 156 | // at the moment, so we do not wish these to end up as function arguments
|
133 | 157 | // which would just be more noise in the IR.
|
| 158 | + llvm::SmallVector<mlir::Value> blockArgs; |
134 | 159 | for (llvm::SetVector<mlir::Value>::iterator iter = inputs.begin(); iter != inputs.end();) {
|
135 | 160 | if (mlir::isa_and_nonnull<mlir::omp::MapInfoOp>(iter->getDefiningOp()) ||
|
136 | 161 | isAddressOfGlobalDeclareTarget(*iter)) {
|
137 | 162 | iter = inputs.erase(iter);
|
| 163 | + } else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>( |
| 164 | + iter->getDefiningOp())) { |
| 165 | + // Gather hlfir.declare arguments to be added later, after the |
| 166 | + // hlfir.declare operation itself has been removed as an input. |
| 167 | + blockArgs.push_back(declareOp.getMemref()); |
| 168 | + if (mlir::Value shape = declareOp.getShape()) |
| 169 | + blockArgs.push_back(shape); |
| 170 | + for (mlir::Value typeParam : declareOp.getTypeparams()) |
| 171 | + blockArgs.push_back(typeParam); |
| 172 | + iter = inputs.erase(iter); |
138 | 173 | } else {
|
139 | 174 | ++iter;
|
140 | 175 | }
|
141 | 176 | }
|
142 | 177 |
|
| 178 | + // Add function arguments to the list of inputs if they are used by an |
| 179 | + // hlfir.declare operation. |
| 180 | + for (mlir::Value arg : blockArgs) { |
| 181 | + if (!arg.getDefiningOp() && !inputs.contains(arg)) |
| 182 | + inputs.insert(arg); |
| 183 | + } |
| 184 | + |
143 | 185 | // Create new function and initialize
|
144 | 186 | mlir::FunctionType funcType = builder.getFunctionType(
|
145 | 187 | mlir::TypeRange(inputs.getArrayRef()), mlir::TypeRange());
|
@@ -218,7 +260,7 @@ class OMPEarlyOutliningPass
|
218 | 260 | return newFunc;
|
219 | 261 | }
|
220 | 262 |
|
221 |
| - // Returns true if a target region was found int the function. |
| 263 | + // Returns true if a target region was found in the function. |
222 | 264 | bool outlineTargetOps(mlir::OpBuilder &builder,
|
223 | 265 | mlir::func::FuncOp &functionOp,
|
224 | 266 | mlir::ModuleOp &moduleOp,
|
|
0 commit comments