Skip to content

Commit cba2e43

Browse files
authored
Improve array ops and equality tests (#193)
* Avoid relying on inlining, demand analysis, and luck to avoid suspending array lookups. * Make `Eq` and `Eq1` instances take advantage of the near-structural nature of `HashMap` equality. * Add a function for traversing arrays in `ST`, and a rule to rewrite to it.
1 parent d672a11 commit cba2e43

File tree

4 files changed

+117
-21
lines changed

4 files changed

+117
-21
lines changed

Data/HashMap/Array.hs

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ module Data.HashMap.Array
2828
, insert
2929
, insertM
3030
, delete
31+
, sameArray1
3132

3233
, unsafeFreeze
3334
, unsafeThaw
@@ -81,6 +82,7 @@ import qualified Prelude
8182
#endif
8283

8384
import Data.HashMap.Unsafe (runST)
85+
import Control.Monad ((>=>))
8486

8587

8688
#if __GLASGOW_HASKELL__ >= 710
@@ -170,11 +172,25 @@ instance Show a => Show (Array a) where
170172

171173
-- Determines whether two arrays have the same memory address.
172174
-- This is more reliable than testing pointer equality on the
173-
-- Array wrappers, but it's still somewhat bogus.
175+
-- Array wrappers, but it's still slightly bogus.
174176
unsafeSameArray :: Array a -> Array b -> Bool
175177
unsafeSameArray (Array xs) (Array ys) =
176178
tagToEnum# (unsafeCoerce# reallyUnsafePtrEquality# xs ys)
177179

180+
sameArray1 :: (a -> b -> Bool) -> Array a -> Array b -> Bool
181+
sameArray1 eq !xs0 !ys0
182+
| lenxs /= lenys = False
183+
| otherwise = go 0 xs0 ys0
184+
where
185+
go !k !xs !ys
186+
| k == lenxs = True
187+
| (# x #) <- index# xs k
188+
, (# y #) <- index# ys k
189+
= eq x y && go (k + 1) xs ys
190+
191+
!lenxs = length xs0
192+
!lenys = length ys0
193+
178194
length :: Array a -> Int
179195
length ary = I# (sizeofArray# (unArray ary))
180196
{-# INLINE length #-}
@@ -259,6 +275,12 @@ index ary _i@(I# i#) =
259275
case indexArray# (unArray ary) i# of (# b #) -> b
260276
{-# INLINE index #-}
261277

278+
index# :: Array a -> Int -> (# a #)
279+
index# ary _i@(I# i#) =
280+
CHECK_BOUNDS("index#", length ary, _i)
281+
indexArray# (unArray ary) i#
282+
{-# INLINE index# #-}
283+
262284
indexM :: Array a -> Int -> ST s a
263285
indexM ary _i@(I# i#) =
264286
CHECK_BOUNDS("indexM", length ary, _i)
@@ -361,16 +383,20 @@ foldl' :: (b -> a -> b) -> b -> Array a -> b
361383
foldl' f = \ z0 ary0 -> go ary0 (length ary0) 0 z0
362384
where
363385
go ary n i !z
364-
| i >= n = z
365-
| otherwise = go ary n (i+1) (f z (index ary i))
386+
| i >= n = z
387+
| otherwise
388+
= case index# ary i of
389+
(# x #) -> go ary n (i+1) (f z x)
366390
{-# INLINE foldl' #-}
367391

368392
foldr :: (a -> b -> b) -> b -> Array a -> b
369393
foldr f = \ z0 ary0 -> go ary0 (length ary0) 0 z0
370394
where
371395
go ary n i z
372-
| i >= n = z
373-
| otherwise = f (index ary i) (go ary n (i+1) z)
396+
| i >= n = z
397+
| otherwise
398+
= case index# ary i of
399+
(# x #) -> f x (go ary n (i+1) z)
374400
{-# INLINE foldr #-}
375401

376402
undefinedElem :: a
@@ -412,7 +438,8 @@ map f = \ ary ->
412438
go ary mary i n
413439
| i >= n = return mary
414440
| otherwise = do
415-
write mary i $ f (index ary i)
441+
x <- indexM ary i
442+
write mary i $ f x
416443
go ary mary (i+1) n
417444
{-# INLINE map #-}
418445

@@ -427,7 +454,8 @@ map' f = \ ary ->
427454
go ary mary i n
428455
| i >= n = return mary
429456
| otherwise = do
430-
write mary i $! f (index ary i)
457+
x <- indexM ary i
458+
write mary i $! f x
431459
go ary mary (i+1) n
432460
{-# INLINE map' #-}
433461

@@ -448,7 +476,27 @@ toList = foldr (:) []
448476
traverse :: Applicative f => (a -> f b) -> Array a -> f (Array b)
449477
traverse f = \ ary -> fromList (length ary) `fmap`
450478
Traversable.traverse f (toList ary)
451-
{-# INLINE traverse #-}
479+
{-# INLINE [1] traverse #-}
480+
481+
-- Traversing in ST, we don't need to make a list; we
482+
-- can just do it directly.
483+
traverseST :: (a -> ST s b) -> Array a -> ST s (Array b)
484+
traverseST f = \ ary0 ->
485+
let
486+
!len = length ary0
487+
go k mary
488+
| k == len = return mary
489+
| otherwise = do
490+
x <- indexM ary0 k
491+
y <- f x
492+
write mary k y
493+
go (k + 1) mary
494+
in new_ len >>= (go 0 >=> unsafeFreeze)
495+
{-# INLINE traverseST #-}
496+
497+
{-# RULES
498+
"traverse/ST" forall f. traverse f = traverseST f
499+
#-}
452500

453501
filter :: (a -> Bool) -> Array a -> Array a
454502
filter p = \ ary ->

Data/HashMap/Base.hs

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ module Data.HashMap.Base
9696
, updateOrConcatWithKey
9797
, filterMapAux
9898
, equalKeys
99+
, equalKeys1
99100
, lookupRecordCollision
100101
, LookupRes(..)
101102
, insert'
@@ -269,18 +270,37 @@ instance Traversable (HashMap k) where
269270

270271
#if MIN_VERSION_base(4,9,0)
271272
instance Eq2 HashMap where
272-
liftEq2 = equal
273+
liftEq2 = equal2
273274

274275
instance Eq k => Eq1 (HashMap k) where
275-
liftEq = equal (==)
276+
liftEq = equal1
276277
#endif
277278

278279
instance (Eq k, Eq v) => Eq (HashMap k v) where
279-
(==) = equal (==) (==)
280-
281-
equal :: (k -> k' -> Bool) -> (v -> v' -> Bool)
280+
(==) = equal1 (==)
281+
282+
-- We rely on there being no Empty constructors in the tree!
283+
-- This ensures that two equal HashMaps will have the same
284+
-- shape, modulo the order of entries in Collisions.
285+
equal1 :: Eq k
286+
=> (v -> v' -> Bool)
287+
-> HashMap k v -> HashMap k v' -> Bool
288+
equal1 eq = go
289+
where
290+
go Empty Empty = True
291+
go (BitmapIndexed bm1 ary1) (BitmapIndexed bm2 ary2)
292+
= bm1 == bm2 && A.sameArray1 go ary1 ary2
293+
go (Leaf h1 l1) (Leaf h2 l2) = h1 == h2 && leafEq l1 l2
294+
go (Full ary1) (Full ary2) = A.sameArray1 go ary1 ary2
295+
go (Collision h1 ary1) (Collision h2 ary2)
296+
= h1 == h2 && isPermutationBy leafEq (A.toList ary1) (A.toList ary2)
297+
go _ _ = False
298+
299+
leafEq (L k1 v1) (L k2 v2) = k1 == k2 && eq v1 v2
300+
301+
equal2 :: (k -> k' -> Bool) -> (v -> v' -> Bool)
282302
-> HashMap k v -> HashMap k' v' -> Bool
283-
equal eqk eqv t1 t2 = go (toList' t1 []) (toList' t2 [])
303+
equal2 eqk eqv t1 t2 = go (toList' t1 []) (toList' t2 [])
284304
where
285305
-- If the two trees are the same, then their lists of 'Leaf's and
286306
-- 'Collision's read from left to right should be the same (modulo the
@@ -339,8 +359,8 @@ cmp cmpk cmpv t1 t2 = go (toList' t1 []) (toList' t2 [])
339359
leafCompare (L k v) (L k' v') = cmpk k k' `mappend` cmpv v v'
340360

341361
-- Same as 'equal' but doesn't compare the values.
342-
equalKeys :: (k -> k' -> Bool) -> HashMap k v -> HashMap k' v' -> Bool
343-
equalKeys eq t1 t2 = go (toList' t1 []) (toList' t2 [])
362+
equalKeys1 :: (k -> k' -> Bool) -> HashMap k v -> HashMap k' v' -> Bool
363+
equalKeys1 eq t1 t2 = go (toList' t1 []) (toList' t2 [])
344364
where
345365
go (Leaf k1 l1 : tl1) (Leaf k2 l2 : tl2)
346366
| k1 == k2 && leafEq l1 l2
@@ -354,6 +374,22 @@ equalKeys eq t1 t2 = go (toList' t1 []) (toList' t2 [])
354374

355375
leafEq (L k _) (L k' _) = eq k k'
356376

377+
-- Same as 'equal1' but doesn't compare the values.
378+
equalKeys :: Eq k => HashMap k v -> HashMap k v' -> Bool
379+
equalKeys = go
380+
where
381+
go :: Eq k => HashMap k v -> HashMap k v' -> Bool
382+
go Empty Empty = True
383+
go (BitmapIndexed bm1 ary1) (BitmapIndexed bm2 ary2)
384+
= bm1 == bm2 && A.sameArray1 go ary1 ary2
385+
go (Leaf h1 l1) (Leaf h2 l2) = h1 == h2 && leafEq l1 l2
386+
go (Full ary1) (Full ary2) = A.sameArray1 go ary1 ary2
387+
go (Collision h1 ary1) (Collision h2 ary2)
388+
= h1 == h2 && isPermutationBy leafEq (A.toList ary1) (A.toList ary2)
389+
go _ _ = False
390+
391+
leafEq (L k1 _) (L k2 _) = k1 == k2
392+
357393
#if MIN_VERSION_hashable(1,2,5)
358394
instance H.Hashable2 HashMap where
359395
liftHashWithSalt2 hk hv salt hm = go salt (toList' hm [])
@@ -1746,8 +1782,11 @@ updateOrConcatWith f = updateOrConcatWithKey (const f)
17461782

17471783
updateOrConcatWithKey :: Eq k => (k -> v -> v -> v) -> A.Array (Leaf k v) -> A.Array (Leaf k v) -> A.Array (Leaf k v)
17481784
updateOrConcatWithKey f ary1 ary2 = A.run $ do
1785+
-- TODO: instead of mapping and then folding, should we traverse?
1786+
-- We'll have to be careful to avoid allocating pairs or similar.
1787+
17491788
-- first: look up the position of each element of ary2 in ary1
1750-
let indices = A.map (\(L k _) -> indexOf k ary1) ary2
1789+
let indices = A.map' (\(L k _) -> indexOf k ary1) ary2
17511790
-- that tells us how large the overlap is:
17521791
-- count number of Nothing constructors
17531792
let nOnly2 = A.foldl' (\n -> maybe (n+1) (const n)) 0 indices

Data/HashSet/Base.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ module Data.HashSet.Base
7777

7878
import Control.DeepSeq (NFData(..))
7979
import Data.Data hiding (Typeable)
80-
import Data.HashMap.Base (HashMap, foldrWithKey, equalKeys)
80+
import Data.HashMap.Base (HashMap, foldrWithKey, equalKeys, equalKeys1)
8181
import Data.Hashable (Hashable(hashWithSalt))
8282
#if __GLASGOW_HASKELL__ >= 711
8383
import Data.Semigroup (Semigroup(..))
@@ -120,12 +120,12 @@ instance (NFData a) => NFData (HashSet a) where
120120
{-# INLINE rnf #-}
121121

122122
instance (Eq a) => Eq (HashSet a) where
123-
HashSet a == HashSet b = equalKeys (==) a b
123+
HashSet a == HashSet b = equalKeys a b
124124
{-# INLINE (==) #-}
125125

126126
#if MIN_VERSION_base(4,9,0)
127127
instance Eq1 HashSet where
128-
liftEq eq (HashSet a) (HashSet b) = equalKeys eq a b
128+
liftEq eq (HashSet a) (HashSet b) = equalKeys1 eq a b
129129
#endif
130130

131131
instance (Ord a) => Ord (HashSet a) where

unordered-containers.cabal

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ copyright: 2010-2014 Johan Tibell
1818
2010 Edward Z. Yang
1919
category: Data
2020
build-type: Simple
21-
cabal-version: >=1.8
21+
cabal-version: >=1.10
2222
extra-source-files: CHANGES.md
2323
tested-with: GHC==8.4.1, GHC==8.2.2, GHC==8.0.2, GHC==7.10.3, GHC==7.8.4
2424

@@ -45,6 +45,8 @@ library
4545
deepseq >= 1.1,
4646
hashable >= 1.0.1.1 && < 1.3
4747

48+
default-language: Haskell2010
49+
4850
other-extensions:
4951
RoleAnnotations,
5052
UnboxedTuples,
@@ -76,6 +78,7 @@ test-suite hashmap-lazy-properties
7678
test-framework-quickcheck2 >= 0.2.9,
7779
unordered-containers
7880

81+
default-language: Haskell2010
7982
ghc-options: -Wall
8083
cpp-options: -DASSERTS
8184

@@ -93,6 +96,7 @@ test-suite hashmap-strict-properties
9396
test-framework-quickcheck2 >= 0.2.9,
9497
unordered-containers
9598

99+
default-language: Haskell2010
96100
ghc-options: -Wall
97101
cpp-options: -DASSERTS -DSTRICT
98102

@@ -110,6 +114,7 @@ test-suite hashset-properties
110114
test-framework-quickcheck2 >= 0.2.9,
111115
unordered-containers
112116

117+
default-language: Haskell2010
113118
ghc-options: -Wall
114119
cpp-options: -DASSERTS
115120

@@ -127,6 +132,7 @@ test-suite list-tests
127132
test-framework >= 0.3.3,
128133
test-framework-quickcheck2 >= 0.2.9
129134

135+
default-language: Haskell2010
130136
ghc-options: -Wall
131137
cpp-options: -DASSERTS
132138

@@ -145,6 +151,7 @@ test-suite regressions
145151
test-framework-quickcheck2,
146152
unordered-containers
147153

154+
default-language: Haskell2010
148155
ghc-options: -Wall
149156
cpp-options: -DASSERTS
150157

@@ -163,6 +170,7 @@ test-suite strictness-properties
163170
test-framework-quickcheck2 >= 0.2.9,
164171
unordered-containers
165172

173+
default-language: Haskell2010
166174
ghc-options: -Wall
167175
cpp-options: -DASSERTS
168176

@@ -200,6 +208,7 @@ benchmark benchmarks
200208
mtl,
201209
random
202210

211+
default-language: Haskell2010
203212
ghc-options: -Wall -O2 -rtsopts -fwarn-tabs -ferror-spans
204213
if flag(debug)
205214
cpp-options: -DASSERTS

0 commit comments

Comments
 (0)