@@ -400,54 +400,35 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
400
400
assert (other.getNumSymbols () == getNumSymbolIds () && " symbol mismatch" );
401
401
402
402
std::vector<SmallVector<int64_t , 8 >> flatExprs;
403
- FlatAffineConstraints localCst;
404
- if (failed (getFlattenedAffineExprs (other, &flatExprs, &localCst))) {
405
- LLVM_DEBUG (llvm::dbgs ()
406
- << " composition unimplemented for semi-affine maps\n " );
403
+ if (failed (flattenAlignedMapAndMergeLocals (other, &flatExprs)))
407
404
return failure ();
408
- }
409
405
assert (flatExprs.size () == other.getNumResults ());
410
406
411
- // Add localCst information.
412
- if (localCst.getNumLocalIds () > 0 ) {
413
- unsigned numLocalIds = getNumLocalIds ();
414
- // Insert local dims of localCst at the beginning.
415
- for (unsigned l = 0 , e = localCst.getNumLocalIds (); l < e; ++l)
416
- addLocalId (0 );
417
- // Insert local dims of `this` at the end of localCst.
418
- for (unsigned l = 0 ; l < numLocalIds; ++l)
419
- localCst.addLocalId (localCst.getNumLocalIds ());
420
- // Dimensions of localCst and this constraint set match. Append localCst to
421
- // this constraint set.
422
- append (localCst);
423
- }
424
-
425
407
// Add dimensions corresponding to the map's results.
426
408
for (unsigned t = 0 , e = other.getNumResults (); t < e; t++) {
427
409
addDimId (0 );
428
410
}
429
411
430
412
// We add one equality for each result connecting the result dim of the map to
431
413
// the other identifiers.
432
- // For eg : if the expression is 16*i0 + i1, and this is the r^th
414
+ // E.g. : if the expression is 16*i0 + i1, and this is the r^th
433
415
// iteration/result of the value map, we are adding the equality:
434
- // d_r - 16*i0 - i1 = 0. Hence , when flattening say (i0 + 1, i0 + 8*i2), we
435
- // add two equalities overall : d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
416
+ // d_r - 16*i0 - i1 = 0. Similarly , when flattening (i0 + 1, i0 + 8*i2), we
417
+ // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
436
418
for (unsigned r = 0 , e = flatExprs.size (); r < e; r++) {
437
419
const auto &flatExpr = flatExprs[r];
438
420
assert (flatExpr.size () >= other.getNumInputs () + 1 );
439
421
440
- // eqToAdd is the equality corresponding to the flattened affine expression.
441
422
SmallVector<int64_t , 8 > eqToAdd (getNumCols (), 0 );
442
423
// Set the coefficient for this result to one.
443
424
eqToAdd[r] = 1 ;
444
425
445
426
// Dims and symbols.
446
427
for (unsigned i = 0 , f = other.getNumInputs (); i < f; i++) {
447
- // Negate ' eq[r]' since the newly added dimension will be set to this one.
428
+ // Negate ` eq[r]` since the newly added dimension will be set to this one.
448
429
eqToAdd[e + i] = -flatExpr[i];
449
430
}
450
- // Local vars common to eq and localCst are at the beginning.
431
+ // Local columns of `eq` are at the beginning.
451
432
unsigned j = getNumDimIds () + getNumSymbolIds ();
452
433
unsigned end = flatExpr.size () - 1 ;
453
434
for (unsigned i = other.getNumInputs (); i < end; i++, j++) {
@@ -1872,27 +1853,14 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
1872
1853
}
1873
1854
}
1874
1855
1875
- LogicalResult FlatAffineConstraints::addLowerOrUpperBound (unsigned pos,
1876
- AffineMap boundMap,
1877
- bool eq, bool lower) {
1878
- assert (boundMap.getNumDims () == getNumDimIds () && " dim mismatch" );
1879
- assert (boundMap.getNumSymbols () == getNumSymbolIds () && " symbol mismatch" );
1880
- assert (pos < getNumDimAndSymbolIds () && " invalid position" );
1881
-
1882
- // Equality follows the logic of lower bound except that we add an equality
1883
- // instead of an inequality.
1884
- assert ((!eq || boundMap.getNumResults () == 1 ) && " single result expected" );
1885
- if (eq)
1886
- lower = true ;
1887
-
1888
- std::vector<SmallVector<int64_t , 8 >> flatExprs;
1856
+ LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals (
1857
+ AffineMap map, std::vector<SmallVector<int64_t , 8 >> *flattenedExprs) {
1889
1858
FlatAffineConstraints localCst;
1890
- if (failed (getFlattenedAffineExprs (boundMap, &flatExprs , &localCst))) {
1859
+ if (failed (getFlattenedAffineExprs (map, flattenedExprs , &localCst))) {
1891
1860
LLVM_DEBUG (llvm::dbgs ()
1892
1861
<< " composition unimplemented for semi-affine maps\n " );
1893
1862
return failure ();
1894
1863
}
1895
- assert (flatExprs.size () == boundMap.getNumResults ());
1896
1864
1897
1865
// Add localCst information.
1898
1866
if (localCst.getNumLocalIds () > 0 ) {
@@ -1908,6 +1876,27 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
1908
1876
append (localCst);
1909
1877
}
1910
1878
1879
+ return success ();
1880
+ }
1881
+
1882
+ LogicalResult FlatAffineConstraints::addLowerOrUpperBound (unsigned pos,
1883
+ AffineMap boundMap,
1884
+ bool eq, bool lower) {
1885
+ assert (boundMap.getNumDims () == getNumDimIds () && " dim mismatch" );
1886
+ assert (boundMap.getNumSymbols () == getNumSymbolIds () && " symbol mismatch" );
1887
+ assert (pos < getNumDimAndSymbolIds () && " invalid position" );
1888
+
1889
+ // Equality follows the logic of lower bound except that we add an equality
1890
+ // instead of an inequality.
1891
+ assert ((!eq || boundMap.getNumResults () == 1 ) && " single result expected" );
1892
+ if (eq)
1893
+ lower = true ;
1894
+
1895
+ std::vector<SmallVector<int64_t , 8 >> flatExprs;
1896
+ if (failed (flattenAlignedMapAndMergeLocals (boundMap, &flatExprs)))
1897
+ return failure ();
1898
+ assert (flatExprs.size () == boundMap.getNumResults ());
1899
+
1911
1900
// Add one (in)equality for each result.
1912
1901
for (const auto &flatExpr : flatExprs) {
1913
1902
SmallVector<int64_t > ineq (getNumCols (), 0 );
@@ -1921,7 +1910,7 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
1921
1910
if (ineq[pos] != 0 )
1922
1911
continue ;
1923
1912
ineq[pos] = lower ? 1 : -1 ;
1924
- // Local vars common to eq and localCst are at the beginning.
1913
+ // Local columns of `ineq` are at the beginning.
1925
1914
unsigned j = getNumDimIds () + getNumSymbolIds ();
1926
1915
unsigned end = flatExpr.size () - 1 ;
1927
1916
for (unsigned i = boundMap.getNumInputs (); i < end; i++, j++) {
0 commit comments