Skip to content

Commit b85ceaa

Browse files
committed
HashMap#concat uses structural sharing to achieve sublinear performance
Apply @retronym's suggestions Revert breaking HashMap#concat changes In HashMap#concat, rely on before/after size when merging leftNode+rightData to determine if newHash should be changed Lots of progress on more Map concat micro optimizations Remove tzeroes and lzeroes Remove unused if/else clause Add comment explaining that hash collisions are never at the same level as bitmapIndexedMapNodes use mergeTwoKeyValPairs to create new node in HashMap#concat Remove testing code
1 parent 89e2683 commit b85ceaa

File tree

2 files changed

+323
-5
lines changed

2 files changed

+323
-5
lines changed

library/src/scala/collection/immutable/ChampCommon.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ private[immutable] final object Node {
2121

2222
final val SizeMoreThanOne = 2
2323

24+
final val BranchingFactor = 1 << BitPartitionSize
25+
2426
final def maskFrom(hash: Int, shift: Int): Int = (hash >>> shift) & BitPartitionMask
2527

2628
final def bitposFrom(mask: Int): Int = 1 << mask

library/src/scala/collection/immutable/ChampHashMap.scala

Lines changed: 321 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,329 @@ final class HashMap[K, +V] private[immutable] (private[immutable] val rootNode:
105105
}
106106

107107
override def concat[V1 >: V](that: scala.IterableOnce[(K, V1)]): HashMap[K, V1] = {
108-
// TODO PERF We could avoid recomputing entry hash's when `that` is another `HashMap`
109-
val builder = mapFactory.newBuilder[K, V1]
110-
builder ++= this
111-
builder ++= that
112-
builder.result()
108+
if (this eq that.asInstanceOf[AnyRef]) {
109+
return this
110+
}
111+
112+
/** Fall back to slow concatenation in case of error, or non-hashmap `that` */
113+
def slowConcat = mapFactory.newBuilder[K, V1].addAll(this).addAll(that).result()
114+
115+
that match {
116+
case hm: HashMap[K, V1] =>
117+
118+
// We start with the assumption that the maps are distinct, and then subtract the hash code of keys
119+
// in this are overwritten by keys in that.
120+
var newHash = cachedJavaKeySetHashCode + hm.cachedJavaKeySetHashCode
121+
122+
/** Recursively, immutably concatenates two MapNodes, node-by-node as to visit the least number of nodes possible */
123+
def concat(left: MapNode[K, V1], right: MapNode[K, V1], shift: Int): MapNode[K, V1] = {
124+
// Whenever possible, we would like to return early when it is known that the left will not contribute anything
125+
// to the final result of the map. However, this is only possible when we are at the top level of the trie,
126+
// because otherwise we must traverse the entire node in order to update the new hash value (see `newHash`)
127+
//
128+
// This is not necessary in the top level because the resulting hash in that case is already computed. It is
129+
// the hash of the right HashMap
130+
val canReturnEarly = shift == 0
131+
if (canReturnEarly && (left eq right)) {
132+
right
133+
} else left match {
134+
case leftBm: BitmapIndexedMapNode[K, V] =>
135+
// if we go through the merge and the result does not differ from `right`, we can just return `right`, to improve sharing
136+
var anyChangesMadeSoFar = false
137+
138+
right match {
139+
case rightBm: BitmapIndexedMapNode[K, V1] =>
140+
val allMap = leftBm.dataMap | rightBm.dataMap | leftBm.nodeMap | rightBm.nodeMap
141+
142+
// minimumIndex is inclusive -- it is the first index for which there is data or nodes
143+
val minimumBitPos: Int = Node.bitposFrom(Integer.numberOfTrailingZeros(allMap))
144+
// maximumIndex is inclusive -- it is the last index for which there is data or nodes
145+
// it could not be exclusive, because then upper bound in worst case (32) would be out-of-bound of int
146+
// bitposition representation
147+
val maximumBitPos: Int = Node.bitposFrom(Node.BranchingFactor - Integer.numberOfLeadingZeros(allMap) - 1)
148+
149+
var leftNodeRightNode = 0
150+
var leftDataRightNode = 0
151+
var leftNodeRightData = 0
152+
var leftDataOnly = 0
153+
var rightDataOnly = 0
154+
var leftNodeOnly = 0
155+
var rightNodeOnly = 0
156+
var leftDataRightDataMigrateToNode = 0
157+
var leftDataRightDataRightOverwrites = 0
158+
159+
var dataToNodeMigrationTargets = 0
160+
161+
{
162+
var bitpos = minimumBitPos
163+
var leftIdx = 0
164+
var rightIdx = 0
165+
var finished = false
166+
167+
while (!finished) {
168+
169+
if ((bitpos & leftBm.dataMap) != 0) {
170+
if ((bitpos & rightBm.dataMap) != 0) {
171+
if (leftBm.getKey(leftIdx) == rightBm.getKey(rightIdx)) {
172+
leftDataRightDataRightOverwrites |= bitpos
173+
} else {
174+
leftDataRightDataMigrateToNode |= bitpos
175+
dataToNodeMigrationTargets |= Node.bitposFrom(Node.maskFrom(improve(leftBm.getHash(leftIdx)), shift))
176+
}
177+
rightIdx += 1
178+
} else if ((bitpos & rightBm.nodeMap) != 0) {
179+
leftDataRightNode |= bitpos
180+
} else {
181+
leftDataOnly |= bitpos
182+
}
183+
leftIdx += 1
184+
} else if ((bitpos & leftBm.nodeMap) != 0) {
185+
if ((bitpos & rightBm.dataMap) != 0) {
186+
leftNodeRightData |= bitpos
187+
rightIdx += 1
188+
} else if ((bitpos & rightBm.nodeMap) != 0) {
189+
leftNodeRightNode |= bitpos
190+
} else {
191+
leftNodeOnly |= bitpos
192+
}
193+
} else if ((bitpos & rightBm.dataMap) != 0) {
194+
rightDataOnly |= bitpos
195+
rightIdx += 1
196+
} else if ((bitpos & rightBm.nodeMap) != 0) {
197+
rightNodeOnly |= bitpos
198+
}
199+
200+
if (bitpos == maximumBitPos) {
201+
finished = true
202+
} else {
203+
bitpos = bitpos << 1
204+
}
205+
}
206+
}
207+
208+
209+
val newDataMap = leftDataOnly | rightDataOnly | leftDataRightDataRightOverwrites
210+
211+
val newNodeMap =
212+
leftNodeRightNode |
213+
leftDataRightNode |
214+
leftNodeRightData |
215+
leftNodeOnly |
216+
rightNodeOnly |
217+
dataToNodeMigrationTargets
218+
219+
220+
if (canReturnEarly && (newDataMap == (rightDataOnly | leftDataRightDataRightOverwrites)) && (newNodeMap == rightNodeOnly)) {
221+
// nothing from left will make it into the result -- return early
222+
return right
223+
}
224+
225+
val newDataSize = bitCount(newDataMap)
226+
val newContentSize = (MapNode.TupleLength * newDataSize) + bitCount(newNodeMap)
227+
228+
val result = new BitmapIndexedMapNode[K, V1](
229+
dataMap = newDataMap,
230+
nodeMap = newNodeMap,
231+
content = new Array[Any](newContentSize),
232+
originalHashes = new Array[Int](newDataSize),
233+
size = 0
234+
)
235+
236+
{
237+
var leftDataIdx = 0
238+
var rightDataIdx = 0
239+
var leftNodeIdx = 0
240+
var rightNodeIdx = 0
241+
242+
val nextShift = shift + Node.BitPartitionSize
243+
244+
var compressedDataIdx = 0
245+
var compressedNodeIdx = 0
246+
247+
var bitpos = minimumBitPos
248+
var finished = false
249+
250+
while (!finished) {
251+
252+
if ((bitpos & leftNodeRightNode) != 0) {
253+
val rightNode = rightBm.getNode(rightNodeIdx)
254+
val newNode = concat(leftBm.getNode(leftNodeIdx), rightNode, nextShift)
255+
if (rightNode ne newNode) {
256+
anyChangesMadeSoFar = true
257+
}
258+
result.content(newContentSize - compressedNodeIdx - 1) = newNode
259+
compressedNodeIdx += 1
260+
rightNodeIdx += 1
261+
leftNodeIdx += 1
262+
result.size += newNode.size
263+
264+
} else if ((bitpos & leftDataRightNode) != 0) {
265+
val newNode = {
266+
val n = rightBm.getNode(rightNodeIdx)
267+
val leftKey = leftBm.getKey(leftDataIdx)
268+
val leftValue = leftBm.getValue(leftDataIdx)
269+
val leftOriginalHash = leftBm.getHash(leftDataIdx)
270+
val leftImproved = improve(leftOriginalHash)
271+
272+
// TODO: Implement MapNode#updatedIfNotContains
273+
val updated = if (n.containsKey(leftKey, leftOriginalHash, leftImproved, nextShift)) {
274+
newHash -= leftImproved
275+
n
276+
} else {
277+
n.updated(leftKey, leftValue, leftOriginalHash, leftImproved, nextShift)
278+
}
279+
280+
if (updated ne n) {
281+
anyChangesMadeSoFar = true
282+
}
283+
284+
updated
285+
}
286+
287+
result.content(newContentSize - compressedNodeIdx - 1) = newNode
288+
compressedNodeIdx += 1
289+
rightNodeIdx += 1
290+
leftDataIdx += 1
291+
result.size += newNode.size
292+
}
293+
else if ((bitpos & leftNodeRightData) != 0) {
294+
anyChangesMadeSoFar = true
295+
val newNode = {
296+
val n = leftBm.getNode(leftNodeIdx)
297+
val rightKey = rightBm.getKey(rightDataIdx)
298+
val rightValue = rightBm.getValue(rightDataIdx)
299+
val rightOriginalHash = rightBm.getHash(rightDataIdx)
300+
val rightImproved = improve(rightOriginalHash)
301+
302+
val updated = n.updated(rightKey, rightValue, rightOriginalHash, rightImproved, nextShift)
303+
304+
if (updated.size == n.size) {
305+
newHash -= rightImproved
306+
}
307+
308+
updated
309+
}
310+
311+
result.content(newContentSize - compressedNodeIdx - 1) = newNode
312+
compressedNodeIdx += 1
313+
leftNodeIdx += 1
314+
rightDataIdx += 1
315+
result.size += newNode.size
316+
317+
} else if ((bitpos & leftDataOnly) != 0) {
318+
anyChangesMadeSoFar = true
319+
result.content(MapNode.TupleLength * compressedDataIdx) =
320+
leftBm.getKey(leftDataIdx).asInstanceOf[AnyRef]
321+
result.content(MapNode.TupleLength * compressedDataIdx + 1) =
322+
leftBm.getValue(leftDataIdx).asInstanceOf[AnyRef]
323+
result.originalHashes(compressedDataIdx) = leftBm.originalHashes(leftDataIdx)
324+
325+
compressedDataIdx += 1
326+
leftDataIdx += 1
327+
result.size += 1
328+
} else if ((bitpos & rightDataOnly) != 0) {
329+
result.content(MapNode.TupleLength * compressedDataIdx) =
330+
rightBm.getKey(rightDataIdx).asInstanceOf[AnyRef]
331+
result.content(MapNode.TupleLength * compressedDataIdx + 1) =
332+
rightBm.getValue(rightDataIdx).asInstanceOf[AnyRef]
333+
result.originalHashes(compressedDataIdx) = rightBm.originalHashes(rightDataIdx)
334+
335+
compressedDataIdx += 1
336+
rightDataIdx += 1
337+
result.size += 1
338+
} else if ((bitpos & leftNodeOnly) != 0) {
339+
anyChangesMadeSoFar = true
340+
val newNode = leftBm.getNode(leftNodeIdx)
341+
result.content(newContentSize - compressedNodeIdx - 1) = newNode
342+
compressedNodeIdx += 1
343+
leftNodeIdx += 1
344+
result.size += newNode.size
345+
} else if ((bitpos & rightNodeOnly) != 0) {
346+
val newNode = rightBm.getNode(rightNodeIdx)
347+
result.content(newContentSize - compressedNodeIdx - 1) = newNode
348+
compressedNodeIdx += 1
349+
rightNodeIdx += 1
350+
result.size += newNode.size
351+
} else if ((bitpos & leftDataRightDataMigrateToNode) != 0) {
352+
anyChangesMadeSoFar = true
353+
val newNode = {
354+
val leftOriginalHash = leftBm.getHash(leftDataIdx)
355+
val rightOriginalHash = rightBm.getHash(rightDataIdx)
356+
357+
rightBm.mergeTwoKeyValPairs(
358+
leftBm.getKey(leftDataIdx), leftBm.getValue(leftDataIdx), leftOriginalHash, improve(leftOriginalHash),
359+
rightBm.getKey(rightDataIdx), rightBm.getValue(rightDataIdx), rightOriginalHash, improve(rightOriginalHash),
360+
nextShift
361+
)
362+
}
363+
364+
result.content(newContentSize - compressedNodeIdx - 1) = newNode
365+
compressedNodeIdx += 1
366+
leftDataIdx += 1
367+
rightDataIdx += 1
368+
result.size += newNode.size
369+
} else if ((bitpos & leftDataRightDataRightOverwrites) != 0) {
370+
result.content(MapNode.TupleLength * compressedDataIdx) =
371+
rightBm.getKey(rightDataIdx).asInstanceOf[AnyRef]
372+
result.content(MapNode.TupleLength * compressedDataIdx + 1) =
373+
rightBm.getValue(rightDataIdx).asInstanceOf[AnyRef]
374+
result.originalHashes(compressedDataIdx) = rightBm.originalHashes(rightDataIdx)
375+
376+
compressedDataIdx += 1
377+
rightDataIdx += 1
378+
result.size += 1
379+
380+
newHash -= improve(leftBm.getHash(leftDataIdx))
381+
leftDataIdx += 1
382+
}
383+
384+
if (bitpos == maximumBitPos) {
385+
finished = true
386+
} else {
387+
bitpos = bitpos << 1
388+
}
389+
}
390+
}
391+
392+
if (anyChangesMadeSoFar) result else right
393+
394+
case rightHc: HashCollisionMapNode[K, V1] =>
395+
// should never happen -- hash collisions are never at the same level as bitmapIndexedMapNodes
396+
var current: MapNode[K, V1] = leftBm
397+
rightHc.content.foreach { case (k, v) =>
398+
current = current.updated(k, v, rightHc.originalHash, rightHc.hash, shift)
399+
}
400+
current
401+
}
402+
case leftHc: HashCollisionMapNode[K, V] => right match {
403+
case rightBm: BitmapIndexedMapNode[K, V1] =>
404+
// should never happen -- hash collisions are never at the same level as bitmapIndexedMapNodes
405+
var current: MapNode[K, V1] = rightBm
406+
leftHc.content.foreach { case (k, v) =>
407+
current = current.updated(k, v, leftHc.originalHash, leftHc.hash, shift)
408+
}
409+
current
410+
case rightHc: HashCollisionMapNode[K, V1] =>
411+
var result: MapNode[K, V1] = leftHc
412+
var i = 0
413+
val improved = improve(leftHc.originalHash)
414+
while (i < leftHc.size) {
415+
result = result.updated(rightHc.getKey(i), rightHc.getValue(i), rightHc.originalHash, improved, shift)
416+
i += 1
417+
}
418+
result
419+
}
420+
}
421+
}
422+
423+
val newRootNode = concat(rootNode, hm.rootNode, 0)
424+
if (newRootNode eq rootNode) this else new HashMap(newRootNode, newHash)
425+
case _ =>
426+
slowConcat
427+
}
113428
}
114429

430+
115431
override def tail: HashMap[K, V] = this - head._1
116432

117433
override def init: HashMap[K, V] = this - last._1

0 commit comments

Comments
 (0)