Improve `Foldable` instance for `Array`
authorDavid Feuer <David.Feuer@gmail.com>
Thu, 13 Nov 2014 08:10:57 +0000 (09:10 +0100)
committerHerbert Valerio Riedel <hvr@gnu.org>
Thu, 13 Nov 2014 08:16:49 +0000 (09:16 +0100)
Previously, `Array`s were simply converted to lists, and the list
methods used. That works acceptably well for `foldr` and `foldr1`, but
not so sensibly for most other things. Left folds ended up "twisted" the
way they are for lists, leading to surprising performance
characteristics.

Moreover, this implements `length` and `null` so they check the array
size directly.

Finally, a test is added to the testsuite ensuring the overridden
`Foldable` methods agree with their expected default semantics.

Addresses #9763

Reviewed By: hvr, austin

Differential Revision: https://phabricator.haskell.org/D459

libraries/base/Data/Foldable.hs
libraries/base/GHC/Arr.hs
libraries/base/tests/all.T
libraries/base/tests/foldableArray.hs [new file with mode: 0644]
libraries/base/tests/foldableArray.stdout [new file with mode: 0644]

index 8d31b9a..8ad8c2f 100644 (file)
@@ -56,7 +56,10 @@ import Data.Monoid
 import Data.Ord
 import Data.Proxy
 
-import GHC.Arr  ( Array(..), Ix(..), elems )
+import GHC.Arr  ( Array(..), Ix(..), elems, numElements,
+                  foldlElems, foldrElems,
+                  foldlElems', foldrElems',
+                  foldl1Elems, foldr1Elems)
 import GHC.Base hiding ( foldr )
 import GHC.Num  ( Num(..) )
 
@@ -252,10 +255,15 @@ instance Foldable ((,) a) where
     foldr f z (_, y) = f y z
 
 instance Ix i => Foldable (Array i) where
-    foldr f z = List.foldr f z . elems
-    foldl f z = List.foldl f z . elems
-    foldr1 f = List.foldr1 f . elems
-    foldl1 f = List.foldl1 f . elems
+    foldr = foldrElems
+    foldl = foldlElems
+    foldl' = foldlElems'
+    foldr' = foldrElems'
+    foldl1 = foldl1Elems
+    foldr1 = foldr1Elems
+    toList = elems
+    length = numElements
+    null a = numElements a == 0
 
 instance Foldable Proxy where
     foldMap _ _ = mempty
