Skip to content

Commit f9485b1

Browse files
authored
Merge pull request scala/scala#7248 from joshlemer/setbuilder-2
optimize s.c.i.HashSetBuilder
2 parents c726dae + 79ed012 commit f9485b1

File tree

4 files changed

+306
-33
lines changed

4 files changed

+306
-33
lines changed

library/src/scala/collection/immutable/ChampHashMap.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -725,19 +725,15 @@ private final class BitmapIndexedMapNode[K, +V](
725725

726726
override def copy(): BitmapIndexedMapNode[K, V] = {
727727
val contentClone = new Array[Any](content.length)
728-
var i = 0
729-
val dataIndices = bitCount(dataMap) * 2
730-
while (i < dataIndices) {
731-
contentClone(i) = content(i)
732-
i += 1
733-
}
728+
val dataIndices = bitCount(dataMap) * TupleLength
729+
Array.copy(content, 0, contentClone, 0, dataIndices)
730+
var i = dataIndices
734731
while (i < content.length) {
735732
contentClone(i) = content(i).asInstanceOf[MapNode[K, V]].copy()
736733
i += 1
737734
}
738735
new BitmapIndexedMapNode[K, V](dataMap, nodeMap, contentClone, originalHashes.clone(), size)
739736
}
740-
741737
}
742738

