12
12
13
13
#include " mlir/Dialect/SCF/Utils/Utils.h"
14
14
#include " mlir/Analysis/SliceAnalysis.h"
15
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
15
16
#include " mlir/Dialect/Arith/IR/Arith.h"
16
17
#include " mlir/Dialect/Arith/Utils/Utils.h"
17
18
#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
671
672
return success ();
672
673
}
673
674
675
+ Range emitNormalizedLoopBoundsForIndexType (RewriterBase &rewriter, Location loc,
676
+ OpFoldResult lb, OpFoldResult ub,
677
+ OpFoldResult step) {
678
+ Range normalizedLoopBounds;
679
+ normalizedLoopBounds.offset = rewriter.getIndexAttr (0 );
680
+ normalizedLoopBounds.stride = rewriter.getIndexAttr (1 );
681
+ AffineExpr s0, s1, s2;
682
+ bindSymbols (rewriter.getContext (), s0, s1, s2);
683
+ AffineExpr e = (s1 - s0).ceilDiv (s2);
684
+ normalizedLoopBounds.size =
685
+ affine::makeComposedFoldedAffineApply (rewriter, loc, e, {lb, ub, step});
686
+ return normalizedLoopBounds;
687
+ }
688
+
674
689
Range mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
675
690
OpFoldResult lb, OpFoldResult ub,
676
691
OpFoldResult step) {
692
+ if (getType (lb) == rewriter.getIndexType ()) {
693
+ return emitNormalizedLoopBoundsForIndexType (rewriter, loc, lb, ub, step);
694
+ }
677
695
// For non-index types, generate `arith` instructions
678
696
// Check if the loop is already known to have a constant zero lower bound or
679
697
// a constant one step.
@@ -714,9 +732,35 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
714
732
return {newLowerBound, newUpperBound, newStep};
715
733
}
716
734
735
+ static void denormalizeInductionVariableForIndexType (RewriterBase &rewriter,
736
+ Location loc,
737
+ Value normalizedIv,
738
+ OpFoldResult origLb,
739
+ OpFoldResult origStep) {
740
+ AffineExpr d0, s0, s1;
741
+ bindSymbols (rewriter.getContext (), s0, s1);
742
+ bindDims (rewriter.getContext (), d0);
743
+ AffineExpr e = d0 * s1 + s0;
744
+ OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply (
745
+ rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
746
+ Value denormalizedIvVal =
747
+ getValueOrCreateConstantIndexOp (rewriter, loc, denormalizedIv);
748
+ SmallPtrSet<Operation *, 1 > preservedUses;
749
+ if (!isConstantIntValue (origLb, 0 ) || !isConstantIntValue (origStep, 1 )) {
750
+ if (Operation *preservedUse = denormalizedIvVal.getDefiningOp ()) {
751
+ preservedUses.insert (preservedUse);
752
+ }
753
+ }
754
+ rewriter.replaceAllUsesExcept (normalizedIv, denormalizedIvVal, preservedUses);
755
+ }
756
+
717
757
void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
718
758
Value normalizedIv, OpFoldResult origLb,
719
759
OpFoldResult origStep) {
760
+ if (getType (origLb) == rewriter.getIndexType ()) {
761
+ return denormalizeInductionVariableForIndexType (rewriter, loc, normalizedIv,
762
+ origLb, origStep);
763
+ }
720
764
Value denormalizedIv;
721
765
SmallPtrSet<Operation *, 2 > preserve;
722
766
bool isStepOne = isConstantIntValue (origStep, 1 );
@@ -739,10 +783,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
739
783
rewriter.replaceAllUsesExcept (normalizedIv, denormalizedIv, preserve);
740
784
}
741
785
786
+ static OpFoldResult getProductOfIndexes (RewriterBase &rewriter, Location loc,
787
+ ArrayRef<OpFoldResult> values) {
788
+ assert (!values.empty () && " unexecpted empty array" );
789
+ AffineExpr s0, s1;
790
+ bindSymbols (rewriter.getContext (), s0, s1);
791
+ AffineExpr mul = s0 * s1;
792
+ OpFoldResult products = rewriter.getIndexAttr (1 );
793
+ for (auto v : values) {
794
+ products = affine::makeComposedFoldedAffineApply (
795
+ rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
796
+ }
797
+ return products;
798
+ }
799
+
742
800
// / Helper function to multiply a sequence of values.
743
801
static Value getProductOfIntsOrIndexes (RewriterBase &rewriter, Location loc,
744
802
ArrayRef<Value> values) {
745
803
assert (!values.empty () && " unexpected empty list" );
804
+ if (getType (values.front ()) == rewriter.getIndexType ()) {
805
+ SmallVector<OpFoldResult> ofrs = getAsOpFoldResult (values);
806
+ OpFoldResult product = getProductOfIndexes (rewriter, loc, ofrs);
807
+ return getValueOrCreateConstantIndexOp (rewriter, loc, product);
808
+ }
746
809
std::optional<Value> productOf;
747
810
for (auto v : values) {
748
811
auto vOne = getConstantIntValue (v);
@@ -757,7 +820,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
757
820
if (!productOf) {
758
821
productOf = rewriter
759
822
.create <arith::ConstantOp>(
760
- loc, rewriter.getOneAttr (values.front (). getType ( )))
823
+ loc, rewriter.getOneAttr (getType ( values.front ())))
761
824
.getResult ();
762
825
}
763
826
return productOf.value ();
@@ -774,6 +837,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
774
837
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2 >>
775
838
delinearizeInductionVariable (RewriterBase &rewriter, Location loc,
776
839
Value linearizedIv, ArrayRef<Value> ubs) {
840
+
841
+ if (linearizedIv.getType () == rewriter.getIndexType ()) {
842
+ Operation *delinearizedOp =
843
+ rewriter.create <affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
844
+ ubs);
845
+ auto resultVals = llvm::map_to_vector (
846
+ delinearizedOp->getResults (), [](OpResult r) -> Value { return r; });
847
+ return {resultVals, SmallPtrSet<Operation *, 2 >{delinearizedOp}};
848
+ }
849
+
777
850
SmallVector<Value> delinearizedIvs (ubs.size ());
778
851
SmallPtrSet<Operation *, 2 > preservedUsers;
779
852
0 commit comments