Skip to content

Commit 34ed07e

Browse files
[mlir][sparse] Parser cleanup (#69792)
Removed TODOs, FIXMEs and long notes that are more suited for design doc.
1 parent 70982ef commit 34ed07e

File tree

8 files changed

+31
-351
lines changed

8 files changed

+31
-351
lines changed

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,6 @@ std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
6060
return k ? std::make_optional(k.getValue()) : std::nullopt;
6161
}
6262

63-
// This helper method is akin to `AffineExpr::operator==(int64_t)`
64-
// except it uses a different implementation, namely the implementation
65-
// used within `AsmPrinter::Impl::printAffineExprInternal`.
66-
//
67-
// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
68-
// this implementation because it avoids constructing the intermediate
69-
// `AffineConstantExpr(val)` and thus should in theory be a bit faster.
70-
// However, if it is indeed faster, then the `AffineExpr::operator==`
71-
// method should be updated to do this instead. And if it isn't any
72-
// faster, then we should be using `AffineExpr::operator==` instead.
7363
bool DimLvlExpr::hasConstantValue(int64_t val) const {
7464
const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
7565
return k && k.getValue() == val;
@@ -216,12 +206,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
216206
return ranks.isValid(var) && (!expr || ranks.isValid(expr));
217207
}
218208

219-
bool DimSpec::isFunctionOf(VarSet const &vars) const {
220-
return vars.occursIn(expr);
221-
}
222-
223-
void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
224-
225209
void DimSpec::dump() const {
226210
print(llvm::errs(), /*wantElision=*/false);
227211
llvm::errs() << "\n";
@@ -262,12 +246,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
262246
return ranks.isValid(var) && ranks.isValid(expr);
263247
}
264248

