Skip to content

Commit 6d5ca70

Browse files
committed
feat(embedding): implement in-memory and disk-synchronized embedding search indices #200
This commit introduces two classes, `InMemoryEmbeddingSearchIndex` and `DiskSynchronizedEmbeddingSearchIndex`, which implement the `EmbeddingSearchIndex` interface. These classes provide methods for addingfeat, updating(embed,ding and): deleting add embedding entries in,-memory as and well disk-sync as searchinged for embedding search the closest index embeddings to This commit a introduces given a query new embedding. in The-memory ` and diskIn-sMemoryynchronizedEmbed embeddingding searchSearch indexIndex.` The stores in all-memory embeddings index in stores memory embeddings, in while memory the and ` supportsDisk concurrentS readynchronized operationsEmbed,ding whileSearch theIndex disk`-s synchronynchronizedizes index index maintains changes index with synchronization disk with storage disk. storage Additionally., Both the commit indices implement includes the a Embed `dingLockedSearchSequenceWrapperIndex interface`, class providing to methods safely for iterate adding over entries embeddings, under saving a/loading lock from, disk as, well finding as closest utility embeddings functions, for and calculating more embedding. similarity Additionally and, normalization the. Locked OverallSequence,Wrapper these ensures classes thread provide-safe efficient iteration and over thread the-safe index ways. to manage and search embedding indices.
1 parent 2af87cc commit 6d5ca70

File tree

7 files changed

+812
-0
lines changed

7 files changed

