Skip to content

Commit 7466b8a

Browse files
committed
feat(//core/conversion/evaluators): A whole bunch of new evaluators
Adds evaluators for: - aten::eq - aten::ne - aten::lt - aten::gt - aten::le - aten::ge - aten::add - aten::sub - aten::mul - aten::Bool - aten::Float - aten::__not__ - aten::__is__ - aten::__isnot__ - aten::numel - aten::dim - aten::div - aten::floordiv - aten::floor - aten::warn - prim::min - prim::max - prim::shape - prim::unchecked_cast - prim::Uninitalized - prim::RaiseException Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6cce381 commit 7466b8a

File tree

5 files changed

+519
-45
lines changed

5 files changed

+519
-45
lines changed

core/conversion/evaluators/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ cc_library(
1515
srcs = [
1616
"NodeEvaluatorRegistry.cpp",
1717
"prim.cpp",
18-
"aten.cpp"
18+
"aten.cpp",
19+
"eval_macros.h"
1920
],
2021
deps = [
2122
"//core/util:prelude",

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class NodeEvaluatorRegistry {
3030
public:
3131
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
3232
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
33+
auto iter = evaluator_lut_.find(node_kind);
34+
if (iter != evaluator_lut_.end()) {
35+
TRTORCH_THROW_ERROR("Attempting to override already registered evaluator " << node_kind.toQualString() << ", merge implementations instead");
36+
}
3337
evaluator_lut_[node_kind] = std::move(eval_reg);
3438
}
3539

0 commit comments

Comments
 (0)