@@ -16,19 +16,46 @@ using namespace mlir::sparse_tensor::ir_detail;
16
16
// `DimLvlExpr` implementation.
17
17
// ===----------------------------------------------------------------------===//
18
18
19
+ Var DimLvlExpr::castAnyVar () const {
20
+ assert (expr && " uninitialized DimLvlExpr" );
21
+ const auto var = dyn_castAnyVar ();
22
+ assert (var && " expected DimLvlExpr to be a Var" );
23
+ return *var;
24
+ }
25
+
26
+ std::optional<Var> DimLvlExpr::dyn_castAnyVar () const {
27
+ if (const auto s = expr.dyn_cast_or_null <AffineSymbolExpr>())
28
+ return SymVar (s);
29
+ if (const auto x = expr.dyn_cast_or_null <AffineDimExpr>())
30
+ return Var (getAllowedVarKind (), x);
31
+ return std::nullopt;
32
+ }
33
+
19
34
SymVar DimLvlExpr::castSymVar () const {
20
35
return SymVar (expr.cast <AffineSymbolExpr>());
21
36
}
22
37
38
+ std::optional<SymVar> DimLvlExpr::dyn_castSymVar () const {
39
+ if (const auto s = expr.dyn_cast_or_null <AffineSymbolExpr>())
40
+ return SymVar (s);
41
+ return std::nullopt;
42
+ }
43
+
23
44
Var DimLvlExpr::castDimLvlVar () const {
24
45
return Var (getAllowedVarKind (), expr.cast <AffineDimExpr>());
25
46
}
26
47
48
+ std::optional<Var> DimLvlExpr::dyn_castDimLvlVar () const {
49
+ if (const auto x = expr.dyn_cast_or_null <AffineDimExpr>())
50
+ return Var (getAllowedVarKind (), x);
51
+ return std::nullopt;
52
+ }
53
+
27
54
int64_t DimLvlExpr::castConstantValue () const {
28
55
return expr.cast <AffineConstantExpr>().getValue ();
29
56
}
30
57
31
- std::optional<int64_t > DimLvlExpr::tryGetConstantValue () const {
58
+ std::optional<int64_t > DimLvlExpr::dyn_castConstantValue () const {
32
59
const auto k = expr.dyn_cast_or_null <AffineConstantExpr>();
33
60
return k ? std::make_optional (k.getValue ()) : std::nullopt;
34
61
}
@@ -98,7 +125,7 @@ static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
98
125
return MatchNeg{DimLvlExpr{expr.getExprKind (), AffineExpr ()}, val};
99
126
}
100
127
if (op == AffineExprKind::Mul)
101
- if (const auto rval = rhs.tryGetConstantValue (); rval && *rval < 0 )
128
+ if (const auto rval = rhs.dyn_castConstantValue (); rval && *rval < 0 )
102
129
return MatchNeg{lhs, *rval};
103
130
return std::nullopt;
104
131
}
0 commit comments