Skip to content

Improve array ops and equality tests #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions Data/HashMap/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module Data.HashMap.Array
, insert
, insertM
, delete
, sameArray1

, unsafeFreeze
, unsafeThaw
Expand Down Expand Up @@ -81,6 +82,7 @@ import qualified Prelude
#endif

import Data.HashMap.Unsafe (runST)
import Control.Monad ((>=>))


#if __GLASGOW_HASKELL__ >= 710
Expand Down Expand Up @@ -170,11 +172,25 @@ instance Show a => Show (Array a) where

-- Determines whether two arrays have the same memory address.
-- This is more reliable than testing pointer equality on the
-- Array wrappers, but it's still somewhat bogus.
-- Array wrappers, but it's still slightly bogus.
unsafeSameArray :: Array a -> Array b -> Bool
unsafeSameArray (Array xs) (Array ys) =
tagToEnum# (unsafeCoerce# reallyUnsafePtrEquality# xs ys)

sameArray1 :: (a -> b -> Bool) -> Array a -> Array b -> Bool
sameArray1 eq !xs0 !ys0
| lenxs /= lenys = False
| otherwise = go 0 xs0 ys0
where
go !k !xs !ys
| k == lenxs = True
| (# x #) <- index# xs k
, (# y #) <- index# ys k
= eq x y && go (k + 1) xs ys

!lenxs = length xs0
!lenys = length ys0

length :: Array a -> Int
length ary = I# (sizeofArray# (unArray ary))
{-# INLINE length #-}
Expand Down Expand Up @@ -259,6 +275,12 @@ index ary _i@(I# i#) =
case indexArray# (unArray ary) i# of (# b #) -> b
{-# INLINE index #-}

index# :: Array a -> Int -> (# a #)
index# ary _i@(I# i#) =
CHECK_BOUNDS("index#", length ary, _i)
indexArray# (unArray ary) i#
{-# INLINE index# #-}

indexM :: Array a -> Int -> ST s a
indexM ary _i@(I# i#) =
CHECK_BOUNDS("indexM", length ary, _i)
Expand Down Expand Up @@ -361,16 +383,20 @@ foldl' :: (b -> a -> b) -> b -> Array a -> b
foldl' f = \ z0 ary0 -> go ary0 (length ary0) 0 z0
where
go ary n i !z
| i >= n = z
| otherwise = go ary n (i+1) (f z (index ary i))
| i >= n = z
| otherwise
= case index# ary i of
(# x #) -> go ary n (i+1) (f z x)
{-# INLINE foldl' #-}

foldr :: (a -> b -> b) -> b -> Array a -> b
foldr f = \ z0 ary0 -> go ary0 (length ary0) 0 z0
where
go ary n i z
| i >= n = z
| otherwise = f (index ary i) (go ary n (i+1) z)
| i >= n = z
| otherwise
= case index# ary i of
(# x #) -> f x (go ary n (i+1) z)
{-# INLINE foldr #-}

undefinedElem :: a
Expand Down Expand Up @@ -412,7 +438,8 @@ map f = \ ary ->
go ary mary i n
| i >= n = return mary
| otherwise = do
write mary i $ f (index ary i)
x <- indexM ary i
write mary i $ f x
go ary mary (i+1) n
{-# INLINE map #-}

Expand All @@ -427,7 +454,8 @@ map' f = \ ary ->
go ary mary i n
| i >= n = return mary
| otherwise = do
write mary i $! f (index ary i)
x <- indexM ary i
write mary i $! f x
go ary mary (i+1) n
{-# INLINE map' #-}

Expand All @@ -448,7 +476,27 @@ toList = foldr (:) []
traverse :: Applicative f => (a -> f b) -> Array a -> f (Array b)
traverse f = \ ary -> fromList (length ary) `fmap`
Traversable.traverse f (toList ary)
{-# INLINE traverse #-}
{-# INLINE [1] traverse #-}

-- Traversing in ST, we don't need to make a list; we
-- can just do it directly.
traverseST :: (a -> ST s b) -> Array a -> ST s (Array b)
traverseST f = \ ary0 ->
let
!len = length ary0
go k mary
| k == len = return mary
| otherwise = do
x <- indexM ary0 k
y <- f x
write mary k y
go (k + 1) mary
in new_ len >>= (go 0 >=> unsafeFreeze)
{-# INLINE traverseST #-}

{-# RULES
"traverse/ST" forall f. traverse f = traverseST f
#-}

filter :: (a -> Bool) -> Array a -> Array a
filter p = \ ary ->
Expand Down
57 changes: 48 additions & 9 deletions Data/HashMap/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ module Data.HashMap.Base
, updateOrConcatWithKey
, filterMapAux
, equalKeys
, equalKeys1
, lookupRecordCollision
, LookupRes(..)
, insert'
Expand Down Expand Up @@ -269,18 +270,37 @@ instance Traversable (HashMap k) where

#if MIN_VERSION_base(4,9,0)
instance Eq2 HashMap where
liftEq2 = equal
liftEq2 = equal2

instance Eq k => Eq1 (HashMap k) where
liftEq = equal (==)
liftEq = equal1
#endif

instance (Eq k, Eq v) => Eq (HashMap k v) where
(==) = equal (==) (==)

equal :: (k -> k' -> Bool) -> (v -> v' -> Bool)
(==) = equal1 (==)

-- We rely on there being no Empty constructors in the tree!
-- This ensures that two equal HashMaps will have the same
-- shape, modulo the order of entries in Collisions.
equal1 :: Eq k
=> (v -> v' -> Bool)
-> HashMap k v -> HashMap k v' -> Bool
equal1 eq = go
where
go Empty Empty = True
go (BitmapIndexed bm1 ary1) (BitmapIndexed bm2 ary2)
= bm1 == bm2 && A.sameArray1 go ary1 ary2
go (Leaf h1 l1) (Leaf h2 l2) = h1 == h2 && leafEq l1 l2
go (Full ary1) (Full ary2) = A.sameArray1 go ary1 ary2
go (Collision h1 ary1) (Collision h2 ary2)
= h1 == h2 && isPermutationBy leafEq (A.toList ary1) (A.toList ary2)
go _ _ = False

leafEq (L k1 v1) (L k2 v2) = k1 == k2 && eq v1 v2

equal2 :: (k -> k' -> Bool) -> (v -> v' -> Bool)
-> HashMap k v -> HashMap k' v' -> Bool
equal eqk eqv t1 t2 = go (toList' t1 []) (toList' t2 [])
equal2 eqk eqv t1 t2 = go (toList' t1 []) (toList' t2 [])
where
-- If the two trees are the same, then their lists of 'Leaf's and
-- 'Collision's read from left to right should be the same (modulo the
Expand Down Expand Up @@ -339,8 +359,8 @@ cmp cmpk cmpv t1 t2 = go (toList' t1 []) (toList' t2 [])
leafCompare (L k v) (L k' v') = cmpk k k' `mappend` cmpv v v'

-- Same as 'equal' but doesn't compare the values.
equalKeys :: (k -> k' -> Bool) -> HashMap k v -> HashMap k' v' -> Bool
equalKeys eq t1 t2 = go (toList' t1 []) (toList' t2 [])
equalKeys1 :: (k -> k' -> Bool) -> HashMap k v -> HashMap k' v' -> Bool
equalKeys1 eq t1 t2 = go (toList' t1 []) (toList' t2 [])
where
go (Leaf k1 l1 : tl1) (Leaf k2 l2 : tl2)
| k1 == k2 && leafEq l1 l2
Expand All @@ -354,6 +374,22 @@ equalKeys eq t1 t2 = go (toList' t1 []) (toList' t2 [])

leafEq (L k _) (L k' _) = eq k k'

-- Same as 'equal1' but doesn't compare the values.
equalKeys :: Eq k => HashMap k v -> HashMap k v' -> Bool
equalKeys = go
where
go :: Eq k => HashMap k v -> HashMap k v' -> Bool
go Empty Empty = True
go (BitmapIndexed bm1 ary1) (BitmapIndexed bm2 ary2)
= bm1 == bm2 && A.sameArray1 go ary1 ary2
go (Leaf h1 l1) (Leaf h2 l2) = h1 == h2 && leafEq l1 l2
go (Full ary1) (Full ary2) = A.sameArray1 go ary1 ary2
go (Collision h1 ary1) (Collision h2 ary2)
= h1 == h2 && isPermutationBy leafEq (A.toList ary1) (A.toList ary2)
go _ _ = False

leafEq (L k1 _) (L k2 _) = k1 == k2

#if MIN_VERSION_hashable(1,2,5)
instance H.Hashable2 HashMap where
liftHashWithSalt2 hk hv salt hm = go salt (toList' hm [])
Expand Down Expand Up @@ -1746,8 +1782,11 @@ updateOrConcatWith f = updateOrConcatWithKey (const f)

updateOrConcatWithKey :: Eq k => (k -> v -> v -> v) -> A.Array (Leaf k v) -> A.Array (Leaf k v) -> A.Array (Leaf k v)
updateOrConcatWithKey f ary1 ary2 = A.run $ do
-- TODO: instead of mapping and then folding, should we traverse?
-- We'll have to be careful to avoid allocating pairs or similar.

-- first: look up the position of each element of ary2 in ary1
let indices = A.map (\(L k _) -> indexOf k ary1) ary2
let indices = A.map' (\(L k _) -> indexOf k ary1) ary2
-- that tells us how large the overlap is:
-- count number of Nothing constructors
let nOnly2 = A.foldl' (\n -> maybe (n+1) (const n)) 0 indices
Expand Down
6 changes: 3 additions & 3 deletions Data/HashSet/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ module Data.HashSet.Base

import Control.DeepSeq (NFData(..))
import Data.Data hiding (Typeable)
import Data.HashMap.Base (HashMap, foldrWithKey, equalKeys)
import Data.HashMap.Base (HashMap, foldrWithKey, equalKeys, equalKeys1)
import Data.Hashable (Hashable(hashWithSalt))
#if __GLASGOW_HASKELL__ >= 711
import Data.Semigroup (Semigroup(..))
Expand Down Expand Up @@ -120,12 +120,12 @@ instance (NFData a) => NFData (HashSet a) where
{-# INLINE rnf #-}

instance (Eq a) => Eq (HashSet a) where
HashSet a == HashSet b = equalKeys (==) a b
HashSet a == HashSet b = equalKeys a b
{-# INLINE (==) #-}

#if MIN_VERSION_base(4,9,0)
instance Eq1 HashSet where
liftEq eq (HashSet a) (HashSet b) = equalKeys eq a b
liftEq eq (HashSet a) (HashSet b) = equalKeys1 eq a b
#endif

instance (Ord a) => Ord (HashSet a) where
Expand Down
11 changes: 10 additions & 1 deletion unordered-containers.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ copyright: 2010-2014 Johan Tibell
2010 Edward Z. Yang
category: Data
build-type: Simple
cabal-version: >=1.8
cabal-version: >=1.10
extra-source-files: CHANGES.md
tested-with: GHC==8.4.1, GHC==8.2.2, GHC==8.0.2, GHC==7.10.3, GHC==7.8.4

Expand All @@ -45,6 +45,8 @@ library
deepseq >= 1.1,
hashable >= 1.0.1.1 && < 1.3

default-language: Haskell2010

other-extensions:
RoleAnnotations,
UnboxedTuples,
Expand Down Expand Up @@ -76,6 +78,7 @@ test-suite hashmap-lazy-properties
test-framework-quickcheck2 >= 0.2.9,
unordered-containers

default-language: Haskell2010
ghc-options: -Wall
cpp-options: -DASSERTS

Expand All @@ -93,6 +96,7 @@ test-suite hashmap-strict-properties
test-framework-quickcheck2 >= 0.2.9,
unordered-containers

default-language: Haskell2010
ghc-options: -Wall
cpp-options: -DASSERTS -DSTRICT

Expand All @@ -110,6 +114,7 @@ test-suite hashset-properties
test-framework-quickcheck2 >= 0.2.9,
unordered-containers

default-language: Haskell2010
ghc-options: -Wall
cpp-options: -DASSERTS

Expand All @@ -127,6 +132,7 @@ test-suite list-tests
test-framework >= 0.3.3,
test-framework-quickcheck2 >= 0.2.9

default-language: Haskell2010
ghc-options: -Wall
cpp-options: -DASSERTS

Expand All @@ -145,6 +151,7 @@ test-suite regressions
test-framework-quickcheck2,
unordered-containers

default-language: Haskell2010
ghc-options: -Wall
cpp-options: -DASSERTS

Expand All @@ -163,6 +170,7 @@ test-suite strictness-properties
test-framework-quickcheck2 >= 0.2.9,
unordered-containers

default-language: Haskell2010
ghc-options: -Wall
cpp-options: -DASSERTS

Expand Down Expand Up @@ -200,6 +208,7 @@ benchmark benchmarks
mtl,
random

default-language: Haskell2010
ghc-options: -Wall -O2 -rtsopts -fwarn-tabs -ferror-spans
if flag(debug)
cpp-options: -DASSERTS
Expand Down