265-
bool LvlSpec::isFunctionOf(VarSet const &vars) const {
266-
return vars.occursIn(expr);
267-
}
268-
269-
void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
270-
271249
void LvlSpec::dump() const {
272250
print(llvm::errs(), /*wantElision=*/false);
273251
llvm::errs() << "\n";
@@ -301,19 +279,6 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
301279
// below cannot cause OOB errors.
302280
assert(isWF());
303281

304-
// TODO: Second, we need to infer/validate the `lvlToDim` mapping.
305-
// Along the way we should set every `DimSpec::elideExpr` according
306-
// to whether the given expression is inferable or not. Notably, this
307-
// needs to happen before the code for setting every `LvlSpec::elideVar`,
308-
// since if the LvlVar is only used in elided DimExpr, then the
309-
// LvlVar should also be elided.
310-
// NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
311-
// to ensure that we maintain the invariant established by `isWF` above.
312-
313-
// Third, we set every `LvlSpec::elideVar` according to whether that
314-
// LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
315-
// NOTE: The invariant established by `isWF` ensures that the following
316-
// calls to `VarSet::add` cannot raise OOB errors.
317282
VarSet usedVars(getRanks());
318283
for (const auto &dimSpec : dimSpecs)
319284
if (!dimSpec.canElideExpr())

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
9-
// be named as the storage representation of the parameter for the tblgen
10-
// defn of STEA. We may well need to make the other classes public too,
11-
// so that the rest of the compiler can use them when necessary.
12-
//===----------------------------------------------------------------------===//
138

149
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
1510
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
@@ -23,16 +18,8 @@ namespace sparse_tensor {
2318
namespace ir_detail {
2419

2520
//===----------------------------------------------------------------------===//
26-
// TODO(wrengr): Give this enum a better name, so that it fits together
27-
// with the name of the `DimLvlExpr` class (which may also want a better
28-
// name). Perhaps make this a nested-type too.
29-
//
30-
// NOTE: In the future we will extend this enum to include "counting
31-
// expressions" required for supporting ITPACK/ELL. Therefore the current
32-
// underlying-type and representation values should not be relied upon.
3321
enum class ExprKind : bool { Dimension = false, Level = true };
3422

35-
// TODO(wrengr): still needs a better name....
3623
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
3724
using VK = std::underlying_type_t<VarKind>;
3825
return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
@@ -41,19 +28,8 @@ static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
4128
getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
4229

4330
//===----------------------------------------------------------------------===//
44-
// TODO(wrengr): The goal of this class is to capture a proof that
45-
// we've verified that the given `AffineExpr` only has variables of the
46-
// appropriate kind(s). So we need to actually prove/verify that in the
47-
// ctor or all its callsites!
4831
class DimLvlExpr {
4932
private:
50-
// FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
51-
// the `kind` field should be private and const. However, beware
52-
// that if we mark any field as `const` or if the fields have differing
53-
// `private`/`protected` privileges then the `IsZeroCostAbstraction`
54-
// assertion will fail!
55-
// (Also, iirc, if we end up moving the `expr` to the subclasses
56-
// instead, that'll also cause `IsZeroCostAbstraction` to fail.)
5733
ExprKind kind;
5834
AffineExpr expr;
5935

@@ -100,11 +76,6 @@ class DimLvlExpr {
10076
//
10177
// Getters for handling `AffineExpr` subclasses.
10278
//
103-
// TODO(wrengr): is there any way to make these typesafe without too much
104-
// templating?
105-
// TODO(wrengr): Most if not all of these don't actually need to be
106-
// methods, they could be free-functions instead.
107-
//
10879
Var castAnyVar() const;
10980
std::optional<Var> dyn_castAnyVar() const;
11081
SymVar castSymVar() const;
@@ -131,9 +102,6 @@ class DimLvlExpr {
131102
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
132103
enum class BindingStrength : bool { Weak = false, Strong = true };
133104

134-
// TODO(wrengr): Does our version of `printAffineExprInternal` really
135-
// need to be a method, or could it be a free-function instead? (assuming
136-
// `BindingStrength` goes with it).
137105
void printAffineExprInternal(llvm::raw_ostream &os,
138106
BindingStrength enclosingTightness) const;
139107
void printStrong(llvm::raw_ostream &os) const {
@@ -145,12 +113,7 @@ class DimLvlExpr {
145113
};
146114
static_assert(IsZeroCostAbstraction<DimLvlExpr>);
147115

148-
// FUTURE_CL(wrengr): It would be nice to have the subclasses override
149-
// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
150-
// the proper covariant return types.
151-
//
152116
class DimExpr final : public DimLvlExpr {
153-
// FIXME(wrengr): These two are needed for the current RTTI implementation.
154117
friend class DimLvlExpr;
155118
constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
156119

@@ -170,7 +133,6 @@ class DimExpr final : public DimLvlExpr {
170133
static_assert(IsZeroCostAbstraction<DimExpr>);
171134

172135
class LvlExpr final : public DimLvlExpr {
173-
// FIXME(wrengr): These two are needed for the current RTTI implementation.
174136
friend class DimLvlExpr;
175137
constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
176138

@@ -189,7 +151,6 @@ class LvlExpr final : public DimLvlExpr {
189151
};
190152
static_assert(IsZeroCostAbstraction<LvlExpr>);
191153

192-
// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
193154
template <typename U>
194155
constexpr bool DimLvlExpr::isa() const {
195156
if constexpr (std::is_same_v<U, DimExpr>)
@@ -247,18 +208,12 @@ class DimSpec final {
247208
/// the result of this predicate.
248209
[[nodiscard]] bool isValid(Ranks const &ranks) const;
249210

250-
// TODO(wrengr): Use it or loose it.
251-
bool isFunctionOf(Var var) const;
252-
bool isFunctionOf(VarSet const &vars) const;
253-
void getFreeVars(VarSet &vars) const;
254-
255211
std::string str(bool wantElision = true) const;
256212
void print(llvm::raw_ostream &os, bool wantElision = true) const;
257213
void print(AsmPrinter &printer, bool wantElision = true) const;
258214
void dump() const;
259215
};
260-
// Although this class is more than just a newtype/wrapper, we do want
261-
// to ensure that storing them into `SmallVector` is efficient.
216+
262217
static_assert(IsZeroCostAbstraction<DimSpec>);
263218

264219
//===----------------------------------------------------------------------===//
@@ -270,13 +225,6 @@ class LvlSpec final {
270225
/// whereas the `DimLvlMap` ctor will reset this as appropriate.
271226
bool elideVar = false;
272227
/// The level-expression.
273-
//
274-
// NOTE: For now we use `LvlExpr` because all level-expressions must be
275-
// `AffineExpr`; however, in the future we will also want to allow "counting
276-
// expressions", and potentially other kinds of non-affine level-expressions.
277-
// Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
278-
// so we may consider defining another class for pairing those two together
279-
// to ensure that the pair is well-formed.
280228
LvlExpr expr;
281229
/// The level-type (== level-format + lvl-properties).
282230
DimLevelType type;
@@ -298,23 +246,14 @@ class LvlSpec final {
298246

299247
/// Checks whether the variables bound/used by this spec are valid
300248
/// with respect to the given ranks.
301-
//
302-
// NOTE: Once we introduce "counting expressions" this will need
303-
// a more sophisticated implementation than `DimSpec::isValid` does.
304249
[[nodiscard]] bool isValid(Ranks const &ranks) const;
305250

306-
// TODO(wrengr): Use it or loose it.
307-
bool isFunctionOf(Var var) const;
308-
bool isFunctionOf(VarSet const &vars) const;
309-
void getFreeVars(VarSet &vars) const;
310-
311251
std::string str(bool wantElision = true) const;
312252
void print(llvm::raw_ostream &os, bool wantElision = true) const;
313253
void print(AsmPrinter &printer, bool wantElision = true) const;
314254
void dump() const;
315255
};
316-
// Although this class is more than just a newtype/wrapper, we do want
317-
// to ensure that storing them into `SmallVector` is efficient.
256+
318257
static_assert(IsZeroCostAbstraction<LvlSpec>);
319258

320259
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -44,42 +44,20 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
4444
VarInfo::ID &varID,
4545
bool &didCreate) {
4646
// Save the current location so that we can have error messages point to
47-
// the right place. Note that `Parser::emitWrongTokenError` starts off
48-
// with the same location as `AsmParserImpl::getCurrentLocation` returns;
49-
// however, `Parser` will then do some various munging with the location
50-
// before actually using it, so `AsmParser::emitError` can't quite be used
51-
// as a drop-in replacement for `Parser::emitWrongTokenError`.
47+
// the right place.
5248
const auto loc = parser.getCurrentLocation();
53-
54-
// Several things to note.
55-
// (1) the `Parser::isCurrentTokenAKeyword` method checks the exact
56-
// same conditions as the `AffineParser.cpp`-static free-function
57-
// `isIdentifier` which is used by `AffineParser::parseBareIdExpr`.
58-
// (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)`
59-
// methods do the same song and dance about using
60-
// `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we
61-
// would want to do if we could use the `Parser` class directly. It
62-
// doesn't provide the nice error handling we want, but we can work around
63-
// that.
6449
StringRef name;
6550
if (failed(parser.parseOptionalKeyword(&name))) {
66-
// If not actually optional, then `emitError`.
6751
ERROR_IF(!isOptional, "expected bare identifier")
68-
// If is actually optional, then return the null `OptionalParseResult`.
6952
return std::nullopt;
7053
}
7154

72-
// I don't know if we need to worry about the possibility of the caller
73-
// recovering from error and then reusing the `DimLvlMapParser` for subsequent
74-
// `parseVar`, but I'm erring on the side of caution by distinguishing
75-
// all three possible creation policies.
7655
if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
7756
varID = res->first;
7857
didCreate = res->second;
7958
return success();
8059
}
81-
// TODO(wrengr): these error messages make sense for our intended usage,
82-
// but not in general; but it's unclear how best to factor that part out.
60+
8361
switch (creationPolicy) {
8462
case Policy::MustNot:
8563
return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
@@ -167,8 +145,6 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
167145
FAILURE_IF_FAILED(parseDimSpecList())
168146
FAILURE_IF_FAILED(parser.parseArrow())
169147
FAILURE_IF_FAILED(parseLvlSpecList())
170-
// TODO(wrengr): Try to improve the error messages from
171-
// `VarEnv::emitErrorIfAnyUnbound`.
172148
InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
173149
if (failed(ifd))
174150
return ifd;
@@ -182,29 +158,6 @@ ParseResult DimLvlMapParser::parseSymbolBindingList() {
182158
" in symbol binding list");
183159
}
184160

185-
// FIXME: The forward-declaration of level-vars is a stop-gap workaround
186-
// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
187-
// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
188-
// variables must be bound before entering `AsmParser::parseAffineExpr`,
189-
// since that method requires every variable to already have a fixed/known
190-
// `Var::Num`.)
191-
//
192-
// However, the forward-declaration list duplicates information which is
193-
// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
194-
// the names of the variables themselves, and the order in which the names
195-
// are bound). This redundancy causes bad UX, and also means we must be
196-
// sure to verify consistency between the two sources of information.
197-
//
198-
// Therefore, it would be best to remove the forward-declaration list from
199-
// the syntax. This can be achieved by implementing our own version of
200-
// `AffineParser::parseAffineExpr` which calls
201-
// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
202-
// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
203-
// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
204-
// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
205-
// side, and thus would also enable us to avoid the O(n^2) behavior of copying
206-
// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
207-
// every time `AsmParser::parseAffineExpr` is called.
208161
ParseResult DimLvlMapParser::parseLvlVarBindingList() {
209162
return parser.parseCommaSeparatedList(
210163
OpAsmParser::Delimiter::OptionalBraces,
@@ -233,9 +186,6 @@ ParseResult DimLvlMapParser::parseDimSpec() {
233186
AffineExpr affine;
234187
if (succeeded(parser.parseOptionalEqual())) {
235188
// Parse the dim affine expr, with only any lvl-vars in scope.
236-
// FIXME(wrengr): This still has the O(n^2) behavior of copying
237-
// our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
238-
// field every time `parseDimSpec` is called.
239189
FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
240190
}
241191
DimExpr expr{affine};
@@ -304,9 +254,6 @@ static inline Twine nth(Var::Num n) {
304254
}
305255
}
306256

307-
// NOTE: This is factored out as a separate method only because `Var`
308-
// lacks a default-ctor, which makes this conditional difficult to inline
309-
// at the one call-site.
310257
FailureOr<LvlVar>
311258
DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
312259
// Nothing to parse, just bind an unnamed variable.
@@ -336,17 +283,14 @@ DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
336283
}
337284

338285
ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
339-
// Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
340-
// specifies whether that "optional" is actually Must or MustNot.)
286+
// Parse the optional lvl-var binding. `requireLvlVarBinding`
287+
// specifies whether that "optional" is actually Must or MustNot.
341288
const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
342289
FAILURE_IF_FAILED(varRes)
343290
const LvlVar var = *varRes;
344291

345292
// Parse the lvl affine expr, with only the dim-vars in scope.
346293
AffineExpr affine;
347-
// FIXME(wrengr): This still has the O(n^2) behavior of copying
348-
// our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
349-
// field every time `parseLvlSpec` is called.
350294
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
351295
LvlExpr expr{affine};
352296

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,32 @@ class DimLvlMapParser final {
4242
FailureOr<DimLvlMap> parseDimLvlMap();
4343

4444
private:
45-
/// The core code for parsing `Var`. This method abstracts out a lot
46-
/// of complex details to avoid code duplication; however, client code
47-
/// should prefer using `parseVarUsage` and `parseVarBinding` rather than
48-
/// calling this method directly.
45+
/// Client code should prefer using `parseVarUsage`
46+
/// and `parseVarBinding` rather than calling this method directly.
4947
OptionalParseResult parseVar(VarKind vk, bool isOptional,
5048
Policy creationPolicy, VarInfo::ID &id,
5149
bool &didCreate);
5250

53-
/// Parse a variable occurence which is a *use* of that variable.
54-
/// The `requireKnown` parameter specifies how to handle the case of
55-
/// encountering a valid variable name which is currently unused: when
56-
/// `requireKnown=true`, an error is raised; when `requireKnown=false`,
51+
/// Parses a variable occurence which is a *use* of that variable.
52+
/// When a valid variable name is currently unused, if
53+
/// `requireKnown=true`, an error is raised; if `requireKnown=false`,
5754
/// a new unbound variable will be created.
58-
///
59-
/// NOTE: Just because a variable is *known* (i.e., the name has been
60-
/// associated with an `VarInfo::ID`), does not mean that the variable
61-
/// is actually *in scope*.
6255
FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
6356

64-
/// Parse a variable occurence which is a *binding* of that variable.
57+
/// Parses a variable occurence which is a *binding* of that variable.
6558
/// The `requireKnown` parameter is for handling the binding of
6659
/// forward-declared variables.
6760
FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
6861

69-
/// Parse an optional variable binding. When the next token is
62+
/// Parses an optional variable binding. When the next token is
7063
/// not a valid variable name, this will bind a new unnamed variable.
7164
/// The returned `bool` indicates whether a variable name was parsed.
7265
FailureOr<std::pair<Var, bool>>
7366
parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
7467

7568
/// Binds the given variable: both updating the `VarEnv` itself, and
76-
/// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
77-
/// to `AsmParser::parseAffineExpr`). This method is already called by the
69+
/// the `{dims,lvls}AndSymbols` lists (which will be passed
70+
/// to `AsmParser::parseAffineExpr`). This method is already called by the
7871
/// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
7972
/// not need to be called elsewhere.
8073
Var bindVar(llvm::SMLoc loc, VarInfo::ID id);

0 commit comments

Comments
 (0)