Skip to content

Commit 24d888d

Browse files
committed
Monai bug-brats_mri_segmentation fixes
1 parent cf1f637 commit 24d888d

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,20 @@ auto aten_registrations TORCHTRT_UNUSED =
223223
{c10::Symbol::fromQualString("aten::slice"),
224224
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
225225
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
226-
227226
int64_t start = 0;
227+
int64_t end = 9223372036854775807;
228228
auto startIVal = args.at(n->input(1)).IValue();
229+
auto endIVal = args.at(n->input(2)).IValue();
230+
229231
if (!startIVal->isNone()) {
230232
start = args.at(n->input(1)).unwrapToInt();
231233
}
232-
int64_t end = args.at(n->input(2)).unwrapToInt();
234+
if (!endIVal->isNone()) {
235+
end = args.at(n->input(2)).unwrapToInt();
236+
}
237+
if (start > end) {
238+
LOG_DEBUG("The end should be greater than start");
239+
}
233240
int64_t step = args.at(n->input(3)).unwrapToInt();
234241

235242
const int64_t list_size = list.size();
@@ -253,8 +260,9 @@ auto aten_registrations TORCHTRT_UNUSED =
253260

254261
return sliced_list;
255262
},
256-
EvalOptions().validSchemas(
257-
{"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
263+
EvalOptions().validSchemas({"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])"})})
264+
// EvalOptions().validSchemas(
265+
// {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
258266
.evaluator(
259267
{c10::Symbol::fromQualString("aten::len"),
260268
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -896,8 +904,14 @@ auto aten_registrations TORCHTRT_UNUSED =
896904
auto step = args.at(n->input(2)).unwrapToInt();
897905
return start + idx * step;
898906
},
899-
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})});
900-
907+
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})})
908+
.evaluator(
909+
{c10::Symbol::fromQualString("aten::list"),
910+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
911+
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
912+
return list.copy();
913+
},
914+
EvalOptions().validSchemas({"aten::list.t(t[] l) -> (t[])"})});
901915
} // namespace
902916
} // namespace evaluators
903917
} // namespace conversion

0 commit comments

Comments
 (0)