+812
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
2+
package cc.unitmesh.devti.embedding
3+
4+
5+
import com.intellij.concurrency.ConcurrentCollectionFactory
6+
import com.intellij.util.containers.CollectionFactory
7+
import kotlinx.coroutines.coroutineScope
8+
import kotlinx.coroutines.ensureActive
9+
import java.nio.file.Path
10+
import java.util.concurrent.locks.ReentrantReadWriteLock
11+
import kotlin.concurrent.read
12+
import kotlin.concurrent.write
13+
14+
/**
15+
* Concurrent [EmbeddingSearchIndex] that synchronizes all index change operations with disk and
16+
* allows simultaneous read operations from multiple consumers.
17+
* Incremental operations do not rewrite the whole storage file with embeddings.
18+
* Instead, they change only the corresponding sections in the file.
19+
*/
20+
class DiskSynchronizedEmbeddingSearchIndex(val root: Path, limit: Int? = null) : EmbeddingSearchIndex {
21+
private var indexToId: MutableMap<Int, String> = CollectionFactory.createSmallMemoryFootprintMap()
22+
private var idToEntry: MutableMap<String, IndexEntry> = CollectionFactory.createSmallMemoryFootprintMap()
23+
private val uncheckedIds: MutableSet<String> = ConcurrentCollectionFactory.createConcurrentSet()
24+
var changed: Boolean = false
25+
26+
private val lock = ReentrantReadWriteLock()
27+
28+
private val fileManager = LocalEmbeddingIndexFileManager(root)
29+
30+
override var limit = limit
31+
set(value) = lock.write {
32+
if (value != null) {
33+
// Shrink index if necessary:
34+
while (idToEntry.size > value) {
35+
delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false)
36+
}
37+
saveIds()
38+
}
39+
field = value
40+
}
41+
42+
internal data class IndexEntry(
43+
var index: Int,
44+
var count: Int,
45+
val embedding: FloatArray
46+
)
47+
48+
override val size: Int get() = lock.read { idToEntry.size }
49+
50+
override operator fun contains(id: String): Boolean = lock.read {
51+
uncheckedIds.remove(id)
52+
id in idToEntry
53+
}
54+
55+
override fun clear() = lock.write {
56+
indexToId.clear()
57+
idToEntry.clear()
58+
uncheckedIds.clear()
59+
changed = false
60+
}
61+
62+
override fun onIndexingStart() {
63+
uncheckedIds.clear()
64+
uncheckedIds.addAll(idToEntry.keys)
65+
}
66+
67+
override fun onIndexingFinish() = lock.write {
68+
if (uncheckedIds.size > 0) changed = true
69+
uncheckedIds.forEach {
70+
delete(it, all = true, shouldSaveIds = false)
71+
}
72+
uncheckedIds.clear()
73+
}
74+
75+
override suspend fun addEntries(values: Iterable<Pair<String, FloatArray>>,
76+
shouldCount: Boolean) = coroutineScope {
77+
lock.write {
78+
for ((id, embedding) in values) {
79+
ensureActive()
80+
val entry = idToEntry.getOrPut(id) {
81+
changed = true
82+
if (limit != null && idToEntry.size >= limit!!) return@write
83+
val index = idToEntry.size
84+
indexToId[index] = id
85+
IndexEntry(index, 0, embedding)
86+
}
87+
if (shouldCount || entry.count == 0) {
88+
entry.count += 1
89+
}
90+
}
91+
}
92+
}
93+
94+
override suspend fun saveToDisk() = lock.read { save() }
95+
96+
override suspend fun loadFromDisk() = coroutineScope {
97+
val (ids, embeddings) = fileManager.loadIndex() ?: return@coroutineScope
98+
val idToIndex = ids.withIndex().associate { it.value to it.index }
99+
val idToEmbedding = (ids zip embeddings).toMap()
100+
ensureActive()
101+
lock.write {
102+
ensureActive()
103+
indexToId = CollectionFactory.createSmallMemoryFootprintMap(ids.withIndex().associate { it.index to it.value })
104+
idToEntry = CollectionFactory.createSmallMemoryFootprintMap(
105+
ids.associateWith { IndexEntry(idToIndex[it]!!, 0, idToEmbedding[it]!!) }
106+
)
107+
}
108+
}
109+
110+
override fun findClosest(searchEmbedding: FloatArray, topK: Int, similarityThreshold: Double?): List<ScoredText> = lock.read {
111+
return idToEntry.mapValues { it.value.embedding }.findClosest(searchEmbedding, topK, similarityThreshold)
112+
}
113+
114+
override fun streamFindClose(searchEmbedding: FloatArray, similarityThreshold: Double?): Sequence<ScoredText> {
115+
return LockedSequenceWrapper(lock::readLock) {
116+
this.idToEntry // manually use the receiver here to make sure the property is not captured by reference
117+
.asSequence()
118+
.map { it.key to it.value.embedding }
119+
.streamFindClose(searchEmbedding, similarityThreshold)
120+
}
121+
}
122+
123+
override fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * size
124+
125+
override fun estimateLimitByMemory(memory: Long): Int {
126+
return (memory / fileManager.embeddingSizeInBytes).toInt()
127+
}
128+
129+
override fun checkCanAddEntry(): Boolean = lock.read {
130+
return limit == null || idToEntry.size < limit!!
131+
}
132+
133+
private suspend fun save() = coroutineScope {
134+
val ids = idToEntry.toList().sortedBy { it.second.index }.map { it.first }
135+
val embeddings = ids.map { idToEntry[it]!!.embedding }
136+
fileManager.saveIndex(ids, embeddings)
137+
}
138+
139+
fun deleteEntry(id: String) = lock.write {
140+
delete(id)
141+
}
142+
143+
fun addEntry(id: String, embedding: FloatArray) = lock.write {
144+
add(id, embedding)
145+
}
146+
147+
/* Optimization for consequent delete and add operations */
148+
fun updateEntry(id: String, newId: String, embedding: FloatArray) = lock.write {
149+
if (id !in idToEntry) return
150+
if (idToEntry[id]!!.count == 1 && newId !in this) {
151+
val index = idToEntry[id]!!.index
152+
fileManager[index] = embedding
153+
154+
idToEntry.remove(id)
155+
idToEntry[newId] = IndexEntry(index, 1, embedding)
156+
indexToId[index] = newId
157+
158+
saveIds()
159+
}
160+
else {
161+
// Do not apply optimization
162+
delete(id)
163+
add(newId, embedding)
164+
}
165+
}
166+
167+
private fun add(id: String, embedding: FloatArray, shouldCount: Boolean = false) {
168+
val entry = idToEntry.getOrPut(id) {
169+
changed = true
170+
if (limit != null && idToEntry.size >= limit!!) return@add
171+
val index = idToEntry.size
172+
fileManager[index] = embedding
173+
indexToId[index] = id
174+
IndexEntry(index, 0, embedding)
175+
}
176+
if (shouldCount || entry.count == 0) {
177+
entry.count += 1
178+
if (entry.count == 1) {
179+
saveIds()
180+
}
181+
}
182+
}
183+
184+
private fun delete(id: String, all: Boolean = false, shouldSaveIds: Boolean = true) {
185+
val entry = idToEntry[id] ?: return
186+
entry.count -= 1
187+
if (!all && entry.count > 0) return
188+
189+
val lastIndex = idToEntry.size - 1
190+
val index = entry.index
191+
192+
val movedId = indexToId[lastIndex]!!
193+
194+
fileManager.removeAtIndex(index)
195+
indexToId[index] = movedId
196+
indexToId.remove(lastIndex)
197+
198+
idToEntry[movedId]!!.index = index
199+
idToEntry.remove(id)
200+
201+
if (shouldSaveIds) saveIds()
202+
}
203+
204+
private fun saveIds() {
205+
fileManager.saveIds(idToEntry.toList().sortedBy { it.second.index }.map { it.first })
206+
}
207+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
2+
package cc.unitmesh.devti.embedding
3+
4+
import kotlin.collections.asSequence
5+
import kotlin.math.sqrt
6+
import kotlin.sequences.filter
7+
import kotlin.sequences.map
8+
import kotlin.sequences.sortedByDescending
9+
import kotlin.sequences.take
10+
import kotlin.sequences.toList
11+
import kotlin.to
12+
13+
interface EmbeddingSearchIndex {
14+
val size: Int
15+
var limit: Int?
16+
17+
operator fun contains(id: String): Boolean
18+
fun clear()
19+
20+
fun onIndexingStart()
21+
fun onIndexingFinish()
22+
23+
suspend fun addEntries(values: Iterable<Pair<String, FloatArray>>, shouldCount: Boolean = false)
24+
25+
suspend fun saveToDisk()
26+
suspend fun loadFromDisk()
27+
28+
fun findClosest(searchEmbedding: FloatArray, topK: Int, similarityThreshold: Double? = null): List<ScoredText>
29+
fun streamFindClose(searchEmbedding: FloatArray, similarityThreshold: Double? = null): Sequence<ScoredText>
30+
31+
fun estimateMemoryUsage(): Long
32+
fun estimateLimitByMemory(memory: Long): Int
33+
fun checkCanAddEntry(): Boolean
34+
}
35+
36+
internal fun Map<String, FloatArray>.findClosest(
37+
searchEmbedding: FloatArray,
38+
topK: Int, similarityThreshold: Double?,
39+
): List<ScoredText> {
40+
return asSequence()
41+
.map {
42+
it.key to searchEmbedding.times(it.value)
43+
}
44+
.filter { (_, similarity) -> if (similarityThreshold != null) similarity > similarityThreshold else true }
45+
.sortedByDescending { (_, similarity) -> similarity }
46+
.take(topK)
47+
.map { (id, similarity) -> ScoredText(id, similarity.toDouble()) }
48+
.toList()
49+
}
50+
51+
internal fun Sequence<Pair<String, FloatArray>>.streamFindClose(
52+
queryEmbedding: FloatArray,
53+
similarityThreshold: Double?,
54+
): Sequence<ScoredText> {
55+
return map { (id, embedding) -> id to queryEmbedding.times(embedding) }
56+
.filter { similarityThreshold == null || it.second > similarityThreshold }
57+
.map { (id, similarity) -> ScoredText(id, similarity.toDouble()) }
58+
}
59+
60+
fun FloatArray.times(other: FloatArray): Float {
61+
require(this.size == other.size) {
62+
"Embeddings must have the same size, but got ${this.size} and ${other.size}"
63+
}
64+
return this.zip(other).map { (a, b) -> a * b }.sum()
65+
}
66+
67+
fun FloatArray.normalized(): FloatArray {
68+
val norm = sqrt(this.times(this))
69+
val normalizedValues = this.map { it / norm }
70+
return normalizedValues.toFloatArray()
71+
}
72+
73+
fun FloatArray.cosine(other: FloatArray): Float {
74+
require(this.size == other.size) { "Embeddings must have the same size" }
75+
val dot = this.times(other)
76+
val norm = sqrt(this.times(this)) * sqrt(other.times(other))
77+
return dot / norm
78+
}

0 commit comments

Comments
 (0)