12
12
13
13
#include " mlir/Dialect/SCF/Utils/Utils.h"
14
14
#include " mlir/Analysis/SliceAnalysis.h"
15
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
15
16
#include " mlir/Dialect/Arith/IR/Arith.h"
16
17
#include " mlir/Dialect/Arith/Utils/Utils.h"
17
18
#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
671
672
return success ();
672
673
}
673
674
675
+ Range emitNormalizedLoopBoundsForIndexType (RewriterBase &rewriter, Location loc,
676
+ OpFoldResult lb, OpFoldResult ub,
677
+ OpFoldResult step) {
678
+ Range normalizedLoopBounds;
679
+ normalizedLoopBounds.offset = rewriter.getIndexAttr (0 );
680
+ normalizedLoopBounds.stride = rewriter.getIndexAttr (1 );
681
+ AffineExpr s0, s1, s2;
682
+ bindSymbols (rewriter.getContext (), s0, s1, s2);
683
+ AffineExpr e = (s1 - s0).ceilDiv (s2);
684
+ normalizedLoopBounds.size =
685
+ affine::makeComposedFoldedAffineApply (rewriter, loc, e, {lb, ub, step});
686
+ return normalizedLoopBounds;
687
+ }
688
+
674
689
Range mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
675
690
OpFoldResult lb, OpFoldResult ub,
676
691
OpFoldResult step) {
692
+ if (getType (lb) == rewriter.getIndexType ()) {
693
+ return emitNormalizedLoopBoundsForIndexType (rewriter, loc, lb, ub, step);
694
+ }
677
695
// For non-index types, generate `arith` instructions
678
696
// Check if the loop is already known to have a constant zero lower bound or
679
697
// a constant one step.
@@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
714
732
return {newLowerBound, newUpperBound, newStep};
715
733
}
716
734
735
+ static void denormalizeInductionVariableForIndexType (RewriterBase &rewriter,
736
+ Location loc,
737
+ Value normalizedIv,
738
+ OpFoldResult origLb,
739
+ OpFoldResult origStep) {
740
+ AffineExpr d0, s0, s1;
741
+ bindSymbols (rewriter.getContext (), s0, s1);
742
+ bindDims (rewriter.getContext (), d0);
743
+ AffineExpr e = d0 * s1 + s0;
744
+ OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply (
745
+ rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
746
+ Value denormalizedIvVal =
747
+ getValueOrCreateConstantIndexOp (rewriter, loc, denormalizedIv);
748
+ SmallPtrSet<Operation *, 1 > preservedUses;
749
+ // If an `affine.apply` operation is generated for denormalization, the use
750
+ // of `origLb` in those ops must not be replaced. These arent not generated
751
+ // when `orig_lb == 0` and `orig_step == 1`.
752
+ if (!isConstantIntValue (origLb, 0 ) || !isConstantIntValue (origStep, 1 )) {
753
+ if (Operation *preservedUse = denormalizedIvVal.getDefiningOp ()) {
754
+ preservedUses.insert (preservedUse);
755
+ }
756
+ }
757
+ rewriter.replaceAllUsesExcept (normalizedIv, denormalizedIvVal, preservedUses);
758
+ }
759
+
717
760
void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
718
761
Value normalizedIv, OpFoldResult origLb,
719
762
OpFoldResult origStep) {
763
+ if (getType (origLb) == rewriter.getIndexType ()) {
764
+ return denormalizeInductionVariableForIndexType (rewriter, loc, normalizedIv,
765
+ origLb, origStep);
766
+ }
720
767
Value denormalizedIv;
721
768
SmallPtrSet<Operation *, 2 > preserve;
722
769
bool isStepOne = isConstantIntValue (origStep, 1 );
@@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
739
786
rewriter.replaceAllUsesExcept (normalizedIv, denormalizedIv, preserve);
740
787
}
741
788
789
+ static OpFoldResult getProductOfIndexes (RewriterBase &rewriter, Location loc,
790
+ ArrayRef<OpFoldResult> values) {
791
+ assert (!values.empty () && " unexecpted empty array" );
792
+ AffineExpr s0, s1;
793
+ bindSymbols (rewriter.getContext (), s0, s1);
794
+ AffineExpr mul = s0 * s1;
795
+ OpFoldResult products = rewriter.getIndexAttr (1 );
796
+ for (auto v : values) {
797
+ products = affine::makeComposedFoldedAffineApply (
798
+ rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
799
+ }
800
+ return products;
801
+ }
802
+
742
803
// / Helper function to multiply a sequence of values.
743
804
static Value getProductOfIntsOrIndexes (RewriterBase &rewriter, Location loc,
744
805
ArrayRef<Value> values) {
745
806
assert (!values.empty () && " unexpected empty list" );
807
+ if (getType (values.front ()) == rewriter.getIndexType ()) {
808
+ SmallVector<OpFoldResult> ofrs = getAsOpFoldResult (values);
809
+ OpFoldResult product = getProductOfIndexes (rewriter, loc, ofrs);
810
+ return getValueOrCreateConstantIndexOp (rewriter, loc, product);
811
+ }
746
812
std::optional<Value> productOf;
747
813
for (auto v : values) {
748
814
auto vOne = getConstantIntValue (v);
@@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
757
823
if (!productOf) {
758
824
productOf = rewriter
759
825
.create <arith::ConstantOp>(
760
- loc, rewriter.getOneAttr (values.front (). getType ( )))
826
+ loc, rewriter.getOneAttr (getType ( values.front ())))
761
827
.getResult ();
762
828
}
763
829
return productOf.value ();
@@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
774
840
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2 >>
775
841
delinearizeInductionVariable (RewriterBase &rewriter, Location loc,
776
842
Value linearizedIv, ArrayRef<Value> ubs) {
843
+
844
+ if (linearizedIv.getType () == rewriter.getIndexType ()) {
845
+ Operation *delinearizedOp =
846
+ rewriter.create <affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
847
+ ubs);
848
+ auto resultVals = llvm::map_to_vector (
849
+ delinearizedOp->getResults (), [](OpResult r) -> Value { return r; });
850
+ return {resultVals, SmallPtrSet<Operation *, 2 >{delinearizedOp}};
851
+ }
852
+
777
853
SmallVector<Value> delinearizedIvs (ubs.size ());
778
854
SmallPtrSet<Operation *, 2 > preservedUsers;
779
855
0 commit comments