@@ -9,7 +9,9 @@ import Hashing.improve
9
9
import java .lang .Integer .{bitCount , numberOfTrailingZeros }
10
10
import java .lang .System .arraycopy
11
11
12
+ import scala .collection .immutable .Set .Set4
12
13
import scala .util .hashing .MurmurHash3
14
+ import scala .runtime .Statics .releaseFence
13
15
14
16
/** This class implements immutable sets using a Compressed Hash-Array Mapped Prefix-tree.
15
17
* See paper https://michael.steindorfer.name/publications/oopsla15.pdf for more details.
@@ -26,6 +28,8 @@ final class HashSet[A] private[immutable] (val rootNode: SetNode[A], val cachedJ
26
28
with SetOps [A , HashSet , HashSet [A ]]
27
29
with StrictOptimizedIterableOps [A , HashSet , HashSet [A ]] {
28
30
31
+ releaseFence()
32
+
29
33
override def iterableFactory : IterableFactory [HashSet ] = HashSet
30
34
31
35
override def knownSize : Int = rootNode.size
@@ -68,6 +72,13 @@ final class HashSet[A] private[immutable] (val rootNode: SetNode[A], val cachedJ
68
72
else this
69
73
}
70
74
75
+ override def concat (that : IterableOnce [A ]): HashSet [A ] = {
76
+ val builder = iterableFactory.newBuilder[A ]
77
+ builder ++= this
78
+ builder ++= that
79
+ builder.result()
80
+ }
81
+
71
82
override def tail : HashSet [A ] = this - head
72
83
73
84
override def init : HashSet [A ] = this - last
@@ -142,9 +153,15 @@ private[immutable] sealed abstract class SetNode[A] extends Node[SetNode[A]] {
142
153
143
154
def subsetOf (that : SetNode [A ], shift : Int ): Boolean
144
155
156
+ def copy (): SetNode [A ]
145
157
}
146
158
147
- private final class BitmapIndexedSetNode [A ](val dataMap : Int , val nodeMap : Int , val content : Array [Any ], val originalHashes : Array [Int ], val size : Int ) extends SetNode [A ] {
159
+ private final class BitmapIndexedSetNode [A ](
160
+ var dataMap : Int ,
161
+ var nodeMap : Int ,
162
+ var content : Array [Any ],
163
+ var originalHashes : Array [Int ],
164
+ var size : Int ) extends SetNode [A ] {
148
165
149
166
import Node ._
150
167
import SetNode ._
@@ -492,8 +509,8 @@ private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int,
492
509
(this eq node) ||
493
510
(this .nodeMap == node.nodeMap) &&
494
511
(this .dataMap == node.dataMap) &&
495
- java.util.Arrays .equals(this .originalHashes, node.originalHashes) &&
496
- deepContentEquality(this .content, node.content, content.length)
512
+ java.util.Arrays .equals(this .originalHashes, node.originalHashes) &&
513
+ deepContentEquality(this .content, node.content, content.length)
497
514
case _ => false
498
515
}
499
516
@@ -516,13 +533,24 @@ private final class BitmapIndexedSetNode[A](val dataMap: Int, val nodeMap: Int,
516
533
override def hashCode (): Int =
517
534
throw new UnsupportedOperationException (" Trie nodes do not support hashing." )
518
535
536
+ override def copy (): BitmapIndexedSetNode [A ] = {
537
+ val contentClone = new Array [Any ](content.length)
538
+ val dataIndices = bitCount(dataMap)
539
+ Array .copy(content, 0 , contentClone, 0 , dataIndices)
540
+ var i = dataIndices
541
+ while (i < content.length) {
542
+ contentClone(i) = content(i).asInstanceOf [SetNode [A ]].copy()
543
+ i += 1
544
+ }
545
+ new BitmapIndexedSetNode [A ](dataMap, nodeMap, contentClone, originalHashes.clone(), size)
546
+ }
519
547
}
520
548
521
- private final class HashCollisionSetNode [A ](val originalHash : Int , val hash : Int , val content : Vector [A ]) extends SetNode [A ] {
549
+ private final class HashCollisionSetNode [A ](val originalHash : Int , val hash : Int , var content : Vector [Any ]) extends SetNode [A ] {
522
550
523
551
import Node ._
524
552
525
- require(content.size >= 2 )
553
+ require(content.length >= 2 )
526
554
527
555
def contains (element : A , originalHash : Int , hash : Int , shift : Int ): Boolean =
528
556
this .hash == hash && content.contains(element)
@@ -549,7 +577,7 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
549
577
// assert(updatedContent.size == content.size - 1)
550
578
551
579
updatedContent.size match {
552
- case 1 => new BitmapIndexedSetNode [A ](bitposFrom(maskFrom(hash, 0 )), 0 , updatedContent.toArray , Array (originalHash), 1 )
580
+ case 1 => new BitmapIndexedSetNode [A ](bitposFrom(maskFrom(hash, 0 )), 0 , Array ( updatedContent( 0 )) , Array (originalHash), 1 )
553
581
case _ => new HashCollisionSetNode [A ](originalHash, hash, updatedContent)
554
582
}
555
583
}
@@ -563,23 +591,28 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
563
591
564
592
def hasPayload : Boolean = true
565
593
566
- def payloadArity : Int = content.size
594
+ def payloadArity : Int = content.length
567
595
568
- def getPayload (index : Int ): A = content(index)
596
+ def getPayload (index : Int ): A = content(index). asInstanceOf [ A ]
569
597
570
598
override def getHash (index : Int ): Int = originalHash
571
599
572
600
def sizePredicate : Int = SizeMoreThanOne
573
601
574
- def size : Int = content.size
602
+ def size : Int = content.length
575
603
576
- def foreach [U ](f : A => U ): Unit = content.foreach(f)
604
+ def foreach [U ](f : A => U ): Unit = {
605
+ var i = 0
606
+ while (i < content.length) {
607
+ f(getPayload(i))
608
+ i += 1
609
+ }
610
+ }
577
611
578
612
def subsetOf (that : SetNode [A ], shift : Int ): Boolean = if (this eq that) true else that match {
579
613
case node : BitmapIndexedSetNode [A ] => false
580
- case node : HashCollisionSetNode [A ] => {
614
+ case node : HashCollisionSetNode [A ] =>
581
615
this .payloadArity <= node.payloadArity && this .content.forall(node.content.contains)
582
- }
583
616
}
584
617
585
618
override def equals (that : Any ): Boolean =
@@ -595,6 +628,8 @@ private final class HashCollisionSetNode[A](val originalHash: Int, val hash: Int
595
628
override def hashCode (): Int =
596
629
throw new UnsupportedOperationException (" Trie nodes do not support hashing." )
597
630
631
+ override def copy () = new HashCollisionSetNode [A ](originalHash, hash, content)
632
+
598
633
}
599
634
600
635
private final class SetIterator [A ](rootNode : SetNode [A ])
@@ -668,16 +703,211 @@ object HashSet extends IterableFactory[HashSet] {
668
703
case _ => (newBuilder[A ] ++= source).result()
669
704
}
670
705
671
- def newBuilder [A ]: Builder [A , HashSet [A ]] =
672
- new ImmutableBuilder [A , HashSet [A ]](empty) {
673
- def addOne (element : A ): this .type = {
674
- elems = elems + element
675
- this
676
- }
677
- }
706
+ def newBuilder [A ]: Builder [A , HashSet [A ]] = new HashSetBuilder
678
707
679
708
// scalac generates a `readReplace` method to discard the deserialized state (see https://github.com/scala/bug/issues/10412).
680
709
// This prevents it from serializing it in the first place:
681
710
private [this ] def writeObject (out : ObjectOutputStream ): Unit = ()
682
711
private [this ] def readObject (in : ObjectInputStream ): Unit = ()
683
712
}
713
+
714
+ private [collection] final class HashSetBuilder [A ] extends Builder [A , HashSet [A ]] {
715
+ import Node ._
716
+ import SetNode ._
717
+
718
+ private def newEmptyRootNode = new BitmapIndexedSetNode [A ](0 , 0 , Array (), Array (), 0 )
719
+
720
+ /** The last given out HashSet as a return value of `result()`, if any, otherwise null.
721
+ * Indicates that on next add, the elements should be copied to an identical structure, before continuing
722
+ * mutations. */
723
+ private var aliased : HashSet [A ] = _
724
+
725
+ private def isAliased : Boolean = aliased != null
726
+
727
+ /** The root node of the partially build hashmap */
728
+ private var rootNode : SetNode [A ] = newEmptyRootNode
729
+
730
+ /** The cached hash of the partially-built hashmap */
731
+ private var hash : Int = 0
732
+
733
+ /** Inserts element `elem` into array `as` at index `ix`, shifting right the trailing elems */
734
+ private def insertElement (as : Array [Int ], ix : Int , elem : Int ): Array [Int ] = {
735
+ if (ix < 0 ) throw new ArrayIndexOutOfBoundsException
736
+ if (ix > as.length) throw new ArrayIndexOutOfBoundsException
737
+ val result = new Array [Int ](as.length + 1 )
738
+ arraycopy(as, 0 , result, 0 , ix)
739
+ result(ix) = elem
740
+ arraycopy(as, ix, result, ix + 1 , as.length - ix)
741
+ result
742
+ }
743
+
744
+ /** Inserts key-value into the bitmapIndexMapNode. Requires that this is a new key-value pair */
745
+ private def insertValue [A1 >: A ](bm : BitmapIndexedSetNode [A ], bitpos : Int , key : A , originalHash : Int , keyHash : Int ): Unit = {
746
+ val dataIx = bm.dataIndex(bitpos)
747
+ val idx = TupleLength * dataIx
748
+
749
+ val src = bm.content
750
+ val dst = new Array [Any ](src.length + TupleLength )
751
+
752
+ // copy 'src' and insert 2 element(s) at position 'idx'
753
+ arraycopy(src, 0 , dst, 0 , idx)
754
+ dst(idx) = key
755
+ arraycopy(src, idx, dst, idx + TupleLength , src.length - idx)
756
+
757
+ val dstHashes = insertElement(bm.originalHashes, dataIx, originalHash)
758
+
759
+ bm.dataMap = bm.dataMap | bitpos
760
+ bm.content = dst
761
+ bm.originalHashes = dstHashes
762
+ bm.size += 1
763
+ }
764
+
765
+ /** Removes element at index `ix` from array `as`, shifting the trailing elements right */
766
+ private def removeElement (as : Array [Int ], ix : Int ): Array [Int ] = {
767
+ if (ix < 0 ) throw new ArrayIndexOutOfBoundsException
768
+ if (ix > as.length - 1 ) throw new ArrayIndexOutOfBoundsException
769
+ val result = new Array [Int ](as.length - 1 )
770
+ arraycopy(as, 0 , result, 0 , ix)
771
+ arraycopy(as, ix + 1 , result, ix, as.length - ix - 1 )
772
+ result
773
+ }
774
+
775
+ /** Mutates `bm` to replace inline data at bit position `bitpos` with node `node` */
776
+ private def migrateFromInlineToNode (bm : BitmapIndexedSetNode [A ], bitpos : Int , node : SetNode [A ]): Unit = {
777
+ val dataIx = bm.dataIndex(bitpos)
778
+ val idxOld = TupleLength * dataIx
779
+ val idxNew = bm.content.length - TupleLength - bm.nodeIndex(bitpos)
780
+
781
+ val src = bm.content
782
+ val dst = new Array [Any ](src.length - TupleLength + 1 )
783
+
784
+ // copy 'src' and remove 2 element(s) at position 'idxOld' and
785
+ // insert 1 element(s) at position 'idxNew'
786
+ // assert(idxOld <= idxNew)
787
+ arraycopy(src, 0 , dst, 0 , idxOld)
788
+ arraycopy(src, idxOld + TupleLength , dst, idxOld, idxNew - idxOld)
789
+ dst(idxNew) = node
790
+ arraycopy(src, idxNew + TupleLength , dst, idxNew + 1 , src.length - idxNew - TupleLength )
791
+
792
+ val dstHashes = removeElement(bm.originalHashes, dataIx)
793
+
794
+ bm.dataMap ^= bitpos
795
+ bm.nodeMap |= bitpos
796
+ bm.content = dst
797
+ bm.originalHashes = dstHashes
798
+ bm.size = bm.size - 1 + node.size
799
+ }
800
+
801
+ /** Mutates `bm` to replace inline data at bit position `bitpos` with updated key/value */
802
+ private def setValue [A1 >: A ](bm : BitmapIndexedSetNode [A ], bitpos : Int , elem : A ): Unit = {
803
+ val dataIx = bm.dataIndex(bitpos)
804
+ val idx = TupleLength * dataIx
805
+ bm.content(idx) = elem
806
+ }
807
+
808
+ def update (setNode : SetNode [A ], element : A , originalHash : Int , elementHash : Int , shift : Int ): Unit =
809
+ setNode match {
810
+ case bm : BitmapIndexedSetNode [A ] =>
811
+ val mask = maskFrom(elementHash, shift)
812
+ val bitpos = bitposFrom(mask)
813
+
814
+ if ((bm.dataMap & bitpos) != 0 ) {
815
+ val index = indexFrom(bm.dataMap, mask, bitpos)
816
+ val element0 = bm.getPayload(index)
817
+ val element0UnimprovedHash = bm.getHash(index)
818
+
819
+ if (element0UnimprovedHash == originalHash && element0 == element) {
820
+ setValue(bm, bitpos, element0)
821
+ } else {
822
+ val element0Hash = improve(element0UnimprovedHash)
823
+ val subNodeNew = bm.mergeTwoKeyValPairs(element0, element0UnimprovedHash, element0Hash, element, originalHash, elementHash, shift + BitPartitionSize )
824
+ hash += elementHash
825
+ migrateFromInlineToNode(bm, bitpos, subNodeNew)
826
+ }
827
+ } else if ((bm.nodeMap & bitpos) != 0 ) {
828
+ val index = indexFrom(bm.nodeMap, mask, bitpos)
829
+ val subNode = bm.getNode(index)
830
+ val beforeSize = subNode.size
831
+ update(subNode, element, originalHash, elementHash, shift + BitPartitionSize )
832
+ bm.size += subNode.size - beforeSize
833
+ } else {
834
+ insertValue(bm, bitpos, element, originalHash, elementHash)
835
+ hash += elementHash
836
+ }
837
+ case hc : HashCollisionSetNode [A ] =>
838
+ val index = hc.content.indexOf(element)
839
+ if (index < 0 ) {
840
+ hash += elementHash
841
+ hc.content = hc.content.appended(element)
842
+ } else {
843
+ hc.content = hc.content.updated(index, element)
844
+ }
845
+ }
846
+
847
+ /** If currently referencing aliased structure, copy elements to new mutable structure */
848
+ private def ensureUnaliased (): Unit = {
849
+ if (isAliased) copyElems()
850
+ aliased = null
851
+ }
852
+
853
+ /** Copy elements to new mutable structure */
854
+ private def copyElems (): Unit = {
855
+ rootNode = rootNode.copy()
856
+ }
857
+
858
+ override def result (): HashSet [A ] =
859
+ if (rootNode.size == 0 ) {
860
+ HashSet .empty
861
+ } else if (aliased != null ) {
862
+ aliased
863
+ } else {
864
+ aliased = new HashSet (rootNode, hash)
865
+ releaseFence()
866
+ aliased
867
+ }
868
+
869
+ override def addOne (elem : A ): this .type = {
870
+ ensureUnaliased()
871
+ val h = elem.##
872
+ val im = improve(h)
873
+ update(rootNode, elem, h, im, 0 )
874
+ this
875
+ }
876
+
877
+ override def addAll (xs : IterableOnce [A ]) = {
878
+ ensureUnaliased()
879
+ xs match {
880
+ case hm : HashSet [A ] =>
881
+ new ChampBaseIterator (hm.rootNode) {
882
+ while (hasNext) {
883
+ val originalHash = currentValueNode.getHash(currentValueCursor)
884
+ update(
885
+ setNode = rootNode,
886
+ element = currentValueNode.getPayload(currentValueCursor),
887
+ originalHash = originalHash,
888
+ elementHash = improve(originalHash),
889
+ shift = 0
890
+ )
891
+ currentValueCursor += 1
892
+ }
893
+ }
894
+ case other =>
895
+ val it = other.iterator
896
+ while (it.hasNext) addOne(it.next())
897
+ }
898
+
899
+ this
900
+ }
901
+
902
+ override def clear (): Unit = {
903
+ aliased = null
904
+ if (rootNode.size > 0 ) {
905
+ // if rootNode is empty, we will not have given it away anyways, we instead give out the reused Set.empty
906
+ rootNode = newEmptyRootNode
907
+ }
908
+ hash = 0
909
+ }
910
+
911
+ private [collection] def size : Int = rootNode.size
912
+ }
913
+
0 commit comments