Write a liftA2 for Seq (#397)
authorDavid Feuer <David.Feuer@gmail.com>
Wed, 8 Feb 2017 18:21:22 +0000 (13:21 -0500)
committerGitHub <noreply@github.com>
Wed, 8 Feb 2017 18:21:22 +0000 (13:21 -0500)
* Use a custom `liftA2` implementation for Data.Sequence for
  base 4.10.

* Write RULES for `liftA2`.

Data/Sequence/Internal.hs
tests/Makefile
tests/seq-properties.hs

index 30cab4c..865c9e8 100644 (file)
@@ -174,6 +174,7 @@ module Data.Sequence.Internal (
     traverseWithIndex, -- :: Applicative f => (Int -> a -> f b) -> Seq a -> f (Seq b)
     reverse,        -- :: Seq a -> Seq a
     intersperse,    -- :: a -> Seq a -> Seq a
+    liftA2Seq,      -- :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
     -- ** Zips
     zip,            -- :: Seq a -> Seq b -> Seq (a, b)
     zipWith,        -- :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
@@ -432,24 +433,41 @@ instance Monad Seq where
 instance Applicative Seq where
     pure = singleton
     xs *> ys = cycleNTimes (length xs) ys
+    (<*>) = apSeq
+#if MIN_VERSION_base(4,10,0)
+    liftA2 = liftA2Seq
+#endif
 
-    fs <*> xs@(Seq xsFT) = case viewl fs of
-      EmptyL -> empty
-      firstf :< fs' -> case viewr fs' of
-        EmptyR -> fmap firstf xs
-        Seq fs''FT :> lastf -> case rigidify xsFT of
-             RigidEmpty -> empty
-             RigidOne (Elem x) -> fmap ($x) fs
-             RigidTwo (Elem x1) (Elem x2) ->
-                Seq $ ap2FT firstf fs''FT lastf (x1, x2)
-             RigidThree (Elem x1) (Elem x2) (Elem x3) ->
-                Seq $ ap3FT firstf fs''FT lastf (x1, x2, x3)
-             RigidFull r@(Rigid s pr _m sf) -> Seq $
-                   Deep (s * length fs)
-                        (fmap (fmap firstf) (nodeToDigit pr))
-                        (aptyMiddle (fmap firstf) (fmap lastf) fmap fs''FT r)
-                        (fmap (fmap lastf) (nodeToDigit sf))
+apSeq :: Seq (a -> b) -> Seq a -> Seq b
+apSeq fs xs@(Seq xsFT) = case viewl fs of
+  EmptyL -> empty
+  firstf :< fs' -> case viewr fs' of
+    EmptyR -> fmap firstf xs
+    Seq fs''FT :> lastf -> case rigidify xsFT of
+         RigidEmpty -> empty
+         RigidOne (Elem x) -> fmap ($x) fs
+         RigidTwo (Elem x1) (Elem x2) ->
+            Seq $ ap2FT firstf fs''FT lastf (x1, x2)
+         RigidThree (Elem x1) (Elem x2) (Elem x3) ->
+            Seq $ ap3FT firstf fs''FT lastf (x1, x2, x3)
+         RigidFull r@(Rigid s pr _m sf) -> Seq $
+               Deep (s * length fs)
+                    (fmap (fmap firstf) (nodeToDigit pr))
+                    (aptyMiddle (fmap firstf) (fmap lastf) fmap fs''FT r)
+                    (fmap (fmap lastf) (nodeToDigit sf))
+{-# NOINLINE [1] apSeq #-}
 
+{-# RULES
+"ap/fmap" forall f xs ys . apSeq (fmapSeq f xs) ys = liftA2Seq f xs ys
+"fmap/ap" forall f gs xs . fmapSeq f (gs `apSeq` xs) =
+                             liftA2Seq (\g x -> f (g x)) gs xs
+"fmap/liftA2" forall f g m n . fmapSeq f (liftA2Seq g m n) =
+                       liftA2Seq (\x y -> f (g x y)) m n
+"liftA2/fmap1" forall f g m n . liftA2Seq f (fmapSeq g m) n =
+                       liftA2Seq (\x y -> f (g x) y) m n
+"liftA2/fmap2" forall f g m n . liftA2Seq f m (fmapSeq g n) =
+                       liftA2Seq (\x y -> f x (g y)) m n
+ #-}
 
 ap2FT :: (a -> b) -> FingerTree (Elem (a->b)) -> (a -> b) -> (a,a) -> FingerTree (Elem b)
 ap2FT firstf fs lastf (x,y) =
@@ -464,6 +482,46 @@ ap3FT firstf fs lastf (x,y,z) = Deep (size fs * 3 + 6)
                         (mapMulFT 3 (\(Elem f) -> Node3 3 (Elem (f x)) (Elem (f y)) (Elem (f z))) fs)
                         (Three (Elem $ lastf x) (Elem $ lastf y) (Elem $ lastf z))
 
+lift2FT :: (a -> b -> c) -> a -> FingerTree (Elem a) -> a -> (b,b) -> FingerTree (Elem c)
+lift2FT f firstx xs lastx (y1,y2) =
+                 Deep (size xs * 2 + 4)
+                      (Two (Elem $ f firstx y1) (Elem $ f firstx y2))
+                      (mapMulFT 2 (\(Elem x) -> Node2 2 (Elem (f x y1)) (Elem (f x y2))) xs)
+                      (Two (Elem $ f lastx y1) (Elem $ f lastx y2))
+
+lift3FT :: (a -> b -> c) -> a -> FingerTree (Elem a) -> a -> (b,b,b) -> FingerTree (Elem c)
+lift3FT f firstx xs lastx (y1,y2,y3) =
+                 Deep (size xs * 3 + 6)
+                      (Three (Elem $ f firstx y1) (Elem $ f firstx y2) (Elem $ f firstx y3))
+                      (mapMulFT 3 (\(Elem x) -> Node3 3 (Elem (f x y1)) (Elem (f x y2)) (Elem (f x y3))) xs)
+                      (Three (Elem $ f lastx y1) (Elem $ f lastx y2) (Elem $ f lastx y3))
+
+liftA2Seq :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
+liftA2Seq f xs ys@(Seq ysFT) = case viewl xs of
+  EmptyL -> empty
+  firstx :< xs' -> case viewr xs' of
+    EmptyR -> f firstx <$> ys
+    Seq xs''FT :> lastx -> case rigidify ysFT of
+      RigidEmpty -> empty
+      RigidOne (Elem y) -> fmap (\x -> f x y) xs
+      RigidTwo (Elem y1) (Elem y2) ->
+        Seq $ lift2FT f firstx xs''FT lastx (y1, y2)
+      RigidThree (Elem y1) (Elem y2) (Elem y3) ->
+        Seq $ lift3FT f firstx xs''FT lastx (y1, y2, y3)
+      RigidFull r@(Rigid s pr _m sf) -> Seq $
+        Deep (s * length xs)
+             (fmap (fmap (f firstx)) (nodeToDigit pr))
+             (aptyMiddle (fmap (f firstx)) (fmap (f lastx)) (lift_elem f) xs''FT r)
+             (fmap (fmap (f lastx)) (nodeToDigit sf))
+  where
+    lift_elem :: (a -> b -> c) -> a -> Elem b -> Elem c
+#if __GLASGOW_HASKELL__ >= 708
+    lift_elem = coerce
+#else
+    lift_elem f x (Elem y) = Elem (f x y)
+#endif
+{-# NOINLINE [1] liftA2Seq #-}
+
 
 data Rigidified a = RigidEmpty
                   | RigidOne a
@@ -514,12 +572,12 @@ type Digit23 a = Node a
 -- class, but as it is we have to build up 'map23' explicitly through the
 -- recursion.
 aptyMiddle
-  :: (c -> d)
-     -> (c -> d)
-     -> ((a -> b) -> c -> d)
-     -> FingerTree (Elem (a -> b))
-     -> Rigid c
-     -> FingerTree (Node d)
+  :: (b -> c)
+     -> (b -> c)
+     -> (a -> b -> c)
+     -> FingerTree (Elem a)
+     -> Rigid b
+     -> FingerTree (Node c)
 
 -- Not at the bottom yet
 
index 69d08ed..231c863 100644 (file)
@@ -8,10 +8,10 @@
 all:
 
 %-properties: %-properties.hs force
-       ghc -O2 -DTESTING $< -i.. -o $@ -outputdir tmp
+       ghc -I../include -O2 -DTESTING $< -i.. -o $@ -outputdir tmp
 
 %-strict-properties: %-properties.hs force
-       ghc -O2 -DTESTING -DSTRICT $< -o $@ -i.. -outputdir tmp
+       ghc -I../include -O2 -DTESTING -DSTRICT $< -o $@ -i.. -outputdir tmp
 
 .PHONY: force clean
 force:
index e162bc4..35cdab2 100644 (file)
@@ -16,7 +16,7 @@ import Data.Sequence.Internal
 
 import Data.Sequence
 
-import Control.Applicative (Applicative(..))
+import Control.Applicative (Applicative(..), liftA2)
 import Control.Arrow ((***))
 import Control.Monad.Trans.State.Strict
 import Data.Array (listArray)
@@ -133,6 +133,8 @@ main = defaultMain
        , testProperty "munzip-lazy" prop_munzipLazy
 #endif
        , testProperty "<*>" prop_ap
+       , testProperty "<*> NOINLINE" prop_ap_NOINLINE
+       , testProperty "liftA2" prop_liftA2
        , testProperty "*>" prop_then
        , testProperty "cycleTaking" prop_cycleTaking
        , testProperty "intersperse" prop_intersperse
@@ -746,6 +748,20 @@ prop_ap :: Seq A -> Seq B -> Bool
 prop_ap xs ys =
     toList' ((,) <$> xs <*> ys) ~= ( (,) <$> toList xs <*> toList ys )
 
+prop_ap_NOINLINE :: Seq A -> Seq B -> Bool
+prop_ap_NOINLINE xs ys =
+    toList' (((,) <$> xs) `apNOINLINE` ys) ~= ( (,) <$> toList xs <*> toList ys )
+
+{-# NOINLINE apNOINLINE #-}
+apNOINLINE :: Seq (a -> b) -> Seq a -> Seq b
+apNOINLINE fs xs = fs <*> xs
+
+prop_liftA2 :: Seq A -> Seq B -> Property
+prop_liftA2 xs ys = valid q .&&.
+    toList q === liftA2 (,) (toList xs) (toList ys)
+  where
+    q = liftA2 (,) xs ys
+
 prop_then :: Seq A -> Seq B -> Bool
 prop_then xs ys =
     toList' (xs *> ys) ~= (toList xs *> toList ys)