Add MonadZip Seq and Tree instances
authorDavid Feuer <David.Feuer@gmail.com>
Wed, 28 Dec 2016 16:41:56 +0000 (11:41 -0500)
committerDavid Feuer <David.Feuer@gmail.com>
Fri, 30 Dec 2016 05:41:02 +0000 (00:41 -0500)
Add `MonadZip` instances for `Data.Sequence.Seq` and
`Data.Tree.Tree`.

Fixes #374

Data/Sequence/Internal.hs
Data/Tree.hs
tests/seq-properties.hs

index d2bfa04..e44dc9f 100644 (file)
@@ -256,6 +256,9 @@ import Data.Word (Word)
 #endif
 
 import Utils.Containers.Internal.StrictPair (StrictPair (..), toPair)
+#if MIN_VERSION_base(4,4,0)
+import Control.Monad.Zip (MonadZip (..))
+#endif
 
 default ()
 
@@ -4080,6 +4083,109 @@ getSingleton _ = error "getSingleton: Not a singleton."
 -- Zipping
 ------------------------------------------------------------------------
 
+-- MonadZip appeared in base 4.4.0
+#if MIN_VERSION_base(4,4,0)
+-- We use a custom definition of munzip to *try* to avoid retaining
+-- memory longer than necessary. Using the default definition, if
+-- we write
+--
+-- let (xs,ys) = munzip zs
+-- in xs `deepseq` (... ys ...)
+--
+-- then ys will retain the entire zs sequence until ys itself is fully
+-- forced. This implementation attempts to use the selector thunk
+-- optimization to prevent that. Unfortunately, that optimization is
+-- fragile, so we can't actually guarantee anything. If someone finds
+-- a leak, we can try to throw explicit bindings and NOINLINE pragmas
+-- around and see if that fixes it.
+instance MonadZip Seq where
+  mzipWith = zipWith
+  munzip = unzipWith id
+
+class UnzipWith f where
+  unzipWith :: (x -> (a, b)) -> f x -> (f a, f b)
+
+instance UnzipWith Elem where
+#if __GLASGOW_HASKELL__ >= 708
+  unzipWith = coerce
+#else
+  unzipWith f (Elem a) = case f a of (x, y) -> (Elem x, Elem y)
+#endif
+
+-- We're super-lazy here for the sake of efficiency. We want to be able to
+-- reach any element of either result in logarithmic time. If we pattern
+-- match strictly, we'll end up building entire 2-3 trees at once, which
+-- would take linear time.
+instance UnzipWith Node where
+  unzipWith f (Node2 s x y) =
+    case (f x, f y) of
+      (~(x1, x2), ~(y1, y2)) -> (Node2 s x1 y1, Node2 s x2 y2)
+  unzipWith f (Node3 s x y z) =
+    case (f x, f y, f z) of
+      (~(x1, x2), ~(y1, y2), ~(z1, z2)) -> (Node3 s x1 y1 z1, Node3 s x2 y2 z2)
+
+-- We're strict here for the sake of efficiency. The Node instance
+-- is lazy, so we don't particularly need to add an extra thunk on top
+-- of each node. See the note at the Seq instance for an explanation
+-- of why the Digit (Elem a) case is handled specially.
+instance UnzipWith Digit where
+  unzipWith f (One x) =
+    case f x of
+      (x1, x2) -> (One x1, One x2)
+  unzipWith f (Two x y) =
+    case (f x, f y) of
+      ((x1, x2), (y1, y2)) -> (Two x1 y1, Two x2 y2)
+  unzipWith f (Three x y z) =
+    case (f x, f y, f z) of
+      ((x1, x2), (y1, y2), (z1, z2)) -> (Three x1 y1 z1, Three x2 y2 z2)
+  unzipWith f (Four x y z w) =
+    case (f x, f y, f z, f w) of
+      ((x1, x2), (y1, y2), (z1, z2), (w1, w2)) -> (Four x1 y1 z1 w1, Four x2 y2 z2 w2)
+
+instance UnzipWith FingerTree where
+  unzipWith _ EmptyT = (EmptyT, EmptyT)
+  unzipWith f (Single x) = case f x of
+    (x1, x2) -> (Single x1, Single x2)
+  unzipWith f (Deep s pr m sf) =
+    case unzipWith f pr of { (pr1, pr2) ->
+    case unzipWith f sf of { (sf1, sf2) ->
+    case unzipWith (unzipWith f) m of { ~(m1, m2) ->
+      (Deep s pr1 m1 sf1, Deep s pr2 m2 sf2)}}}
+
+-- We need to handle the top level of the sequence specially, to make unzipping behave
+-- well in the presence of undefined elements. For example, what do we want from
+--
+-- munzip [(1,2), undefined, (5,6)]?
+--
+-- The argument could be represented as
+--
+-- Seq $ Deep 3 (One (Elem (1,2))) EmptyT (Two undefined (Elem (5,6)))
+--
+-- or as
+--
+-- Seq $ Deep 3 (Two (Elem (1,2)) undefined) EmptyT (One (Elem (5,6)))
+--
+-- We don't want the tree balance to determine whether we get
+--
+-- ([1, undefined, undefined], [2, undefined, undefined])
+--
+-- or
+--
+-- ([undefined, undefined, 5], [undefined, undefined, 6])
+--
+-- so we pretty much have to be completely lazy in the elements. We could
+-- do this by adding extra laziness to the Digit instance or to the Elem instance,
+-- but either of those would give unnecessary extra laziness lower in the tree.
+instance UnzipWith Seq where
+  unzipWith _f (Seq EmptyT) = (empty, empty)
+  unzipWith f (Seq (Single (Elem x))) = case f x of ~(a, b) -> (singleton a, singleton b)
+  unzipWith f (Seq (Deep s pr m sf)) =
+    case unzipWith (\(Elem x) -> case f x of ~(a, b) -> (Elem a, Elem b)) pr of { (pr1, pr2) ->
+    case unzipWith (\(Elem x) -> case f x of ~(a, b) -> (Elem a, Elem b)) sf of { (sf1, sf2) ->
+    case unzipWith (unzipWith (unzipWith f)) m of { ~(m1, m2) ->
+      (Seq (Deep s pr1 m1 sf1), Seq (Deep s pr2 m2 sf2))}}}
+#endif
+
 -- | /O(min(n1,n2))/.  'zip' takes two sequences and returns a sequence
 -- of corresponding pairs.  If one input is short, excess elements are
 -- discarded from the right end of the longer sequence.