743739
private final class HashCollisionMapNode[K, +V ](
@@ -1142,7 +1138,6 @@ private[immutable] final class HashMapBuilder[K, V] extends Builder[(K, V), Hash
11421138

11431139
/** Copy elements to new mutable structure */
11441140
private def copyElems(): Unit = {
1145-
aliased = null
11461141
rootNode = rootNode.copy()
11471142
}
11481143

library/src/scala/collection/immutable/ChampHashSet.scala

Lines changed: 249 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import Hashing.improve
99
import java.lang.Integer.{bitCount, numberOfTrailingZeros}
1010
import java.lang.System.arraycopy
1111

12+
import scala.collection.immutable.Set.Set4
1213
import scala.util.hashing.MurmurHash3
14+
import scala.runtime.Statics.releaseFence
1315

1416
/** This class implements immutable sets using a Compressed Hash-Array Mapped Prefix-tree.
1517
* See paper https://michael.steindorfer.name/publications/oopsla15.pdf for more details.
@@ -26,6 +28,8 @@ final class HashSet[A] private[immutable] (val rootNode: SetNode[A], val cachedJ
2628
with SetOps[A, HashSet, HashSet[A]]
2729
with StrictOptimizedIterableOps[A, HashSet, HashSet[A]] {
2830

31+
releaseFence()
32+
2933
override def iterableFactory: IterableFactory[HashSet] = HashSet
3034

3135
override def knownSize: Int = rootNode.size
@@ -68,6 +72,13 @@ final class HashSet[A] private[immutable] (val rootNode: SetNode[A], val cachedJ
6872
else this
6973
}
7074

75+
override def concat(that: IterableOnce[A]): HashSet[A] = {
76+
val builder = iterableFactory.newBuilder[A]
77+
builder ++= this
78+
builder ++= that
79+
builder.result()
80+
}
81+
7182
override def tail: HashSet[A] = this - head
7283

7384
override def init: HashSet[A] = this - last
@@ -214,9 +225,15 @@ private[immutable] sealed abstract class SetNode[A] extends Node[SetNode[A]] {
214225

215226
def subsetOf(that: SetNode[A], shift: Int): Boolean
216227

228+
def copy(): SetNode[A]
217229
}
218230

219-
private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int, val content: Array[Any], val originalHashes: Array[Int], val size: Int) extends SetNode[A] {
231+
private final class BitmapIndexedSetNode[A](
232+
var dataMap: Int,
233+
var nodeMap: Int,
234+
var content: Array[Any],
235+
var originalHashes: Array[Int],
236+
var size: Int) extends SetNode[A] {
220237

221238
import Node._
222239
import SetNode._
@@ -564,8 +581,8 @@ private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int,
564581
(this eq node) ||
565582
(this.nodeMap == node.nodeMap) &&
566583
(this.dataMap == node.dataMap) &&
567-
java.util.Arrays.equals(this.originalHashes, node.originalHashes) &&
568-
deepContentEquality(this.content, node.content, content.length)
584+
java.util.Arrays.equals(this.originalHashes, node.originalHashes) &&
585+
deepContentEquality(this.content, node.content, content.length)
569586
case _ => false
570587
}
571588

@@ -588,13 +605,24 @@ private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int,
588605
override def hashCode(): Int =
589606
throw new UnsupportedOperationException("Trie nodes do not support hashing.")
590607

608+
override def copy(): BitmapIndexedSetNode[A] = {
609+
val contentClone = new Array[Any](content.length)
610+
val dataIndices = bitCount(dataMap)
611+
Array.copy(content, 0, contentClone, 0, dataIndices)
612+
var i = dataIndices
613+
while (i < content.length) {
614+
contentClone(i) = content(i).asInstanceOf[SetNode[A]].copy()
615+
i += 1
616+
}
617+
new BitmapIndexedSetNode[A](dataMap, nodeMap, contentClone, originalHashes.clone(), size)
618+
}
591619
}
592620

593-
private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int, val content: Vector[A]) extends SetNode[A] {
621+
private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int, var content: Vector[Any]) extends SetNode[A] {
594622

595623
import Node._
596624

597-
require(content.size >= 2)
625+
require(content.length >= 2)
598626

599627
def contains(element: A, originalHash: Int, hash: Int, shift: Int): Boolean =
600628
this.hash == hash && content.contains(element)
@@ -621,7 +649,7 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
621649
// assert(updatedContent.size == content.size - 1)
622650

623651
updatedContent.size match {
624-
case 1 => new BitmapIndexedSetNode[A](bitposFrom(maskFrom(hash, 0)), 0, updatedContent.toArray, Array(originalHash), 1)
652+
case 1 => new BitmapIndexedSetNode[A](bitposFrom(maskFrom(hash, 0)), 0, Array(updatedContent(0)), Array(originalHash), 1)
625653
case _ => new HashCollisionSetNode[A](originalHash, hash, updatedContent)
626654
}
627655
}
@@ -635,23 +663,28 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
635663

636664
def hasPayload: Boolean = true
637665

638-
def payloadArity: Int = content.size
666+
def payloadArity: Int = content.length
639667

640-
def getPayload(index: Int): A = content(index)
668+
def getPayload(index: Int): A = content(index).asInstanceOf[A]
641669

642670
override def getHash(index: Int): Int = originalHash
643671

644672
def sizePredicate: Int = SizeMoreThanOne
645673

646-
def size: Int = content.size
674+
def size: Int = content.length
647675

648-
def foreach[U](f: A => U): Unit = content.foreach(f)
676+
def foreach[U](f: A => U): Unit = {
677+
var i = 0
678+
while (i < content.length) {
679+
f(getPayload(i))
680+
i += 1
681+
}
682+
}
649683

650684
def subsetOf(that: SetNode[A], shift: Int): Boolean = if (this eq that) true else that match {
651685
case node: BitmapIndexedSetNode[A] => false
652-
case node: HashCollisionSetNode[A] => {
686+
case node: HashCollisionSetNode[A] =>
653687
this.payloadArity <= node.payloadArity && this.content.forall(node.content.contains)
654-
}
655688
}
656689

657690
override def equals(that: Any): Boolean =
@@ -667,6 +700,8 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
667700
override def hashCode(): Int =
668701
throw new UnsupportedOperationException("Trie nodes do not support hashing.")
669702

703+
override def copy() = new HashCollisionSetNode[A](originalHash, hash, content)
704+
670705
}
671706

672707
private final class SetIterator[A](rootNode: SetNode[A])
@@ -740,16 +775,211 @@ object HashSet extends IterableFactory[HashSet] {
740775
case _ => (newBuilder[A] ++= source).result()
741776
}
742777

743-
def newBuilder[A]: Builder[A, HashSet[A]] =
744-
new ImmutableBuilder[A, HashSet[A]](empty) {
745-
def addOne(element: A): this.type = {
746-
elems = elems + element
747-
this
748-
}
749-
}
778+
def newBuilder[A]: Builder[A, HashSet[A]] = new HashSetBuilder
750779

751780
// scalac generates a `readReplace` method to discard the deserialized state (see https://github.com/scala/bug/issues/10412).
752781
// This prevents it from serializing it in the first place:
753782
private[this] def writeObject(out: ObjectOutputStream): Unit = ()
754783
private[this] def readObject(in: ObjectInputStream): Unit = ()
755784
}
785+
786+
private[collection] final class HashSetBuilder[A] extends Builder[A, HashSet[A]] {
787+
import Node._
788+
import SetNode._
789+
790+
private def newEmptyRootNode = new BitmapIndexedSetNode[A](0, 0, Array(), Array(), 0)
791+
792+
/** The last given out HashSet as a return value of `result()`, if any, otherwise null.
793+
* Indicates that on next add, the elements should be copied to an identical structure, before continuing
794+
* mutations. */
795+
private var aliased: HashSet[A] = _
796+
797+
private def isAliased: Boolean = aliased != null
798+
799+
/** The root node of the partially build hashmap */
800+
private var rootNode: SetNode[A] = newEmptyRootNode
801+
802+
/** The cached hash of the partially-built hashmap */
803+
private var hash: Int = 0
804+
805+
/** Inserts element `elem` into array `as` at index `ix`, shifting right the trailing elems */
806+
private def insertElement(as: Array[Int], ix: Int, elem: Int): Array[Int] = {
807+
if (ix < 0) throw new ArrayIndexOutOfBoundsException
808+
if (ix > as.length) throw new ArrayIndexOutOfBoundsException
809+
val result = new Array[Int](as.length + 1)
810+
arraycopy(as, 0, result, 0, ix)
811+
result(ix) = elem
812+
arraycopy(as, ix, result, ix + 1, as.length - ix)
813+
result
814+
}
815+
816+
/** Inserts key-value into the bitmapIndexMapNode. Requires that this is a new key-value pair */
817+
private def insertValue[A1 >: A](bm: BitmapIndexedSetNode[A], bitpos: Int, key: A, originalHash: Int, keyHash: Int): Unit = {
818+
val dataIx = bm.dataIndex(bitpos)
819+
val idx = TupleLength * dataIx
820+
821+
val src = bm.content
822+
val dst = new Array[Any](src.length + TupleLength)
823+
824+
// copy 'src' and insert 2 element(s) at position 'idx'
825+
arraycopy(src, 0, dst, 0, idx)
826+
dst(idx) = key
827+
arraycopy(src, idx, dst, idx + TupleLength, src.length - idx)
828+
829+
val dstHashes = insertElement(bm.originalHashes, dataIx, originalHash)
830+
831+
bm.dataMap = bm.dataMap | bitpos
832+
bm.content = dst
833+
bm.originalHashes = dstHashes
834+
bm.size += 1
835+
}
836+
837+
/** Removes element at index `ix` from array `as`, shifting the trailing elements right */
838+
private def removeElement(as: Array[Int], ix: Int): Array[Int] = {
839+
if (ix < 0) throw new ArrayIndexOutOfBoundsException
840+
if (ix > as.length - 1) throw new ArrayIndexOutOfBoundsException
841+
val result = new Array[Int](as.length - 1)
842+
arraycopy(as, 0, result, 0, ix)
843+
arraycopy(as, ix + 1, result, ix, as.length - ix - 1)
844+
result
845+
}
846+
847+
/** Mutates `bm` to replace inline data at bit position `bitpos` with node `node` */
848+
private def migrateFromInlineToNode(bm: BitmapIndexedSetNode[A], bitpos: Int, node: SetNode[A]): Unit = {
849+
val dataIx = bm.dataIndex(bitpos)
850+
val idxOld = TupleLength * dataIx
851+
val idxNew = bm.content.length - TupleLength - bm.nodeIndex(bitpos)
852+
853+
val src = bm.content
854+
val dst = new Array[Any](src.length - TupleLength + 1)
855+
856+
// copy 'src' and remove 2 element(s) at position 'idxOld' and
857+
// insert 1 element(s) at position 'idxNew'
858+
// assert(idxOld <= idxNew)
859+
arraycopy(src, 0, dst, 0, idxOld)
860+
arraycopy(src, idxOld + TupleLength, dst, idxOld, idxNew - idxOld)
861+
dst(idxNew) = node
862+
arraycopy(src, idxNew + TupleLength, dst, idxNew + 1, src.length - idxNew - TupleLength)
863+
864+
val dstHashes = removeElement(bm.originalHashes, dataIx)
865+
866+
bm.dataMap ^= bitpos
867+
bm.nodeMap |= bitpos
868+
bm.content = dst
869+
bm.originalHashes = dstHashes
870+
bm.size = bm.size - 1 + node.size
871+
}
872+
873+
/** Mutates `bm` to replace inline data at bit position `bitpos` with updated key/value */
874+
private def setValue[A1 >: A](bm: BitmapIndexedSetNode[A], bitpos: Int, elem: A): Unit = {
875+
val dataIx = bm.dataIndex(bitpos)
876+
val idx = TupleLength * dataIx
877+
bm.content(idx) = elem
878+
}
879+
880+
def update(setNode: SetNode[A], element: A, originalHash: Int, elementHash: Int, shift: Int): Unit =
881+
setNode match {
882+
case bm: BitmapIndexedSetNode[A] =>
883+
val mask = maskFrom(elementHash, shift)
884+
val bitpos = bitposFrom(mask)
885+
886+
if ((bm.dataMap & bitpos) != 0) {
887+
val index = indexFrom(bm.dataMap, mask, bitpos)
888+
val element0 = bm.getPayload(index)
889+
val element0UnimprovedHash = bm.getHash(index)
890+
891+
if (element0UnimprovedHash == originalHash && element0 == element) {
892+
setValue(bm, bitpos, element0)
893+
} else {
894+
val element0Hash = improve(element0UnimprovedHash)
895+
val subNodeNew = bm.mergeTwoKeyValPairs(element0, element0UnimprovedHash, element0Hash, element, originalHash, elementHash, shift + BitPartitionSize)
896+
hash += elementHash
897+
migrateFromInlineToNode(bm, bitpos, subNodeNew)
898+
}
899+
} else if ((bm.nodeMap & bitpos) != 0) {
900+
val index = indexFrom(bm.nodeMap, mask, bitpos)
901+
val subNode = bm.getNode(index)
902+
val beforeSize = subNode.size
903+
update(subNode, element, originalHash, elementHash, shift + BitPartitionSize)
904+
bm.size += subNode.size - beforeSize
905+
} else {
906+
insertValue(bm, bitpos, element, originalHash, elementHash)
907+
hash += elementHash
908+
}
909+
case hc: HashCollisionSetNode[A] =>
910+
val index = hc.content.indexOf(element)
911+
if (index < 0) {
912+
hash += elementHash
913+
hc.content = hc.content.appended(element)
914+
} else {
915+
hc.content = hc.content.updated(index, element)
916+
}
917+
}
918+
919+
/** If currently referencing aliased structure, copy elements to new mutable structure */
920+
private def ensureUnaliased():Unit = {
921+
if (isAliased) copyElems()
922+
aliased = null
923+
}
924+
925+
/** Copy elements to new mutable structure */
926+
private def copyElems(): Unit = {
927+
rootNode = rootNode.copy()
928+
}
929+
930+
override def result(): HashSet[A] =
931+
if (rootNode.size == 0) {
932+
HashSet.empty
933+
} else if (aliased != null) {
934+
aliased
935+
} else {
936+
aliased = new HashSet(rootNode, hash)
937+
releaseFence()
938+
aliased
939+
}
940+
941+
override def addOne(elem: A): this.type = {
942+
ensureUnaliased()
943+
val h = elem.##
944+
val im = improve(h)
945+
update(rootNode, elem, h, im, 0)
946+
this
947+
}
948+
949+
override def addAll(xs: IterableOnce[A]) = {
950+
ensureUnaliased()
951+
xs match {
952+
case hm: HashSet[A] =>
953+
new ChampBaseIterator(hm.rootNode) {
954+
while(hasNext) {
955+
val originalHash = currentValueNode.getHash(currentValueCursor)
956+
update(
957+
setNode = rootNode,
958+
element = currentValueNode.getPayload(currentValueCursor),
959+
originalHash = originalHash,
960+
elementHash = improve(originalHash),
961+
shift = 0
962+
)
963+
currentValueCursor += 1
964+
}
965+
}
966+
case other =>
967+
val it = other.iterator
968+
while(it.hasNext) addOne(it.next())
969+
}
970+
971+
this
972+
}
973+
974+
override def clear(): Unit = {
975+
aliased = null
976+
if (rootNode.size > 0) {
977+
// if rootNode is empty, we will not have given it away anyways, we instead give out the reused Set.empty
978+
rootNode = newEmptyRootNode
979+
}
980+
hash = 0
981+
}
982+
983+
private[collection] def size: Int = rootNode.size
984+
}
985+

0 commit comments

Comments
 (0)