@@ -16,21 +16,6 @@ using namespace mlir::sparse_tensor::ir_detail;
16
16
// `DimLvlExpr` implementation.
17
17
// ===----------------------------------------------------------------------===//
18
18
19
- Var DimLvlExpr::castAnyVar () const {
20
- assert (expr && " uninitialized DimLvlExpr" );
21
- const auto var = dyn_castAnyVar ();
22
- assert (var && " expected DimLvlExpr to be a Var" );
23
- return *var;
24
- }
25
-
26
- std::optional<Var> DimLvlExpr::dyn_castAnyVar () const {
27
- if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
28
- return SymVar (s);
29
- if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
30
- return Var (getAllowedVarKind (), x);
31
- return std::nullopt;
32
- }
33
-
34
19
SymVar DimLvlExpr::castSymVar () const {
35
20
return SymVar (llvm::cast<AffineSymbolExpr>(expr));
36
21
}
@@ -51,30 +36,6 @@ std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
51
36
return std::nullopt;
52
37
}
53
38
54
- int64_t DimLvlExpr::castConstantValue () const {
55
- return llvm::cast<AffineConstantExpr>(expr).getValue ();
56
- }
57
-
58
- std::optional<int64_t > DimLvlExpr::dyn_castConstantValue () const {
59
- const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
60
- return k ? std::make_optional (k.getValue ()) : std::nullopt;
61
- }
62
-
63
- bool DimLvlExpr::hasConstantValue (int64_t val) const {
64
- const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
65
- return k && k.getValue () == val;
66
- }
67
-
68
- DimLvlExpr DimLvlExpr::getLHS () const {
69
- const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
70
- return DimLvlExpr (kind, binop ? binop.getLHS () : nullptr );
71
- }
72
-
73
- DimLvlExpr DimLvlExpr::getRHS () const {
74
- const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
75
- return DimLvlExpr (kind, binop ? binop.getRHS () : nullptr );
76
- }
77
-
78
39
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
79
40
DimLvlExpr::unpackBinop () const {
80
41
const auto ak = getAffineKind ();
@@ -84,115 +45,6 @@ DimLvlExpr::unpackBinop() const {
84
45
return {lhs, ak, rhs};
85
46
}
86
47
87
- void DimLvlExpr::dump () const {
88
- print (llvm::errs ());
89
- llvm::errs () << " \n " ;
90
- }
91
- std::string DimLvlExpr::str () const {
92
- std::string str;
93
- llvm::raw_string_ostream os (str);
94
- print (os);
95
- return os.str ();
96
- }
97
- void DimLvlExpr::print (AsmPrinter &printer) const {
98
- print (printer.getStream ());
99
- }
100
- void DimLvlExpr::print (llvm::raw_ostream &os) const {
101
- if (!expr)
102
- os << " <<NULL AFFINE EXPR>>" ;
103
- else
104
- printWeak (os);
105
- }
106
-
107
- namespace {
108
- struct MatchNeg final : public std::pair<DimLvlExpr, int64_t > {
109
- using Base = std::pair<DimLvlExpr, int64_t >;
110
- using Base::Base;
111
- constexpr DimLvlExpr getLHS () const { return first; }
112
- constexpr int64_t getRHS () const { return second; }
113
- };
114
- } // namespace
115
-
116
- static std::optional<MatchNeg> matchNeg (DimLvlExpr expr) {
117
- const auto [lhs, op, rhs] = expr.unpackBinop ();
118
- if (op == AffineExprKind::Constant) {
119
- const auto val = expr.castConstantValue ();
120
- if (val < 0 )
121
- return MatchNeg{DimLvlExpr{expr.getExprKind (), AffineExpr ()}, val};
122
- }
123
- if (op == AffineExprKind::Mul)
124
- if (const auto rval = rhs.dyn_castConstantValue (); rval && *rval < 0 )
125
- return MatchNeg{lhs, *rval};
126
- return std::nullopt;
127
- }
128
-
129
- // A heavily revised version of `AsmPrinter::Impl::printAffineExprInternal`.
130
- void DimLvlExpr::printAffineExprInternal (
131
- llvm::raw_ostream &os, BindingStrength enclosingTightness) const {
132
- const char *binopSpelling = nullptr ;
133
- switch (getAffineKind ()) {
134
- case AffineExprKind::SymbolId:
135
- os << castSymVar ();
136
- return ;
137
- case AffineExprKind::DimId:
138
- os << castDimLvlVar ();
139
- return ;
140
- case AffineExprKind::Constant:
141
- os << castConstantValue ();
142
- return ;
143
- case AffineExprKind::Add:
144
- binopSpelling = " + " ; // N.B., this is unused
145
- break ;
146
- case AffineExprKind::Mul:
147
- binopSpelling = " * " ;
148
- break ;
149
- case AffineExprKind::FloorDiv:
150
- binopSpelling = " floordiv " ;
151
- break ;
152
- case AffineExprKind::CeilDiv:
153
- binopSpelling = " ceildiv " ;
154
- break ;
155
- case AffineExprKind::Mod:
156
- binopSpelling = " mod " ;
157
- break ;
158
- }
159
-
160
- if (enclosingTightness == BindingStrength::Strong)
161
- os << ' (' ;
162
-
163
- const auto [lhs, op, rhs] = unpackBinop ();
164
- if (op == AffineExprKind::Mul && rhs.hasConstantValue (-1 )) {
165
- // Pretty print `(lhs * -1)` as "-lhs".
166
- os << ' -' ;
167
- lhs.printStrong (os);
168
- } else if (op != AffineExprKind::Add) {
169
- // Default rule for tightly binding binary operators.
170
- // (Including `Mul` that didn't match the previous rule.)
171
- lhs.printStrong (os);
172
- os << binopSpelling;
173
- rhs.printStrong (os);
174
- } else {
175
- // Combination of all the special rules for addition/subtraction.
176
- lhs.printWeak (os);
177
- const auto rx = matchNeg (rhs);
178
- os << (rx ? " - " : " + " );
179
- const auto &rlhs = rx ? rx->getLHS () : rhs;
180
- const auto rrhs = rx ? rx->getRHS () : -1 ; // value irrelevant when `!rx`
181
- const bool nonunit = rrhs != -1 ; // value irrelevant when `!rx`
182
- const bool isStrong =
183
- rx && rlhs && (nonunit || rlhs.getAffineKind () == AffineExprKind::Add);
184
- if (rlhs)
185
- rlhs.printAffineExprInternal (os, BindingStrength{isStrong});
186
- if (rx && rlhs && nonunit)
187
- os << " * " ;
188
- if (rx && (!rlhs || nonunit))
189
- os << -rrhs;
190
- }
191
-
192
- if (enclosingTightness == BindingStrength::Strong)
193
- os << ' )' ;
194
- }
195
-
196
48
// ===----------------------------------------------------------------------===//
197
49
// `DimSpec` implementation.
198
50
// ===----------------------------------------------------------------------===//
@@ -206,31 +58,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
206
58
return ranks.isValid (var) && (!expr || ranks.isValid (expr));
207
59
}
208
60
209
- void DimSpec::dump () const {
210
- print (llvm::errs (), /* wantElision=*/ false );
211
- llvm::errs () << " \n " ;
212
- }
213
- std::string DimSpec::str (bool wantElision) const {
214
- std::string str;
215
- llvm::raw_string_ostream os (str);
216
- print (os, wantElision);
217
- return os.str ();
218
- }
219
- void DimSpec::print (AsmPrinter &printer, bool wantElision) const {
220
- print (printer.getStream (), wantElision);
221
- }
222
- void DimSpec::print (llvm::raw_ostream &os, bool wantElision) const {
223
- os << var;
224
- if (expr && (!wantElision || !elideExpr))
225
- os << " = " << expr;
226
- if (slice) {
227
- os << " : " ;
228
- // Call `SparseTensorDimSliceAttr::print` directly, to avoid
229
- // printing the mnemonic.
230
- slice.print (os);
231
- }
232
- }
233
-
234
61
// ===----------------------------------------------------------------------===//
235
62
// `LvlSpec` implementation.
236
63
// ===----------------------------------------------------------------------===//
@@ -246,26 +73,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
246
73
return ranks.isValid (var) && ranks.isValid (expr);
247
74
}
248
75
249
- void LvlSpec::dump () const {
250
- print (llvm::errs (), /* wantElision=*/ false );
251
- llvm::errs () << " \n " ;
252
- }
253
- std::string LvlSpec::str (bool wantElision) const {
254
- std::string str;
255
- llvm::raw_string_ostream os (str);
256
- print (os, wantElision);
257
- return os.str ();
258
- }
259
- void LvlSpec::print (AsmPrinter &printer, bool wantElision) const {
260
- print (printer.getStream (), wantElision);
261
- }
262
- void LvlSpec::print (llvm::raw_ostream &os, bool wantElision) const {
263
- if (!wantElision || !elideVar)
264
- os << var << " = " ;
265
- os << expr;
266
- os << " : " << toMLIRString (type);
267
- }
268
-
269
76
// ===----------------------------------------------------------------------===//
270
77
// `DimLvlMap` implementation.
271
78
// ===----------------------------------------------------------------------===//
@@ -334,51 +141,4 @@ AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
334
141
return map;
335
142
}
336
143
337
- void DimLvlMap::dump () const {
338
- print (llvm::errs (), /* wantElision=*/ false );
339
- llvm::errs () << " \n " ;
340
- }
341
- std::string DimLvlMap::str (bool wantElision) const {
342
- std::string str;
343
- llvm::raw_string_ostream os (str);
344
- print (os, wantElision);
345
- return os.str ();
346
- }
347
- void DimLvlMap::print (AsmPrinter &printer, bool wantElision) const {
348
- print (printer.getStream (), wantElision);
349
- }
350
- void DimLvlMap::print (llvm::raw_ostream &os, bool wantElision) const {
351
- // Symbolic identifiers.
352
- // NOTE: Unlike `AffineMap` we place the SymVar bindings before the DimVar
353
- // bindings, since the SymVars may occur within DimExprs and thus this
354
- // ordering helps reduce potential user confusion about the scope of bidings
355
- // (since it means SymVars and DimVars both bind-forward in the usual way,
356
- // whereas only LvlVars have different binding rules).
357
- if (symRank != 0 ) {
358
- os << " [s0" ;
359
- for (unsigned i = 1 ; i < symRank; ++i)
360
- os << " , s" << i;
361
- os << ' ]' ;
362
- }
363
-
364
- // LvlVar forward-declarations.
365
- if (mustPrintLvlVars) {
366
- os << ' {' ;
367
- llvm::interleaveComma (
368
- lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar (); });
369
- os << " } " ;
370
- }
371
-
372
- // Dimension specifiers.
373
- os << ' (' ;
374
- llvm::interleaveComma (
375
- dimSpecs, os, [&](DimSpec const &spec) { spec.print (os, wantElision); });
376
- os << " ) -> (" ;
377
- // Level specifiers.
378
- wantElision = wantElision && !mustPrintLvlVars;
379
- llvm::interleaveComma (
380
- lvlSpecs, os, [&](LvlSpec const &spec) { spec.print (os, wantElision); });
381
- os << ' )' ;
382
- }
383
-
384
144
// ===----------------------------------------------------------------------===//
0 commit comments