|
18 | 18 | #include "flang/Optimizer/Builder/Todo.h"
|
19 | 19 | #include "flang/Optimizer/Dialect/FIROps.h"
|
20 | 20 | #include "flang/Optimizer/Dialect/FIRType.h"
|
| 21 | +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
21 | 22 | #include "flang/Optimizer/Support/FatalError.h"
|
22 | 23 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
23 | 24 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
@@ -144,13 +145,133 @@ using AddrGeneratorTy = llvm::function_ref<mlir::Value(
|
144 | 145 | mlir::Value)>;
|
145 | 146 |
|
146 | 147 | // Produces a loop nest for a Minloc intrinsic.
|
147 |
| -void genMinMaxlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array, |
148 |
| - InitValGeneratorTy initVal, |
149 |
| - MinlocBodyOpGeneratorTy genBody, |
150 |
| - fir::AddrGeneratorTy getAddrFn, unsigned rank, |
151 |
| - mlir::Type elementType, mlir::Location loc, |
152 |
| - mlir::Type maskElemType, mlir::Value resultArr, |
153 |
| - bool maskMayBeLogicalScalar); |
| 148 | +inline void genMinMaxlocReductionLoop( |
| 149 | + fir::FirOpBuilder &builder, mlir::Value array, |
| 150 | + fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody, |
| 151 | + fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType, |
| 152 | + mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr, |
| 153 | + bool maskMayBeLogicalScalar) { |
| 154 | + mlir::IndexType idxTy = builder.getIndexType(); |
| 155 | + |
| 156 | + mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); |
| 157 | + |
| 158 | + fir::SequenceType::Shape flatShape(rank, |
| 159 | + fir::SequenceType::getUnknownExtent()); |
| 160 | + mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); |
| 161 | + mlir::Type boxArrTy = fir::BoxType::get(arrTy); |
| 162 | + array = builder.create<fir::ConvertOp>(loc, boxArrTy, array); |
| 163 | + |
| 164 | + mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType()); |
| 165 | + mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1); |
| 166 | + mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0); |
| 167 | + mlir::Value flagRef = builder.createTemporary(loc, resultElemType); |
| 168 | + builder.create<fir::StoreOp>(loc, zero, flagRef); |
| 169 | + |
| 170 | + mlir::Value init = initVal(builder, loc, elementType); |
| 171 | + llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds; |
| 172 | + |
| 173 | + assert(rank > 0 && "rank cannot be zero"); |
| 174 | + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
| 175 | + |
| 176 | + // Compute all the upper bounds before the loop nest. |
| 177 | + // It is not strictly necessary for performance, since the loop nest |
| 178 | + // does not have any store operations and any LICM optimization |
| 179 | + // should be able to optimize the redundancy. |
| 180 | + for (unsigned i = 0; i < rank; ++i) { |
| 181 | + mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); |
| 182 | + auto dims = |
| 183 | + builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx); |
| 184 | + mlir::Value len = dims.getResult(1); |
| 185 | + // We use C indexing here, so len-1 as loopcount |
| 186 | + mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one); |
| 187 | + bounds.push_back(loopCount); |
| 188 | + } |
| 189 | + // Create a loop nest consisting of OP operations. |
| 190 | + // Collect the loops' induction variables into indices array, |
| 191 | + // which will be used in the innermost loop to load the input |
| 192 | + // array's element. |
| 193 | + // The loops are generated such that the innermost loop processes |
| 194 | + // the 0 dimension. |
| 195 | + llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; |
| 196 | + for (unsigned i = rank; 0 < i; --i) { |
| 197 | + mlir::Value step = one; |
| 198 | + mlir::Value loopCount = bounds[i - 1]; |
| 199 | + auto loop = |
| 200 | + builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false, |
| 201 | + /*finalCountValue=*/false, init); |
| 202 | + init = loop.getRegionIterArgs()[0]; |
| 203 | + indices.push_back(loop.getInductionVar()); |
| 204 | + // Set insertion point to the loop body so that the next loop |
| 205 | + // is inserted inside the current one. |
| 206 | + builder.setInsertionPointToStart(loop.getBody()); |
| 207 | + } |
| 208 | + |
| 209 | + // Reverse the indices such that they are ordered as: |
| 210 | + // <dim-0-idx, dim-1-idx, ...> |
| 211 | + std::reverse(indices.begin(), indices.end()); |
| 212 | + mlir::Value reductionVal = |
| 213 | + genBody(builder, loc, elementType, array, flagRef, init, indices); |
| 214 | + |
| 215 | + // Unwind the loop nest and insert ResultOp on each level |
| 216 | + // to return the updated value of the reduction to the enclosing |
| 217 | + // loops. |
| 218 | + for (unsigned i = 0; i < rank; ++i) { |
| 219 | + auto result = builder.create<fir::ResultOp>(loc, reductionVal); |
| 220 | + // Proceed to the outer loop. |
| 221 | + auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); |
| 222 | + reductionVal = loop.getResult(0); |
| 223 | + // Set insertion point after the loop operation that we have |
| 224 | + // just processed. |
| 225 | + builder.setInsertionPointAfter(loop.getOperation()); |
| 226 | + } |
| 227 | + // End of loop nest. The insertion point is after the outermost loop. |
| 228 | + if (maskMayBeLogicalScalar) { |
| 229 | + if (fir::IfOp ifOp = |
| 230 | + mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) { |
| 231 | + builder.create<fir::ResultOp>(loc, reductionVal); |
| 232 | + builder.setInsertionPointAfter(ifOp); |
| 233 | + // Redefine flagSet to escape scope of ifOp |
| 234 | + flagSet = builder.createIntegerConstant(loc, resultElemType, 1); |
| 235 | + reductionVal = ifOp.getResult(0); |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + // Check for case where array was full of max values. |
| 240 | + // flag will be 0 if mask was never true, 1 if mask was true as some point, |
| 241 | + // this is needed to avoid catching cases where we didn't access any elements |
| 242 | + // e.g. mask=.FALSE. |
| 243 | + mlir::Value flagValue = |
| 244 | + builder.create<fir::LoadOp>(loc, resultElemType, flagRef); |
| 245 | + mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>( |
| 246 | + loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet); |
| 247 | + fir::IfOp ifMaskTrueOp = |
| 248 | + builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false); |
| 249 | + builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front()); |
| 250 | + |
| 251 | + mlir::Value testInit = initVal(builder, loc, elementType); |
| 252 | + fir::IfOp ifMinSetOp; |
| 253 | + if (elementType.isa<mlir::FloatType>()) { |
| 254 | + mlir::Value cmp = builder.create<mlir::arith::CmpFOp>( |
| 255 | + loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal); |
| 256 | + ifMinSetOp = builder.create<fir::IfOp>(loc, cmp, |
| 257 | + /*withElseRegion*/ false); |
| 258 | + } else { |
| 259 | + mlir::Value cmp = builder.create<mlir::arith::CmpIOp>( |
| 260 | + loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal); |
| 261 | + ifMinSetOp = builder.create<fir::IfOp>(loc, cmp, |
| 262 | + /*withElseRegion*/ false); |
| 263 | + } |
| 264 | + builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front()); |
| 265 | + |
| 266 | + // Load output array with 1s instead of 0s |
| 267 | + for (unsigned int i = 0; i < rank; ++i) { |
| 268 | + mlir::Value index = builder.createIntegerConstant(loc, idxTy, i); |
| 269 | + mlir::Value resultElemAddr = |
| 270 | + getAddrFn(builder, loc, resultElemType, resultArr, index); |
| 271 | + builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr); |
| 272 | + } |
| 273 | + builder.setInsertionPointAfter(ifMaskTrueOp); |
| 274 | +} |
154 | 275 |
|
155 | 276 | } // namespace fir
|
156 | 277 |
|
|
0 commit comments