@@ -46,9 +46,15 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
46
46
// inequalities.
47
47
IntegerPolyhedron localVarCst;
48
48
49
- AffineExprFlattener (unsigned nDims, unsigned nSymbols)
49
+ AffineExprFlattener (unsigned nDims, unsigned nSymbols,
50
+ bool addConservativeSemiAffineBounds = false )
50
51
: SimpleAffineExprFlattener(nDims, nSymbols),
51
- localVarCst (PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
52
+ localVarCst (PresburgerSpace::getSetSpace(nDims, nSymbols)),
53
+ addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {}
54
+
55
+ bool hasUnhandledSemiAffineExpressions () const {
56
+ return unhandledSemiAffineExpressions;
57
+ }
52
58
53
59
private:
54
60
// Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
@@ -63,35 +69,61 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
63
69
// Update localVarCst.
64
70
localVarCst.addLocalFloorDiv (dividend, divisor);
65
71
}
72
+
73
+ // Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
74
+ // expr) when the rhs is a symbolic expression. The local identifier added
75
+ // may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
76
+ // function of other identifiers, coefficients of which are specified in the
77
+ // lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
78
+ // symbolic rhs expression. `localExpr` is the simplified tree expression
79
+ // (AffineExpr) corresponding to the quantifier.
80
+ void addLocalIdSemiAffine (AffineExpr localExpr, ArrayRef<int64_t > lhs,
81
+ ArrayRef<int64_t > rhs) override {
82
+ SimpleAffineExprFlattener::addLocalIdSemiAffine (localExpr, lhs, rhs);
83
+ if (!addConservativeSemiAffineBounds) {
84
+ unhandledSemiAffineExpressions = true ;
85
+ return ;
86
+ }
87
+ if (localExpr.getKind () == AffineExprKind::Mod) {
88
+ localVarCst.addLocalModConservativeBounds (lhs, rhs);
89
+ return ;
90
+ }
91
+ // TODO: Support other semi-affine expressions.
92
+ unhandledSemiAffineExpressions = true ;
93
+ }
94
+
95
+ bool addConservativeSemiAffineBounds = false ;
96
+ bool unhandledSemiAffineExpressions = false ;
66
97
};
67
98
68
99
} // namespace
69
100
70
101
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
71
102
// flattened. For example two specific cases:
72
- // 1. semi-affine expressions not handled yet .
103
+ // 1. an unhandled semi-affine expressions is found .
73
104
// 2. has poison expression (i.e., division by zero).
74
105
static LogicalResult
75
106
getFlattenedAffineExprs (ArrayRef<AffineExpr> exprs, unsigned numDims,
76
107
unsigned numSymbols,
77
108
std::vector<SmallVector<int64_t , 8 >> *flattenedExprs,
78
- FlatLinearConstraints *localVarCst) {
109
+ FlatLinearConstraints *localVarCst,
110
+ bool addConservativeSemiAffineBounds = false ) {
79
111
if (exprs.empty ()) {
80
112
if (localVarCst)
81
113
*localVarCst = FlatLinearConstraints (numDims, numSymbols);
82
114
return success ();
83
115
}
84
116
85
- AffineExprFlattener flattener (numDims, numSymbols);
117
+ AffineExprFlattener flattener (numDims, numSymbols,
118
+ addConservativeSemiAffineBounds);
86
119
// Use the same flattener to simplify each expression successively. This way
87
120
// local variables / expressions are shared.
88
121
for (auto expr : exprs) {
89
- if (!expr.isPureAffine ())
90
- return failure ();
91
- // has poison expression
92
122
auto flattenResult = flattener.walkPostOrder (expr);
93
123
if (failed (flattenResult))
94
124
return failure ();
125
+ if (flattener.hasUnhandledSemiAffineExpressions ())
126
+ return failure ();
95
127
}
96
128
97
129
assert (flattener.operandExprStack .size () == exprs.size ());
@@ -106,33 +138,33 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
106
138
}
107
139
108
140
// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
109
- // be flattened (semi-affine expressions not handled yet).
110
- LogicalResult
111
- mlir::getFlattenedAffineExpr (AffineExpr expr, unsigned numDims,
112
- unsigned numSymbols,
113
- SmallVectorImpl<int64_t > *flattenedExpr,
114
- FlatLinearConstraints *localVarCst) {
141
+ // be flattened (an unhandled semi-affine was found).
142
+ LogicalResult mlir::getFlattenedAffineExpr (
143
+ AffineExpr expr, unsigned numDims, unsigned numSymbols,
144
+ SmallVectorImpl<int64_t > *flattenedExpr, FlatLinearConstraints *localVarCst,
145
+ bool addConservativeSemiAffineBounds) {
115
146
std::vector<SmallVector<int64_t , 8 >> flattenedExprs;
116
- LogicalResult ret = ::getFlattenedAffineExprs ({expr}, numDims, numSymbols,
117
- &flattenedExprs, localVarCst);
147
+ LogicalResult ret =
148
+ ::getFlattenedAffineExprs ({expr}, numDims, numSymbols, &flattenedExprs,
149
+ localVarCst, addConservativeSemiAffineBounds);
118
150
*flattenedExpr = flattenedExprs[0 ];
119
151
return ret;
120
152
}
121
153
122
154
// / Flattens the expressions in map. Returns failure if 'expr' was unable to be
123
- // / flattened (i.e., semi-affine expressions not handled yet ).
155
+ // / flattened (i.e., an unhandled semi-affine was found ).
124
156
LogicalResult mlir::getFlattenedAffineExprs (
125
157
AffineMap map, std::vector<SmallVector<int64_t , 8 >> *flattenedExprs,
126
- FlatLinearConstraints *localVarCst) {
158
+ FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds ) {
127
159
if (map.getNumResults () == 0 ) {
128
160
if (localVarCst)
129
161
*localVarCst =
130
162
FlatLinearConstraints (map.getNumDims (), map.getNumSymbols ());
131
163
return success ();
132
164
}
133
- return ::getFlattenedAffineExprs (map. getResults (), map. getNumDims (),
134
- map.getNumSymbols (), flattenedExprs,
135
- localVarCst);
165
+ return ::getFlattenedAffineExprs (
166
+ map. getResults (), map. getNumDims (), map.getNumSymbols (), flattenedExprs,
167
+ localVarCst, addConservativeSemiAffineBounds );
136
168
}
137
169
138
170
LogicalResult mlir::getFlattenedAffineExprs (
@@ -641,9 +673,11 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
641
673
}
642
674
643
675
LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals (
644
- AffineMap map, std::vector<SmallVector<int64_t , 8 >> *flattenedExprs) {
676
+ AffineMap map, std::vector<SmallVector<int64_t , 8 >> *flattenedExprs,
677
+ bool addConservativeSemiAffineBounds) {
645
678
FlatLinearConstraints localCst;
646
- if (failed (getFlattenedAffineExprs (map, flattenedExprs, &localCst))) {
679
+ if (failed (getFlattenedAffineExprs (map, flattenedExprs, &localCst,
680
+ addConservativeSemiAffineBounds))) {
647
681
LLVM_DEBUG (llvm::dbgs ()
648
682
<< " composition unimplemented for semi-affine maps\n " );
649
683
return failure ();
@@ -664,9 +698,9 @@ LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
664
698
return success ();
665
699
}
666
700
667
- LogicalResult FlatLinearConstraints::addBound (BoundType type, unsigned pos,
668
- AffineMap boundMap,
669
- bool isClosedBound ) {
701
+ LogicalResult FlatLinearConstraints::addBound (
702
+ BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound ,
703
+ AddConservativeSemiAffineBounds addSemiAffineBounds ) {
670
704
assert (boundMap.getNumDims () == getNumDimVars () && " dim mismatch" );
671
705
assert (boundMap.getNumSymbols () == getNumSymbolVars () && " symbol mismatch" );
672
706
assert (pos < getNumDimAndSymbolVars () && " invalid position" );
@@ -680,7 +714,9 @@ LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
680
714
bool lower = type == BoundType::LB || type == BoundType::EQ;
681
715
682
716
std::vector<SmallVector<int64_t , 8 >> flatExprs;
683
- if (failed (flattenAlignedMapAndMergeLocals (boundMap, &flatExprs)))
717
+ if (failed (flattenAlignedMapAndMergeLocals (
718
+ boundMap, &flatExprs,
719
+ addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
684
720
return failure ();
685
721
assert (flatExprs.size () == boundMap.getNumResults ());
686
722
@@ -716,9 +752,11 @@ LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
716
752
return success ();
717
753
}
718
754
719
- LogicalResult FlatLinearConstraints::addBound (BoundType type, unsigned pos,
720
- AffineMap boundMap) {
721
- return addBound (type, pos, boundMap, /* isClosedBound=*/ type != BoundType::UB);
755
+ LogicalResult FlatLinearConstraints::addBound (
756
+ BoundType type, unsigned pos, AffineMap boundMap,
757
+ AddConservativeSemiAffineBounds addSemiAffineBounds) {
758
+ return addBound (type, pos, boundMap,
759
+ /* isClosedBound=*/ type != BoundType::UB, addSemiAffineBounds);
722
760
}
723
761
724
762
// / Compute an explicit representation for local vars. For all systems coming
0 commit comments