Skip to content

Commit a5f4e80

Browse files
support new dense vector quantization in 8.16
1 parent 0dd69f8 commit a5f4e80

File tree

4 files changed

+123
-4
lines changed

4 files changed

+123
-4
lines changed

elasticsearch_dsl/field.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
389389
return float(data)
390390

391391

392-
class DenseVector(Float):
392+
class DenseVector(Field):
393393
name = "dense_vector"
394+
_coerce = True
394395

395396
def __init__(self, **kwargs: Any):
396-
kwargs["multi"] = True
397+
self._element_type = kwargs.get("element_type", "float")
398+
if self._element_type in ["float", "byte"]:
399+
kwargs["multi"] = True
397400
super().__init__(**kwargs)
398401

402+
def _deserialize(self, data: Any) -> Any:
403+
if self._element_type == "float":
404+
return float(data)
405+
elif self._element_type == "byte":
406+
return int(data)
407+
return data
408+
399409

400410
class SparseVector(Field):
401411
name = "sparse_vector"

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ filterwarnings =
1111
error
1212
ignore:Legacy index templates are deprecated in favor of composable templates.:elasticsearch.exceptions.ElasticsearchWarning
1313
ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated and scheduled for removal in a future version..*:DeprecationWarning
14+
default:enable_cleanup_closed ignored.*:DeprecationWarning
1415
markers =
1516
sync: mark a test as performing I/O without asyncio.

tests/test_integration/_async/test_document.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
26+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
@@ -37,6 +37,7 @@
3737
Binary,
3838
Boolean,
3939
Date,
40+
DenseVector,
4041
Double,
4142
InnerDoc,
4243
Ip,
@@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
795796
"age": 45,
796797
"languages": ["es"],
797798
}
799+
800+
801+
@pytest.mark.asyncio
802+
async def test_legacy_dense_vector(
803+
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
804+
) -> None:
805+
if es_version >= (8, 16):
806+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
807+
808+
class Doc(AsyncDocument):
809+
float_vector: List[float] = mapped_field(DenseVector(dims=3))
810+
811+
class Index:
812+
name = "vectors"
813+
814+
await Doc._index.delete(ignore_unavailable=True)
815+
await Doc.init()
816+
817+
doc = Doc(float_vector=[1.0, 1.2, 2.3])
818+
await doc.save(refresh=True)
819+
820+
docs = await Doc.search().execute()
821+
assert len(docs) == 1
822+
assert docs[0].float_vector == doc.float_vector
823+
824+
825+
@pytest.mark.asyncio
826+
async def test_dense_vector(
827+
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
828+
) -> None:
829+
if es_version < (8, 16):
830+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
831+
832+
class Doc(AsyncDocument):
833+
float_vector: List[float] = mapped_field(DenseVector())
834+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
835+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
836+
837+
class Index:
838+
name = "vectors"
839+
840+
await Doc._index.delete(ignore_unavailable=True)
841+
await Doc.init()
842+
843+
doc = Doc(
844+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
845+
)
846+
await doc.save(refresh=True)
847+
848+
docs = await Doc.search().execute()
849+
assert len(docs) == 1
850+
assert docs[0].float_vector == doc.float_vector
851+
assert docs[0].byte_vector == doc.byte_vector
852+
assert docs[0].bit_vector == doc.bit_vector

tests/test_integration/_sync/test_document.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union
26+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
@@ -35,6 +35,7 @@
3535
Binary,
3636
Boolean,
3737
Date,
38+
DenseVector,
3839
Document,
3940
Double,
4041
InnerDoc,
@@ -789,3 +790,55 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
789790
"age": 45,
790791
"languages": ["es"],
791792
}
793+
794+
795+
@pytest.mark.sync
796+
def test_legacy_dense_vector(
797+
client: Elasticsearch, es_version: Tuple[int, ...]
798+
) -> None:
799+
if es_version >= (8, 16):
800+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
801+
802+
class Doc(Document):
803+
float_vector: List[float] = mapped_field(DenseVector(dims=3))
804+
805+
class Index:
806+
name = "vectors"
807+
808+
Doc._index.delete(ignore_unavailable=True)
809+
Doc.init()
810+
811+
doc = Doc(float_vector=[1.0, 1.2, 2.3])
812+
doc.save(refresh=True)
813+
814+
docs = Doc.search().execute()
815+
assert len(docs) == 1
816+
assert docs[0].float_vector == doc.float_vector
817+
818+
819+
@pytest.mark.sync
820+
def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
821+
if es_version < (8, 16):
822+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
823+
824+
class Doc(Document):
825+
float_vector: List[float] = mapped_field(DenseVector())
826+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
827+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
828+
829+
class Index:
830+
name = "vectors"
831+
832+
Doc._index.delete(ignore_unavailable=True)
833+
Doc.init()
834+
835+
doc = Doc(
836+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
837+
)
838+
doc.save(refresh=True)
839+
840+
docs = Doc.search().execute()
841+
assert len(docs) == 1
842+
assert docs[0].float_vector == doc.float_vector
843+
assert docs[0].byte_vector == doc.byte_vector
844+
assert docs[0].bit_vector == doc.bit_vector

0 commit comments

Comments
 (0)