@@ -9,7 +9,9 @@ import Hashing.improve
9
9
import java .lang .Integer .{bitCount , numberOfTrailingZeros }
10
10
import java .lang .System .arraycopy
11
11
12
+ import scala .collection .immutable .Set .Set4
12
13
import scala .util .hashing .MurmurHash3
14
+ import scala .runtime .Statics .releaseFence
13
15
14
16
/** This class implements immutable sets using a Compressed Hash-Array Mapped Prefix-tree.
15
17
* See paper https://michael.steindorfer.name/publications/oopsla15.pdf for more details.
@@ -26,6 +28,8 @@ final class HashSet[A] private[immutable] (val rootNode: SetNode[A], val cachedJ
26
28
with SetOps [A , HashSet , HashSet [A ]]
27
29
with StrictOptimizedIterableOps [A , HashSet , HashSet [A ]] {
28
30
31
+ releaseFence()
32
+
29
33
override def iterableFactory : IterableFactory [HashSet ] = HashSet
30
34
31
35
override def knownSize : Int = rootNode.size
@@ -68,6 +72,13 @@ final class HashSet[A] private[immutable] (val rootNode: SetNode[A], val cachedJ
68
72
else this
69
73
}
70
74
75
+ override def concat (that : IterableOnce [A ]): HashSet [A ] = {
76
+ val builder = iterableFactory.newBuilder[A ]
77
+ builder ++= this
78
+ builder ++= that
79
+ builder.result()
80
+ }
81
+
71
82
override def tail : HashSet [A ] = this - head
72
83
73
84
override def init : HashSet [A ] = this - last
@@ -214,9 +225,15 @@ private[immutable] sealed abstract class SetNode[A] extends Node[SetNode[A]] {
214
225
215
226
def subsetOf (that : SetNode [A ], shift : Int ): Boolean
216
227
228
+ def copy (): SetNode [A ]
217
229
}
218
230
219
- private final class BitmapIndexedSetNode [A ](val dataMap : Int , val nodeMap : Int , val content : Array [Any ], val originalHashes : Array [Int ], val size : Int ) extends SetNode [A ] {
231
+ private final class BitmapIndexedSetNode [A ](
232
+ var dataMap : Int ,
233
+ var nodeMap : Int ,
234
+ var content : Array [Any ],
235
+ var originalHashes : Array [Int ],
236
+ var size : Int ) extends SetNode [A ] {
220
237
221
238
import Node ._
222
239
import SetNode ._
@@ -564,8 +581,8 @@ private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int,
564
581
(this eq node) ||
565
582
(this .nodeMap == node.nodeMap) &&
566
583
(this .dataMap == node.dataMap) &&
567
- java.util.Arrays .equals(this .originalHashes, node.originalHashes) &&
568
- deepContentEquality(this .content, node.content, content.length)
584
+ java.util.Arrays .equals(this .originalHashes, node.originalHashes) &&
585
+ deepContentEquality(this .content, node.content, content.length)
569
586
case _ => false
570
587
}
571
588
@@ -588,13 +605,24 @@ private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int,
588
605
override def hashCode (): Int =
589
606
throw new UnsupportedOperationException (" Trie nodes do not support hashing." )
590
607
608
+ override def copy (): BitmapIndexedSetNode [A ] = {
609
+ val contentClone = new Array [Any ](content.length)
610
+ val dataIndices = bitCount(dataMap)
611
+ Array .copy(content, 0 , contentClone, 0 , dataIndices)
612
+ var i = dataIndices
613
+ while (i < content.length) {
614
+ contentClone(i) = content(i).asInstanceOf [SetNode [A ]].copy()
615
+ i += 1
616
+ }
617
+ new BitmapIndexedSetNode [A ](dataMap, nodeMap, contentClone, originalHashes.clone(), size)
618
+ }
591
619
}
592
620
593
- private final class HashCollisionSetNode [A ](val originalHash : Int , val hash : Int , val content : Vector [A ]) extends SetNode [A ] {
621
+ private final class HashCollisionSetNode [A ](val originalHash : Int , val hash : Int , var content : Vector [Any ]) extends SetNode [A ] {
594
622
595
623
import Node ._
596
624
597
- require(content.size >= 2 )
625
+ require(content.length >= 2 )
598
626
599
627
def contains (element : A , originalHash : Int , hash : Int , shift : Int ): Boolean =
600
628
this .hash == hash && content.contains(element)
@@ -621,7 +649,7 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
621
649
// assert(updatedContent.size == content.size - 1)
622
650
623
651
updatedContent.size match {
624
- case 1 => new BitmapIndexedSetNode [A ](bitposFrom(maskFrom(hash, 0 )), 0 , updatedContent.toArray , Array (originalHash), 1 )
652
+ case 1 => new BitmapIndexedSetNode [A ](bitposFrom(maskFrom(hash, 0 )), 0 , Array ( updatedContent( 0 )) , Array (originalHash), 1 )
625
653
case _ => new HashCollisionSetNode [A ](originalHash, hash, updatedContent)
626
654
}
627
655
}
@@ -635,23 +663,28 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
635
663
636
664
def hasPayload : Boolean = true
637
665
638
- def payloadArity : Int = content.size
666
+ def payloadArity : Int = content.length
639
667
640
- def getPayload (index : Int ): A = content(index)
668
+ def getPayload (index : Int ): A = content(index). asInstanceOf [ A ]
641
669
642
670
override def getHash (index : Int ): Int = originalHash
643
671
644
672
def sizePredicate : Int = SizeMoreThanOne
645
673
646
- def size : Int = content.size
674
+ def size : Int = content.length
647
675
648
- def foreach [U ](f : A => U ): Unit = content.foreach(f)
676
+ def foreach [U ](f : A => U ): Unit = {
677
+ var i = 0
678
+ while (i < content.length) {
679
+ f(getPayload(i))
680
+ i += 1
681
+ }
682
+ }
649
683
650
684
def subsetOf (that : SetNode [A ], shift : Int ): Boolean = if (this eq that) true else that match {
651
685
case node : BitmapIndexedSetNode [A ] => false
652
- case node : HashCollisionSetNode [A ] => {
686
+ case node : HashCollisionSetNode [A ] =>
653
687
this .payloadArity <= node.payloadArity && this .content.forall(node.content.contains)
654
- }
655
688
}
656
689
657
690
override def equals (that : Any ): Boolean =
@@ -667,6 +700,8 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
667
700
override def hashCode (): Int =
668
701
throw new UnsupportedOperationException (" Trie nodes do not support hashing." )
669
702
703
+ override def copy () = new HashCollisionSetNode [A ](originalHash, hash, content)
704
+
670
705
}
671
706
672
707
private final class SetIterator [A ](rootNode : SetNode [A ])
@@ -740,16 +775,211 @@ object HashSet extends IterableFactory[HashSet] {
740
775
case _ => (newBuilder[A ] ++= source).result()
741
776
}
742
777
743
- def newBuilder [A ]: Builder [A , HashSet [A ]] =
744
- new ImmutableBuilder [A , HashSet [A ]](empty) {
745
- def addOne (element : A ): this .type = {
746
- elems = elems + element
747
- this
748
- }
749
- }
778
+ def newBuilder [A ]: Builder [A , HashSet [A ]] = new HashSetBuilder
750
779
751
780
// scalac generates a `readReplace` method to discard the deserialized state (see https://github.com/scala/bug/issues/10412).
752
781
// This prevents it from serializing it in the first place:
753
782
private [this ] def writeObject (out : ObjectOutputStream ): Unit = ()
754
783
private [this ] def readObject (in : ObjectInputStream ): Unit = ()
755
784
}
785
+
786
+ private [collection] final class HashSetBuilder [A ] extends Builder [A , HashSet [A ]] {
787
+ import Node ._
788
+ import SetNode ._
789
+
790
+ private def newEmptyRootNode = new BitmapIndexedSetNode [A ](0 , 0 , Array (), Array (), 0 )
791
+
792
+ /** The last given out HashSet as a return value of `result()`, if any, otherwise null.
793
+ * Indicates that on next add, the elements should be copied to an identical structure, before continuing
794
+ * mutations. */
795
+ private var aliased : HashSet [A ] = _
796
+
797
+ private def isAliased : Boolean = aliased != null
798
+
799
+ /** The root node of the partially build hashmap */
800
+ private var rootNode : SetNode [A ] = newEmptyRootNode
801
+
802
+ /** The cached hash of the partially-built hashmap */
803
+ private var hash : Int = 0
804
+
805
+ /** Inserts element `elem` into array `as` at index `ix`, shifting right the trailing elems */
806
+ private def insertElement (as : Array [Int ], ix : Int , elem : Int ): Array [Int ] = {
807
+ if (ix < 0 ) throw new ArrayIndexOutOfBoundsException
808
+ if (ix > as.length) throw new ArrayIndexOutOfBoundsException
809
+ val result = new Array [Int ](as.length + 1 )
810
+ arraycopy(as, 0 , result, 0 , ix)
811
+ result(ix) = elem
812
+ arraycopy(as, ix, result, ix + 1 , as.length - ix)
813
+ result
814
+ }
815
+
816
+ /** Inserts key-value into the bitmapIndexMapNode. Requires that this is a new key-value pair */
817
+ private def insertValue [A1 >: A ](bm : BitmapIndexedSetNode [A ], bitpos : Int , key : A , originalHash : Int , keyHash : Int ): Unit = {
818
+ val dataIx = bm.dataIndex(bitpos)
819
+ val idx = TupleLength * dataIx
820
+
821
+ val src = bm.content
822
+ val dst = new Array [Any ](src.length + TupleLength )
823
+
824
+ // copy 'src' and insert 2 element(s) at position 'idx'
825
+ arraycopy(src, 0 , dst, 0 , idx)
826
+ dst(idx) = key
827
+ arraycopy(src, idx, dst, idx + TupleLength , src.length - idx)
828
+
829
+ val dstHashes = insertElement(bm.originalHashes, dataIx, originalHash)
830
+
831
+ bm.dataMap = bm.dataMap | bitpos
832
+ bm.content = dst
833
+ bm.originalHashes = dstHashes
834
+ bm.size += 1
835
+ }
836
+
837
+ /** Removes element at index `ix` from array `as`, shifting the trailing elements right */
838
+ private def removeElement (as : Array [Int ], ix : Int ): Array [Int ] = {
839
+ if (ix < 0 ) throw new ArrayIndexOutOfBoundsException
840
+ if (ix > as.length - 1 ) throw new ArrayIndexOutOfBoundsException
841
+ val result = new Array [Int ](as.length - 1 )
842
+ arraycopy(as, 0 , result, 0 , ix)
843
+ arraycopy(as, ix + 1 , result, ix, as.length - ix - 1 )
844
+ result
845
+ }
846
+
847
+ /** Mutates `bm` to replace inline data at bit position `bitpos` with node `node` */
848
+ private def migrateFromInlineToNode (bm : BitmapIndexedSetNode [A ], bitpos : Int , node : SetNode [A ]): Unit = {
849
+ val dataIx = bm.dataIndex(bitpos)
850
+ val idxOld = TupleLength * dataIx
851
+ val idxNew = bm.content.length - TupleLength - bm.nodeIndex(bitpos)
852
+
853
+ val src = bm.content
854
+ val dst = new Array [Any ](src.length - TupleLength + 1 )
855
+
856
+ // copy 'src' and remove 2 element(s) at position 'idxOld' and
857
+ // insert 1 element(s) at position 'idxNew'
858
+ // assert(idxOld <= idxNew)
859
+ arraycopy(src, 0 , dst, 0 , idxOld)
860
+ arraycopy(src, idxOld + TupleLength , dst, idxOld, idxNew - idxOld)
861
+ dst(idxNew) = node
862
+ arraycopy(src, idxNew + TupleLength , dst, idxNew + 1 , src.length - idxNew - TupleLength )
863
+
864
+ val dstHashes = removeElement(bm.originalHashes, dataIx)
865
+
866
+ bm.dataMap ^= bitpos
867
+ bm.nodeMap |= bitpos
868
+ bm.content = dst
869
+ bm.originalHashes = dstHashes
870
+ bm.size = bm.size - 1 + node.size
871
+ }
872
+
873
+ /** Mutates `bm` to replace inline data at bit position `bitpos` with updated key/value */
874
+ private def setValue [A1 >: A ](bm : BitmapIndexedSetNode [A ], bitpos : Int , elem : A ): Unit = {
875
+ val dataIx = bm.dataIndex(bitpos)
876
+ val idx = TupleLength * dataIx
877
+ bm.content(idx) = elem
878
+ }
879
+
880
+ def update (setNode : SetNode [A ], element : A , originalHash : Int , elementHash : Int , shift : Int ): Unit =
881
+ setNode match {
882
+ case bm : BitmapIndexedSetNode [A ] =>
883
+ val mask = maskFrom(elementHash, shift)
884
+ val bitpos = bitposFrom(mask)
885
+
886
+ if ((bm.dataMap & bitpos) != 0 ) {
887
+ val index = indexFrom(bm.dataMap, mask, bitpos)
888
+ val element0 = bm.getPayload(index)
889
+ val element0UnimprovedHash = bm.getHash(index)
890
+
891
+ if (element0UnimprovedHash == originalHash && element0 == element) {
892
+ setValue(bm, bitpos, element0)
893
+ } else {
894
+ val element0Hash = improve(element0UnimprovedHash)
895
+ val subNodeNew = bm.mergeTwoKeyValPairs(element0, element0UnimprovedHash, element0Hash, element, originalHash, elementHash, shift + BitPartitionSize )
896
+ hash += elementHash
897
+ migrateFromInlineToNode(bm, bitpos, subNodeNew)
898
+ }
899
+ } else if ((bm.nodeMap & bitpos) != 0 ) {
900
+ val index = indexFrom(bm.nodeMap, mask, bitpos)
901
+ val subNode = bm.getNode(index)
902
+ val beforeSize = subNode.size
903
+ update(subNode, element, originalHash, elementHash, shift + BitPartitionSize )
904
+ bm.size += subNode.size - beforeSize
905
+ } else {
906
+ insertValue(bm, bitpos, element, originalHash, elementHash)
907
+ hash += elementHash
908
+ }
909
+ case hc : HashCollisionSetNode [A ] =>
910
+ val index = hc.content.indexOf(element)
911
+ if (index < 0 ) {
912
+ hash += elementHash
913
+ hc.content = hc.content.appended(element)
914
+ } else {
915
+ hc.content = hc.content.updated(index, element)
916
+ }
917
+ }
918
+
919
+ /** If currently referencing aliased structure, copy elements to new mutable structure */
920
+ private def ensureUnaliased (): Unit = {
921
+ if (isAliased) copyElems()
922
+ aliased = null
923
+ }
924
+
925
+ /** Copy elements to new mutable structure */
926
+ private def copyElems (): Unit = {
927
+ rootNode = rootNode.copy()
928
+ }
929
+
930
+ override def result (): HashSet [A ] =
931
+ if (rootNode.size == 0 ) {
932
+ HashSet .empty
933
+ } else if (aliased != null ) {
934
+ aliased
935
+ } else {
936
+ aliased = new HashSet (rootNode, hash)
937
+ releaseFence()
938
+ aliased
939
+ }
940
+
941
+ override def addOne (elem : A ): this .type = {
942
+ ensureUnaliased()
943
+ val h = elem.##
944
+ val im = improve(h)
945
+ update(rootNode, elem, h, im, 0 )
946
+ this
947
+ }
948
+
949
+ override def addAll (xs : IterableOnce [A ]) = {
950
+ ensureUnaliased()
951
+ xs match {
952
+ case hm : HashSet [A ] =>
953
+ new ChampBaseIterator (hm.rootNode) {
954
+ while (hasNext) {
955
+ val originalHash = currentValueNode.getHash(currentValueCursor)
956
+ update(
957
+ setNode = rootNode,
958
+ element = currentValueNode.getPayload(currentValueCursor),
959
+ originalHash = originalHash,
960
+ elementHash = improve(originalHash),
961
+ shift = 0
962
+ )
963
+ currentValueCursor += 1
964
+ }
965
+ }
966
+ case other =>
967
+ val it = other.iterator
968
+ while (it.hasNext) addOne(it.next())
969
+ }
970
+
971
+ this
972
+ }
973
+
974
+ override def clear (): Unit = {
975
+ aliased = null
976
+ if (rootNode.size > 0 ) {
977
+ // if rootNode is empty, we will not have given it away anyways, we instead give out the reused Set.empty
978
+ rootNode = newEmptyRootNode
979
+ }
980
+ hash = 0
981
+ }
982
+
983
+ private [collection] def size : Int = rootNode.size
984
+ }
985
+
0 commit comments