index 0d50993..67702ea 100644 (file)
@@ -1,6 +1,5 @@
 {-# LANGUAGE Unsafe #-}
 {-# LANGUAGE NoImplicitPrelude, MagicHash, UnboxedTuples #-}
-{-# OPTIONS_GHC -funbox-strict-fields #-}
 {-# OPTIONS_HADDOCK hide #-}
 
 -----------------------------------------------------------------------------
@@ -30,6 +29,8 @@ module GHC.Arr (
         newSTArray, boundsSTArray,
         readSTArray, writeSTArray,
         freezeSTArray, thawSTArray,
+        foldlElems, foldlElems', foldl1Elems,
+        foldrElems, foldrElems', foldr1Elems,
 
         -- * Unsafe operations
         fill, done,
@@ -557,6 +558,62 @@ elems :: Ix i => Array i e -> [e]
 elems arr@(Array _ _ n _) =
     [unsafeAt arr i | i <- [0 .. n - 1]]
 
+-- | A right fold over the elements
+{-# INLINABLE foldrElems #-}
+foldrElems :: Ix i => (a -> b -> b) -> b -> Array i a -> b
+foldrElems f b0 = \ arr@(Array _ _ n _) ->
+  let
+    go i | i == n    = b0
+         | otherwise = f (unsafeAt arr i) (go (i+1))
+  in go 0
+
+-- | A left fold over the elements
+{-# INLINABLE foldlElems #-}
+foldlElems :: Ix i => (b -> a -> b) -> b -> Array i a -> b
+foldlElems f b0 = \ arr@(Array _ _ n _) ->
+  let
+    go i | i == (-1) = b0
+         | otherwise = f (go (i-1)) (unsafeAt arr i)
+  in go (n-1)
+
+-- | A strict right fold over the elements
+{-# INLINABLE foldrElems' #-}
+foldrElems' :: Ix i => (a -> b -> b) -> b -> Array i a -> b
+foldrElems' f b0 = \ arr@(Array _ _ n _) ->
+  let
+    go i a | i == (-1) = a
+           | otherwise = go (i-1) (f (unsafeAt arr i) $! a)
+  in go (n-1) b0
+
+-- | A strict left fold over the elements
+{-# INLINABLE foldlElems' #-}
+foldlElems' :: Ix i => (b -> a -> b) -> b -> Array i a -> b
+foldlElems' f b0 = \ arr@(Array _ _ n _) ->
+  let
+    go i a | i == n    = a
+           | otherwise = go (i+1) (a `seq` f a (unsafeAt arr i))
+  in go 0 b0
+
+-- | A left fold over the elements with no starting value
+{-# INLINABLE foldl1Elems #-}
+foldl1Elems :: Ix i => (a -> a -> a) -> Array i a -> a
+foldl1Elems f = \ arr@(Array _ _ n _) ->
+  let
+    go i | i == 0    = unsafeAt arr 0
+         | otherwise = f (go (i-1)) (unsafeAt arr i)
+  in
+    if n == 0 then error "foldl1: empty Array" else go (n-1)
+
+-- | A right fold over the elements with no starting value
+{-# INLINABLE foldr1Elems #-}
+foldr1Elems :: Ix i => (a -> a -> a) -> Array i a -> a
+foldr1Elems f = \ arr@(Array _ _ n _) ->
+  let
+    go i | i == n-1  = unsafeAt arr i
+         | otherwise = f (unsafeAt arr i) (go (i + 1))
+  in
+    if n == 0 then error "foldr1: empty Array" else go 0
+
 -- | The list of associations of an array in index order.
 {-# INLINE assocs #-}
 assocs :: Ix i => Array i e -> [(i, e)]
index d4005b7..fa8ecd3 100644 (file)
@@ -83,6 +83,7 @@ test('enum03',                when(fast(), skip), compile_and_run, ['-cpp'])
 test('enum04',                 normal, compile_and_run, [''])
 test('exceptionsrun001',       normal, compile_and_run, [''])
 test('exceptionsrun002',       normal, compile_and_run, [''])
+test('foldableArray',   normal, compile_and_run, [''])
 test('list001' ,       when(fast(), skip), compile_and_run, [''])
 test('list002', when(fast(), skip), compile_and_run, [''])
 test('list003', when(fast(), skip), compile_and_run, [''])
diff --git a/libraries/base/tests/foldableArray.hs b/libraries/base/tests/foldableArray.hs
new file mode 100644 (file)
index 0000000..5a5041f
--- /dev/null
@@ -0,0 +1,129 @@
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE CPP #-}
+module Main where
+import Prelude hiding (foldr, foldl, foldl', foldr1, foldl1, length, null, sum,
+                       product, all, any, and, or)
+import Data.Foldable
+import Control.Exception
+import Data.Array
+import Data.Foldable
+import Data.Typeable
+import Data.Either
+import Control.Applicative
+import Control.DeepSeq
+#if __GLASGOW_HASKELL__ < 709
+import qualified Data.List as L
+#else
+import qualified Data.OldList as L
+#endif
+
+data BadElementException = BadFirst | BadLast deriving (Show, Typeable, Eq)
+instance Exception BadElementException
+
+newtype ForceDefault f a = ForceDefault (f a)
+instance Foldable f => Foldable (ForceDefault f) where
+  foldMap f (ForceDefault c) = foldMap f c
+
+goodLists, badFronts, badBacks :: [[Integer]]
+goodLists = [[0..n] | n <- [(-1)..5]]
+badFronts = map (throw BadFirst :) goodLists
+badBacks  = map (++ [throw BadLast]) goodLists
+doubleBads = map (\l -> throw BadFirst : l ++ [throw BadLast]) goodLists
+lists =
+        goodLists
+        ++ badFronts
+        ++ badBacks
+        ++ doubleBads
+
+makeArray xs = array (1::Int, length xs) (zip [1..] xs)
+
+arrays = map makeArray lists
+goodArrays = map makeArray goodLists
+
+
+strictCons x y = x + 10*y
+rightLazyCons x y = x
+leftLazyCons x y = y
+
+conses :: [Integer -> Integer -> Integer]
+conses = [(+), strictCons, rightLazyCons, leftLazyCons]
+
+runOneRight :: forall f . Foldable f =>
+                             (forall a b . (a -> b -> b) -> b -> f a -> b) ->
+                             (Integer -> Integer -> Integer) -> f Integer ->
+                             IO (Either BadElementException Integer)
+runOneRight fol f container = try (evaluate (fol f 12 container))
+
+runOne1 :: forall f . Foldable f => (forall a . (a -> a -> a) -> f a -> a) ->
+                              (Integer -> Integer -> Integer) -> f Integer ->
+                              IO (Either BadElementException Integer)
+runOne1 fol f container = try (evaluate (fol f container))
+
+runOneLeft :: forall f . Foldable f =>
+                             (forall a b . (b -> a -> b) -> b -> f a -> b) ->
+                              (Integer -> Integer -> Integer) -> f Integer ->
+                              IO (Either BadElementException Integer)
+runOneLeft fol f container = try (evaluate (fol f 13 container))
+
+runWithAllRight :: forall f . Foldable f =>
+                          (forall a b . (a -> b -> b) -> b -> f a -> b) ->
+                          [f Integer] -> IO [Either BadElementException Integer]
+runWithAllRight fol containers =
+       mapM (uncurry (runOneRight fol)) [(f,c) | f <- conses, c <- containers]
+
+runWithAll1 :: forall f . Foldable f =>
+                        (forall a . (a -> a -> a) -> f a -> a) ->
+                        [f Integer] -> IO [Either BadElementException Integer]
+runWithAll1 fol containers =
+          mapM (uncurry (runOne1 fol)) [(f,c) | f <- conses, c <- containers]
+
+runWithAllLeft :: forall f . Foldable f =>
+                          (forall a b . (b -> a -> b) -> b -> f a -> b) ->
+                          [f Integer] -> IO [Either BadElementException Integer]
+runWithAllLeft fol containers = mapM (uncurry (runOneLeft fol))
+                              [(f,c) | f <- map flip conses, c <- containers]
+
+testWithAllRight :: forall f . Foldable f =>
+                 (forall a b . (a -> b -> b) -> b -> f a -> b) ->
+                  (forall a b . (a -> b -> b) -> b -> ForceDefault f a -> b) ->
+                   [f Integer] -> IO Bool
+testWithAllRight fol1 fol2 containers = (==) <$>
+       runWithAllRight fol1 containers <*>
+           runWithAllRight fol2 (map ForceDefault containers)
+
+testWithAllLeft :: forall f . Foldable f =>
+                   (forall a b . (b -> a -> b) -> b -> f a -> b) ->
+                   (forall a b . (b -> a -> b) -> b -> ForceDefault f a -> b) ->
+                       [f Integer] -> IO Bool
+testWithAllLeft fol1 fol2 containers = (==) <$>
+      runWithAllLeft fol1 containers <*>
+         runWithAllLeft fol2 (map ForceDefault containers)
+
+
+testWithAll1 :: forall f . Foldable f =>
+                        (forall a . (a -> a -> a) -> f a -> a) ->
+                        (forall a . (a -> a -> a) -> ForceDefault f a -> a) ->
+                                               [f Integer] -> IO Bool
+testWithAll1 fol1 fol2 containers =
+  (==) <$> runWithAll1 fol1 containers
+            <*> runWithAll1 fol2 (map ForceDefault containers)
+
+checkup f g cs = map f cs == map g (map ForceDefault cs)
+
+main = do
+         testWithAllRight foldr foldr arrays >>= print
+         testWithAllRight foldr' foldr' arrays >>= print
+         testWithAllLeft foldl foldl arrays >>= print
+         testWithAllLeft foldl' foldl' arrays >>= print
+         testWithAll1 foldl1 foldl1 (filter (not . null) arrays) >>= print
+         testWithAll1 foldr1 foldr1 (filter (not . null) arrays) >>= print
+         -- we won't bother with the fancy laziness tests for the rest
+         print $ checkup length length goodArrays
+         print $ checkup sum sum goodArrays
+         print $ checkup product product goodArrays
+         print $ checkup maximum maximum $ filter (not . null) goodArrays
+         print $ checkup minimum minimum $ filter (not . null) goodArrays
+         print $ checkup toList toList goodArrays
+         print $ checkup null null arrays
diff --git a/libraries/base/tests/foldableArray.stdout b/libraries/base/tests/foldableArray.stdout
new file mode 100644 (file)
index 0000000..50aa4a9
--- /dev/null
@@ -0,0 +1,13 @@
+True
+True
+True
+True
+True
+True
+True
+True
+True
+True
+True
+True
+True