index 89dd42b..5a9ad20 100644 (file)
@@ -61,6 +61,10 @@ import GHC.Generics (Generic, Generic1)
 import GHC.Generics (Generic)
 #endif
 
+#if MIN_VERSION_base(4,4,0)
+import Control.Monad.Zip (MonadZip (..))
+#endif
+
 #if MIN_VERSION_base(4,8,0)
 import Data.Coerce
 #endif
@@ -163,6 +167,15 @@ instance Foldable Tree where
 instance NFData a => NFData (Tree a) where
     rnf (Node x ts) = rnf x `seq` rnf ts
 
+#if MIN_VERSION_base(4,4,0)
+instance MonadZip Tree where
+  mzipWith f (Node a as) (Node b bs)
+    = Node (f a b) (mzipWith (mzipWith f) as bs)
+
+  munzip (Node (a, b) ts) = (Node a as, Node b bs)
+    where (as, bs) = munzip (map munzip ts)
+#endif
+
 -- | Neat 2-dimensional drawing of a tree.
 drawTree :: Tree String -> String
 drawTree  = unlines . draw
index f325f3f..e162bc4 100644 (file)
@@ -1,6 +1,8 @@
 {-# LANGUAGE CPP #-}
 {-# LANGUAGE PatternGuards #-}
 
+#include "containers.h"
+
 import Data.Sequence.Internal
   ( Sized (..)
   , Seq (Seq)
@@ -38,6 +40,10 @@ import Test.QuickCheck.Property
 import Test.QuickCheck.Function
 import Test.Framework
 import Test.Framework.Providers.QuickCheck2
+#if MIN_VERSION_base(4,4,0)
+import Control.Monad.Zip (MonadZip (..))
+#endif
+import Control.DeepSeq (deepseq)
 
 
 main :: IO ()
@@ -121,6 +127,11 @@ main = defaultMain
        , testProperty "zipWith3" prop_zipWith3
        , testProperty "zip4" prop_zip4
        , testProperty "zipWith4" prop_zipWith4
+#if MIN_VERSION_base(4,4,0)
+       , testProperty "mzip-naturality" prop_mzipNaturality
+       , testProperty "mzip-preservation" prop_mzipPreservation
+       , testProperty "munzip-lazy" prop_munzipLazy
+#endif
        , testProperty "<*>" prop_ap
        , testProperty "*>" prop_then
        , testProperty "cycleTaking" prop_cycleTaking
@@ -249,6 +260,20 @@ toListList' xss = toList' xss >>= mapM toList'
 toListPair' :: (Seq a, Seq b) -> Maybe ([a], [b])
 toListPair' (xs, ys) = (,) <$> toList' xs <*> toList' ys
 
+-- Extra "polymorphic" test type
+newtype D = D{ unD :: Integer }
+  deriving ( Eq )
+
+instance Show D where
+  showsPrec n (D x) = showsPrec n x
+
+instance Arbitrary D where
+  arbitrary    = (D . (+1) . abs) `fmap` arbitrary
+  shrink (D x) = [ D x' | x' <- shrink x, x' > 0 ]
+
+instance CoArbitrary D where
+  coarbitrary = coarbitrary . unD
+
 -- instances
 
 prop_fmap :: Seq Int -> Bool
@@ -686,6 +711,35 @@ prop_zipWith4 xs ys zs ts =
     toList' (zipWith4 f xs ys zs ts) ~= Data.List.zipWith4 f (toList xs) (toList ys) (toList zs) (toList ts)
   where f = (,,,)
 
+#if MIN_VERSION_base(4,4,0)
+-- This comes straight from the MonadZip documentation
+prop_mzipNaturality :: Fun A C -> Fun B D -> Seq A -> Seq B -> Property
+prop_mzipNaturality f g sa sb =
+  fmap (apply f *** apply g) (mzip sa sb) ===
+  mzip (apply f <$> sa) (apply g <$> sb)
+
+-- This is a slight optimization of the MonadZip preservation
+-- law that works because sequences don't have any decorations.
+prop_mzipPreservation :: Fun A B -> Seq A -> Property
+prop_mzipPreservation f sa =
+  let sb = fmap (apply f) sa
+  in munzip (mzip sa sb) === (sa, sb)
+
+-- We want to ensure that
+--
+-- munzip xs = xs `seq` (fmap fst x, fmap snd x)
+--
+-- even in the presence of bottoms (alternatives are all balance-
+-- fragile).
+prop_munzipLazy :: Seq (Integer, B) -> Bool
+prop_munzipLazy pairs = deepseq ((`seq` ()) <$> repaired) True
+  where
+    partialpairs = mapWithIndex (\i a -> update i err pairs) pairs
+    firstPieces = fmap (fst . munzip) partialpairs
+    repaired = mapWithIndex (\i s -> update i 10000 s) firstPieces
+    err = error "munzip isn't lazy enough"
+#endif
+
 -- Applicative operations
 
 prop_ap :: Seq A -> Seq B -> Bool