@@ -36,6 +36,7 @@ using namespace mlir::edsc::intrinsics;
36
36
using namespace mlir ::linalg;
37
37
38
38
#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE << " ]: " )
39
+
39
40
// ===----------------------------------------------------------------------===//
40
41
// Transformations exposed as rewrite patterns.
41
42
// ===----------------------------------------------------------------------===//
@@ -235,3 +236,177 @@ LogicalResult mlir::linalg::applyStagedPatterns(
235
236
}
236
237
return success ();
237
238
}
239
+
240
+ // / Traverse `e` and return an AffineExpr where all occurrences of `dim` have
241
+ // / been replaced by either:
242
+ // / - `min` if `positivePath` is true when we reach an occurrence of `dim`
243
+ // / - `max` if `positivePath` is true when we reach an occurrence of `dim`
244
+ // / `positivePath` is negated each time we hit a multiplicative or divisive
245
+ // / binary op with a constant negative coefficient.
246
+ static AffineExpr substWithMin (AffineExpr e, AffineExpr dim, AffineExpr min,
247
+ AffineExpr max, bool positivePath = true ) {
248
+ if (e == dim)
249
+ return positivePath ? min : max;
250
+ if (auto bin = e.dyn_cast <AffineBinaryOpExpr>()) {
251
+ AffineExpr lhs = bin.getLHS ();
252
+ AffineExpr rhs = bin.getRHS ();
253
+ if (bin.getKind () == mlir::AffineExprKind::Add)
254
+ return substWithMin (lhs, dim, min, max, positivePath) +
255
+ substWithMin (rhs, dim, min, max, positivePath);
256
+
257
+ auto c1 = bin.getLHS ().dyn_cast <AffineConstantExpr>();
258
+ auto c2 = bin.getRHS ().dyn_cast <AffineConstantExpr>();
259
+ if (c1 && c1.getValue () < 0 )
260
+ return getAffineBinaryOpExpr (
261
+ bin.getKind (), c1, substWithMin (rhs, dim, min, max, !positivePath));
262
+ if (c2 && c2.getValue () < 0 )
263
+ return getAffineBinaryOpExpr (
264
+ bin.getKind (), substWithMin (lhs, dim, min, max, !positivePath), c2);
265
+ return getAffineBinaryOpExpr (
266
+ bin.getKind (), substWithMin (lhs, dim, min, max, positivePath),
267
+ substWithMin (rhs, dim, min, max, positivePath));
268
+ }
269
+ return e;
270
+ }
271
+
272
+ // / Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
273
+ // / `ubVal` to `dims` and `stepVal` to `symbols`.
274
+ // / Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
275
+ // / with positions matching the newly appended values. Substitute occurrences of
276
+ // / `dimExpr` by either the min expression (i.e. `%lb`) or the max expression
277
+ // / (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether
278
+ // / the induction variable is used with a positive or negative coefficient.
279
+ static AffineExpr substituteLoopInExpr (AffineExpr expr, AffineExpr dimExpr,
280
+ Value lbVal, Value ubVal, Value stepVal,
281
+ SmallVectorImpl<Value> &dims,
282
+ SmallVectorImpl<Value> &symbols) {
283
+ MLIRContext *ctx = lbVal.getContext ();
284
+ AffineExpr lb = getAffineDimExpr (dims.size (), ctx);
285
+ dims.push_back (lbVal);
286
+ AffineExpr ub = getAffineDimExpr (dims.size (), ctx);
287
+ dims.push_back (ubVal);
288
+ AffineExpr step = getAffineSymbolExpr (symbols.size (), ctx);
289
+ symbols.push_back (stepVal);
290
+ LLVM_DEBUG (DBGS () << " Before: " << expr << " \n " );
291
+ AffineExpr ee = substWithMin (expr, dimExpr, lb,
292
+ lb + step * ((ub - 1 ) - lb).floorDiv (step));
293
+ LLVM_DEBUG (DBGS () << " After: " << expr << " \n " );
294
+ return ee;
295
+ }
296
+
297
+ // / Traverse the `dims` and substitute known min or max expressions in place of
298
+ // / induction variables in `exprs`.
299
+ static AffineMap substitute (AffineMap map, SmallVectorImpl<Value> &dims,
300
+ SmallVectorImpl<Value> &symbols) {
301
+ auto exprs = llvm::to_vector<4 >(map.getResults ());
302
+ for (AffineExpr &expr : exprs) {
303
+ bool substituted = true ;
304
+ while (substituted) {
305
+ substituted = false ;
306
+ for (unsigned dimIdx = 0 ; dimIdx < dims.size (); ++dimIdx) {
307
+ Value dim = dims[dimIdx];
308
+ AffineExpr dimExpr = getAffineDimExpr (dimIdx, expr.getContext ());
309
+ LLVM_DEBUG (DBGS () << " Subst: " << dim << " @ " << dimExpr << " \n " );
310
+ AffineExpr substitutedExpr;
311
+ if (auto forOp = scf::getForInductionVarOwner (dim))
312
+ substitutedExpr = substituteLoopInExpr (
313
+ expr, dimExpr, forOp.lowerBound (), forOp.upperBound (),
314
+ forOp.step (), dims, symbols);
315
+
316
+ if (auto parallelForOp = scf::getParallelForInductionVarOwner (dim))
317
+ for (unsigned idx = 0 , e = parallelForOp.getNumLoops (); idx < e;
318
+ ++idx)
319
+ substitutedExpr = substituteLoopInExpr (
320
+ expr, dimExpr, parallelForOp.lowerBound ()[idx],
321
+ parallelForOp.upperBound ()[idx], parallelForOp.step ()[idx],
322
+ dims, symbols);
323
+
324
+ if (!substitutedExpr)
325
+ continue ;
326
+
327
+ substituted = (substitutedExpr != expr);
328
+ expr = substitutedExpr;
329
+ }
330
+ }
331
+
332
+ // Cleanup and simplify the results.
333
+ // This needs to happen outside of the loop iterating on dims.size() since
334
+ // it modifies dims.
335
+ SmallVector<Value, 4 > operands (dims.begin (), dims.end ());
336
+ operands.append (symbols.begin (), symbols.end ());
337
+ auto map = AffineMap::get (dims.size (), symbols.size (), exprs,
338
+ exprs.front ().getContext ());
339
+
340
+ LLVM_DEBUG (DBGS () << " Map to simplify: " << map << " \n " );
341
+
342
+ // Pull in affine.apply operations and compose them fully into the
343
+ // result.
344
+ fullyComposeAffineMapAndOperands (&map, &operands);
345
+ canonicalizeMapAndOperands (&map, &operands);
346
+ map = simplifyAffineMap (map);
347
+ // Assign the results.
348
+ exprs.assign (map.getResults ().begin (), map.getResults ().end ());
349
+ dims.assign (operands.begin (), operands.begin () + map.getNumDims ());
350
+ symbols.assign (operands.begin () + map.getNumDims (), operands.end ());
351
+
352
+ LLVM_DEBUG (DBGS () << " Map simplified: " << map << " \n " );
353
+ }
354
+
355
+ assert (!exprs.empty () && " Unexpected empty exprs" );
356
+ return AffineMap::get (dims.size (), symbols.size (), exprs, map.getContext ());
357
+ }
358
+
359
+ LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite (
360
+ AffineMinOp minOp, PatternRewriter &rewriter) const {
361
+ LLVM_DEBUG (DBGS () << " Canonicalize AffineMinSCF: " << *minOp.getOperation ()
362
+ << " \n " );
363
+
364
+ SmallVector<Value, 4 > dims (minOp.getDimOperands ()),
365
+ symbols (minOp.getSymbolOperands ());
366
+ AffineMap map = substitute (minOp.getAffineMap (), dims, symbols);
367
+
368
+ LLVM_DEBUG (DBGS () << " Resulting map: " << map << " \n " );
369
+
370
+ // Check whether any of the expressions, when subtracted from all other
371
+ // expressions, produces only >= 0 constants. If so, it is the min.
372
+ for (auto e : minOp.getAffineMap ().getResults ()) {
373
+ LLVM_DEBUG (DBGS () << " Candidate min: " << e << " \n " );
374
+ if (!e.isSymbolicOrConstant ())
375
+ continue ;
376
+
377
+ auto isNonPositive = [](AffineExpr e) {
378
+ if (auto cst = e.dyn_cast <AffineConstantExpr>())
379
+ return cst.getValue () < 0 ;
380
+ return true ;
381
+ };
382
+
383
+ // Build the subMap and check everything is statically known to be
384
+ // positive.
385
+ SmallVector<AffineExpr, 4 > subExprs;
386
+ subExprs.reserve (map.getNumResults ());
387
+ for (auto ee : map.getResults ())
388
+ subExprs.push_back (ee - e);
389
+ MLIRContext *ctx = minOp.getContext ();
390
+ AffineMap subMap = simplifyAffineMap (
391
+ AffineMap::get (map.getNumDims (), map.getNumSymbols (), subExprs, ctx));
392
+ LLVM_DEBUG (DBGS () << " simplified subMap: " << subMap << " \n " );
393
+ if (llvm::any_of (subMap.getResults (), isNonPositive))
394
+ continue ;
395
+
396
+ // Static min found.
397
+ if (auto cst = e.dyn_cast <AffineConstantExpr>()) {
398
+ rewriter.replaceOpWithNewOp <ConstantIndexOp>(minOp, cst.getValue ());
399
+ } else {
400
+ auto resultMap = AffineMap::get (0 , map.getNumSymbols (), {e}, ctx);
401
+ SmallVector<Value, 4 > resultOperands = dims;
402
+ resultOperands.append (symbols.begin (), symbols.end ());
403
+ canonicalizeMapAndOperands (&resultMap, &resultOperands);
404
+ resultMap = simplifyAffineMap (resultMap);
405
+ rewriter.replaceOpWithNewOp <AffineApplyOp>(minOp, resultMap,
406
+ resultOperands);
407
+ }
408
+ return success ();
409
+ }
410
+
411
+ return failure ();
412
+ }
0 commit comments