Skip to content

Commit 2b13162

Browse files
authored
gguf-py : add support for sub_type (in arrays) in GGUFWriter add_key_value method (#13561)
1 parent 54a2c7a commit 2b13162

File tree

5 files changed

+29
-16
lines changed

5 files changed

+29
-16
lines changed

gguf-py/gguf/gguf_writer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class TensorInfo:
4949
class GGUFValue:
5050
value: Any
5151
type: GGUFValueType
52+
sub_type: GGUFValueType | None = None
5253

5354

5455
class WriterState(Enum):
@@ -238,7 +239,7 @@ def write_kv_data_to_file(self) -> None:
238239

239240
for key, val in kv_data.items():
240241
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
241-
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
242+
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
242243

243244
fout.write(kv_bytes)
244245

@@ -268,11 +269,11 @@ def write_ti_data_to_file(self) -> None:
268269
fout.flush()
269270
self.state = WriterState.TI_DATA
270271

271-
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
272+
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
272273
if any(key in kv_data for kv_data in self.kv_data):
273274
raise ValueError(f'Duplicated key name {key!r}')
274275

275-
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
276+
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
276277

277278
def add_uint8(self, key: str, val: int) -> None:
278279
self.add_key_value(key,val, GGUFValueType.UINT8)
@@ -1022,7 +1023,7 @@ def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
10221023
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
10231024
return struct.pack(f'{pack_prefix}{fmt}', value)
10241025

1025-
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
1026+
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
10261027
kv_data = bytearray()
10271028

10281029
if add_vtype:
@@ -1043,7 +1044,9 @@ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
10431044
if len(val) == 0:
10441045
raise ValueError("Invalid GGUF metadata array. Empty array")
10451046

1046-
if isinstance(val, bytes):
1047+
if sub_type is not None:
1048+
ltype = sub_type
1049+
elif isinstance(val, bytes):
10471050
ltype = GGUFValueType.UINT8
10481051
else:
10491052
ltype = GGUFValueType.get_type(val[0])

gguf-py/gguf/scripts/gguf_editor_gui.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,27 +1521,34 @@ def save_file(self):
15211521
continue
15221522

15231523
# Apply changes if any
1524+
sub_type = None
15241525
if field.name in self.metadata_changes:
15251526
value_type, value = self.metadata_changes[field.name]
15261527
if value_type == GGUFValueType.ARRAY:
15271528
# Handle array values
1528-
element_type, array_values = value
1529-
writer.add_array(field.name, array_values)
1530-
else:
1531-
writer.add_key_value(field.name, value, value_type)
1529+
sub_type, value = value
15321530
else:
15331531
# Copy original value
15341532
value = field.contents()
1535-
if value is not None and field.types:
1536-
writer.add_key_value(field.name, value, field.types[0])
1533+
value_type = field.types[0]
1534+
if value_type == GGUFValueType.ARRAY:
1535+
sub_type = field.types[-1]
1536+
1537+
if value is not None:
1538+
writer.add_key_value(field.name, value, value_type, sub_type=sub_type)
15371539

15381540
# Add new metadata
15391541
for key, (value_type, value) in self.metadata_changes.items():
15401542
# Skip if the key already existed (we handled it above)
15411543
if self.reader.get_field(key) is not None:
15421544
continue
15431545

1544-
writer.add_key_value(key, value, value_type)
1546+
sub_type = None
1547+
if value_type == GGUFValueType.ARRAY:
1548+
# Handle array values
1549+
sub_type, value = value
1550+
1551+
writer.add_key_value(key, value, value_type, sub_type=sub_type)
15451552

15461553
# Add tensors (including data)
15471554
for tensor in self.reader.tensors:

gguf-py/gguf/scripts/gguf_new_metadata.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MetadataDetails(NamedTuple):
2424
type: gguf.GGUFValueType
2525
value: Any
2626
description: str = ''
27+
sub_type: gguf.GGUFValueType | None = None
2728

2829

2930
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
@@ -57,7 +58,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
5758
logger.debug(f'Removing {field.name}')
5859
continue
5960

60-
old_val = MetadataDetails(field.types[0], field.contents())
61+
val_type = field.types[0]
62+
sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
63+
old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
6164
val = new_metadata.get(field.name, old_val)
6265

6366
if field.name in new_metadata:
@@ -67,7 +70,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
6770
logger.debug(f'Copying {field.name}')
6871

6972
if val.value is not None:
70-
writer.add_key_value(field.name, val.value, val.type)
73+
writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
7174

7275
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
7376
logger.debug('Adding chat template(s)')

gguf-py/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "gguf"
3-
version = "0.16.3"
3+
version = "0.17.0"
44
description = "Read and write ML models in GGUF for GGML"
55
authors = ["GGML <[email protected]>"]
66
packages = [
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy~=1.26.4
22
PySide6~=6.9.0
3-
gguf>=0.16.0
3+
gguf>=0.17.0

0 commit comments

Comments
 (0)