Skip to content

Commit 5c050e8

Browse files
authored
implemented slice function in std:vector
1 parent 90c2bc4 commit 5c050e8

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

llama_cpp/llama_grammar.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def __add__(self, value: int) -> "std.vector[T].iterator":
244244
def __sub__(self, value: int) -> "std.vector[T].iterator":
245245
return self.__class__(self._vector, self._index - value)
246246

247-
def __init__(self):
247+
def __init__(self, *args, **kwargs):
248+
super().__init__(*args, **kwargs)
248249
self._version = 0
249250

250251
def modify(self):
@@ -309,7 +310,7 @@ def insert(
309310
first: "std.vector[T].iterator",
310311
last: "std.vector[T].iterator",
311312
) -> None:
312-
self[pos._index : pos._index] = list(
313+
self[pos._index:pos._index] = list(
313314
islice(first._vector, first._index, last._index)
314315
)
315316

@@ -319,6 +320,24 @@ def begin(self) -> "std.vector[T].iterator":
319320
def end(self) -> "std.vector[T].iterator":
320321
return self.iterator(self, self.size())
321322

323+
def __getitem__(self, index):
324+
if isinstance(index, slice):
325+
return std.vector(super().__getitem__(index))
326+
return super().__getitem__(index)
327+
328+
def __setitem__(self, index, value):
329+
self.modify()
330+
if isinstance(index, slice):
331+
if isinstance(value, std.vector):
332+
value = list(value)
333+
super().__setitem__(index, value)
334+
else:
335+
super().__setitem__(index, value)
336+
337+
def __delitem__(self, index):
338+
self.modify()
339+
super().__delitem__(index)
340+
322341
class map(Generic[T, U], OrderedDict[T, U]):
323342
"""C++ implementation of std::map."""
324343

@@ -410,7 +429,6 @@ def begin(self) -> "std.map[T, U].iterator[T, U]":
410429
def end(self) -> "std.map[T, U].iterator[T, U]":
411430
return self.iterator(self, Sentinel())
412431

413-
414432
# // grammar element type
415433
# enum llama_gretype {
416434
# // end of rule definition
@@ -824,7 +842,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
824842

825843

826844
previous_elements = out_elements[last_sym_start:]
827-
print("type-1 ", type(out_elements))
845+
828846
if min_times == 0:
829847
out_elements.resize(last_sym_start)
830848
else:
@@ -835,8 +853,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
835853
last_rec_rule_id = 0 # type: int
836854
n_opt = 1 if max_times < 0 else max_times - min_times # type: int
837855
rec_rule = previous_elements # type: List[LlamaGrammarElement]
838-
print("type1", type(rec_rule))
839-
print('ahhhhhhhhh')
856+
840857
for i in range(n_opt):
841858
rec_rule = previous_elements
842859
rec_rule.resize(len(previous_elements))
@@ -1263,6 +1280,7 @@ def print_rule(
12631280
# print_grammar_char(file, elem.value);
12641281
# break;
12651282
# }
1283+
12661284
for i, elem in enumerate(rule[:-1]):
12671285
case = elem.type # type: llama_gretype
12681286
if case is llama_gretype.LLAMA_GRETYPE_END:

0 commit comments

Comments
 (0)