Skip to content

Commit 78921a6

Browse files
wrengraartbik
authored andcommitted
[mlir][sparse] Add more helper methods for converting DimLvlExpr to Var
These new methods help clean up some code for doing LvlExpr-analysis during DimExpr-inference. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D157647
1 parent 8211050 commit 78921a6

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,46 @@ using namespace mlir::sparse_tensor::ir_detail;
1616
// `DimLvlExpr` implementation.
1717
//===----------------------------------------------------------------------===//
1818

19+
Var DimLvlExpr::castAnyVar() const {
20+
assert(expr && "uninitialized DimLvlExpr");
21+
const auto var = dyn_castAnyVar();
22+
assert(var && "expected DimLvlExpr to be a Var");
23+
return *var;
24+
}
25+
26+
std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
27+
if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
28+
return SymVar(s);
29+
if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
30+
return Var(getAllowedVarKind(), x);
31+
return std::nullopt;
32+
}
33+
1934
SymVar DimLvlExpr::castSymVar() const {
2035
return SymVar(expr.cast<AffineSymbolExpr>());
2136
}
2237

38+
std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
39+
if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
40+
return SymVar(s);
41+
return std::nullopt;
42+
}
43+
2344
Var DimLvlExpr::castDimLvlVar() const {
2445
return Var(getAllowedVarKind(), expr.cast<AffineDimExpr>());
2546
}
2647

48+
std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
49+
if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
50+
return Var(getAllowedVarKind(), x);
51+
return std::nullopt;
52+
}
53+
2754
int64_t DimLvlExpr::castConstantValue() const {
2855
return expr.cast<AffineConstantExpr>().getValue();
2956
}
3057

31-
std::optional<int64_t> DimLvlExpr::tryGetConstantValue() const {
58+
std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
3259
const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
3360
return k ? std::make_optional(k.getValue()) : std::nullopt;
3461
}
@@ -98,7 +125,7 @@ static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
98125
return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val};
99126
}
100127
if (op == AffineExprKind::Mul)
101-
if (const auto rval = rhs.tryGetConstantValue(); rval && *rval < 0)
128+
if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0)
102129
return MatchNeg{lhs, *rval};
103130
return std::nullopt;
104131
}

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,14 @@ class DimLvlExpr {
105105
// TODO(wrengr): Most if not all of these don't actually need to be
106106
// methods, they could be free-functions instead.
107107
//
108+
Var castAnyVar() const;
109+
std::optional<Var> dyn_castAnyVar() const;
108110
SymVar castSymVar() const;
111+
std::optional<SymVar> dyn_castSymVar() const;
109112
Var castDimLvlVar() const;
113+
std::optional<Var> dyn_castDimLvlVar() const;
110114
int64_t castConstantValue() const;
111-
std::optional<int64_t> tryGetConstantValue() const;
115+
std::optional<int64_t> dyn_castConstantValue() const;
112116
bool hasConstantValue(int64_t val) const;
113117
DimLvlExpr getLHS() const;
114118
DimLvlExpr getRHS() const;
@@ -155,6 +159,12 @@ class DimExpr final : public DimLvlExpr {
155159
return expr->getExprKind() == Kind;
156160
}
157161
constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
162+
163+
LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
164+
std::optional<LvlVar> dyn_castLvlVar() const {
165+
const auto var = dyn_castDimLvlVar();
166+
return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
167+
}
158168
};
159169
static_assert(IsZeroCostAbstraction<DimExpr>);
160170

@@ -169,6 +179,12 @@ class LvlExpr final : public DimLvlExpr {
169179
return expr->getExprKind() == Kind;
170180
}
171181
constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
182+
183+
DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
184+
std::optional<DimVar> dyn_castDimVar() const {
185+
const auto var = dyn_castDimLvlVar();
186+
return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
187+
}
172188
};
173189
static_assert(IsZeroCostAbstraction<LvlExpr>);
174190

0 commit comments

Comments
 (0)