diff --git a/Data/Primitive/Array.hs b/Data/Primitive/Array.hs index fc0301c9..2e740f8d 100644 --- a/Data/Primitive/Array.hs +++ b/Data/Primitive/Array.hs @@ -23,8 +23,8 @@ module Data.Primitive.Array ( cloneArray, cloneMutableArray, sizeofArray, sizeofMutableArray, fromListN, fromList, - mapArray', - traverseArrayP + imapArray, mapArray', imapArray', + itraverseArray, traverseArrayP, itraverseArrayP ) where import Control.Monad.Primitive @@ -533,6 +533,43 @@ traverseArray f = \ !ary -> #-} #endif +-- | Traverse an array with an index. When applicable, 'itraverseArrayP' +-- will likely be more efficient. +-- +-- @since 0.6.4.1 +itraverseArray + :: Applicative f + => (Int -> a -> f b) + -> Array a + -> f (Array b) +itraverseArray f = \ !ary -> + let + !len = sizeofArray ary + go !i + | i == len = pure $ STA $ \mary -> unsafeFreezeArray (MutableArray mary) + | (# x #) <- indexArray## ary i + = liftA2 (\b (STA m) -> STA $ \mary -> + writeArray (MutableArray mary) i b >> m mary) + (f i x) (go (i + 1)) + in if len == 0 + then pure emptyArray + else runSTA len <$> go 0 +{-# INLINE [1] itraverseArray #-} + +{-# RULES +"itraverse/ST" forall (f :: Int -> a -> ST s b). itraverseArray f = + itraverseArrayP f +"itraverse/IO" forall (f :: Int -> a -> IO b). itraverseArray f = + itraverseArrayP f + #-} +#if MIN_VERSION_base(4,8,0) +{-# RULES +"itraverse/Id" forall (f :: Int -> a -> Identity b). itraverseArray f = + (coerce :: (Array a -> Array (Identity b)) + -> Array a -> Identity (Array b)) (imapArray f) + #-} +#endif + -- | This is the fastest, most straightforward way to traverse -- an array, but it only works correctly with a sufficiently -- "affine" 'PrimMonad' instance. In particular, it must only produce @@ -560,7 +597,55 @@ traverseArrayP f = \ !ary -> go 0 mary {-# INLINE traverseArrayP #-} +-- | This is the fastest, most straightforward way to traverse +-- an array, but it only works correctly with a sufficiently +-- "affine" 'PrimMonad' instance. In particular, it must only produce +-- *one* result array. 'Control.Monad.Trans.List.ListT'-transformed +-- monads, for example, will not work right at all. +-- +-- @since 0.6.4.1 +itraverseArrayP + :: PrimMonad m + => (Int -> a -> m b) + -> Array a + -> m (Array b) +itraverseArrayP f = \ !ary -> + let + !sz = sizeofArray ary + go !i !mary + | i == sz + = unsafeFreezeArray mary + | otherwise + = do + a <- indexArrayM ary i + b <- f i a + writeArray mary i b + go (i + 1) mary + in do + mary <- newArray sz badTraverseValue + go 0 mary +{-# INLINE itraverseArrayP #-} + +-- | Lazy map over the elements of the array with an index. +-- +-- @since 0.6.4.1 +imapArray :: (Int -> a -> b) -> Array a -> Array b +imapArray f a = + createArray (sizeofArray a) (die "mapArray'" "impossible") $ \mb -> + let go i | i == sizeofArray a + = return () + | otherwise + = do x <- indexArrayM a i + -- We use indexArrayM here so that we will perform the + -- indexing eagerly even if f is lazy. + let !y = f i x + writeArray mb i y >> go (i+1) + in go 0 +{-# INLINE imapArray #-} + -- | Strict map over the elements of the array. +-- +-- @since 0.6.4.1 mapArray' :: (a -> b) -> Array a -> Array b mapArray' f a = createArray (sizeofArray a) (die "mapArray'" "impossible") $ \mb -> @@ -575,6 +660,23 @@ mapArray' f a = in go 0 {-# INLINE mapArray' #-} +-- | Strict map over the elements of the array with an index. +-- +-- @since 0.6.4.1 +imapArray' :: (Int -> a -> b) -> Array a -> Array b +imapArray' f a = + createArray (sizeofArray a) (die "mapArray'" "impossible") $ \mb -> + let go i | i == sizeofArray a + = return () + | otherwise + = do x <- indexArrayM a i + -- We use indexArrayM here so that we will perform the + -- indexing eagerly even if f is lazy. + let !y = f i x + writeArray mb i y >> go (i+1) + in go 0 +{-# INLINE imapArray' #-} + arrayFromListN :: Int -> [a] -> Array a arrayFromListN n l = createArray n (die "fromListN" "uninitialized element") $ \sma -> diff --git a/Data/Primitive/SmallArray.hs b/Data/Primitive/SmallArray.hs index c737adbd..d29d6c26 100644 --- a/Data/Primitive/SmallArray.hs +++ b/Data/Primitive/SmallArray.hs @@ -58,7 +58,11 @@ module Data.Primitive.SmallArray , smallArrayFromList , smallArrayFromListN , mapSmallArray' + , imapSmallArray + , imapSmallArray' + , itraverseSmallArray , traverseSmallArrayP + , itraverseSmallArrayP ) where @@ -437,7 +441,42 @@ traverseSmallArrayP f (SmallArray ar) = SmallArray `liftM` traverseArrayP f ar #endif {-# INLINE traverseSmallArrayP #-} +-- | This is the fastest, most straightforward way to traverse +-- an array with an index, but it only works correctly with a sufficiently +-- "affine" 'PrimMonad' instance. In particular, it must only produce +-- *one* result array. 'Control.Monad.Trans.List.ListT'-transformed +-- monads, for example, will not work right at all. +-- +-- @since 0.6.4.1 +itraverseSmallArrayP + :: PrimMonad m + => (Int -> a -> m b) + -> SmallArray a + -> m (SmallArray b) +#if HAVE_SMALL_ARRAY +itraverseSmallArrayP f = \ !ary -> + let + !sz = sizeofSmallArray ary + go !i !mary + | i == sz + = unsafeFreezeSmallArray mary + | otherwise + = do + a <- indexSmallArrayM ary i + b <- f i a + writeSmallArray mary i b + go (i + 1) mary + in do + mary <- newSmallArray sz badTraverseValue + go 0 mary +#else +itraverseSmallArrayP f (SmallArray ar) = SmallArray `liftM` itraverseArrayP f ar +#endif +{-# INLINE itraverseSmallArrayP #-} + -- | Strict map over the elements of the array. +-- +-- @since 0.6.4.1 mapSmallArray' :: (a -> b) -> SmallArray a -> SmallArray b #if HAVE_SMALL_ARRAY mapSmallArray' f sa = createSmallArray (length sa) (die "mapSmallArray'" "impossible") $ \smb -> @@ -451,6 +490,37 @@ mapSmallArray' f (SmallArray ar) = SmallArray (mapArray' f ar) #endif {-# INLINE mapSmallArray' #-} +-- | Lazy indexed map over the elements of the array. +-- +-- @since 0.6.4.1 +imapSmallArray :: (Int -> a -> b) -> SmallArray a -> SmallArray b +#if HAVE_SMALL_ARRAY +imapSmallArray f sa = createSmallArray (length sa) (die "mapSmallArray" "impossible") $ \smb -> + fix ? 0 $ \go i -> + when (i < length sa) $ do + x <- indexSmallArrayM sa i + writeSmallArray smb i (f i x) *> go (i+1) +#else +imapSmallArray f (SmallArray ar) = SmallArray (imapArray f ar) +#endif +{-# INLINE imapSmallArray #-} + +-- | Strict indexed map over the elements of the array. +-- +-- @since 0.6.4.1 +imapSmallArray' :: (Int -> a -> b) -> SmallArray a -> SmallArray b +#if HAVE_SMALL_ARRAY +imapSmallArray' f sa = createSmallArray (length sa) (die "imapSmallArray'" "impossible") $ \smb -> + fix ? 0 $ \go i -> + when (i < length sa) $ do + x <- indexSmallArrayM sa i + let !y = f i x + writeSmallArray smb i y *> go (i+1) +#else +imapSmallArray' f (SmallArray ar) = SmallArray (imapArray' f ar) +#endif +{-# INLINE imapSmallArray' #-} + #ifndef HAVE_SMALL_ARRAY runSmallArray :: (forall s. ST s (SmallMutableArray s a)) @@ -705,6 +775,36 @@ traverseSmallArray f = \ !ary -> -> SmallArray a -> Identity (SmallArray b)) (fmap f) #-} +-- | Traverse a 'SmallArray' using the indices. When applicable, +-- 'itraverseSmallArrayP' will likely be more efficient. +-- +-- @since 0.6.4.1 +itraverseSmallArray + :: Applicative f + => (Int -> a -> f b) -> SmallArray a -> f (SmallArray b) +itraverseSmallArray f = \ !ary -> + let + !len = sizeofSmallArray ary + go !i + | i == len + = pure $ STA $ \mary -> unsafeFreezeSmallArray (SmallMutableArray mary) + | (# x #) <- indexSmallArray## ary i + = liftA2 (\b (STA m) -> STA $ \mary -> + writeSmallArray (SmallMutableArray mary) i b >> m mary) + (f i x) (go (i + 1)) + in if len == 0 + then pure emptySmallArray + else runSTA len <$> go 0 +{-# INLINE [1] itraverseSmallArray #-} + +{-# RULES +"itraverse/ST" forall (f :: Int -> a -> ST s b). itraverseSmallArray f = itraverseSmallArrayP f +"itraverse/IO" forall (f :: Int -> a -> IO b). itraverseSmallArray f = itraverseSmallArrayP f +"itraverse/Id" forall (f :: Int -> a -> Identity b). itraverseSmallArray f = + (coerce :: (SmallArray a -> SmallArray (Identity b)) + -> SmallArray a -> Identity (SmallArray b)) (imapSmallArray f) + #-} + instance Functor SmallArray where fmap f sa = createSmallArray (length sa) (die "fmap" "impossible") $ \smb -> diff --git a/changelog.md b/changelog.md index 12ad835e..8f60ea63 100644 --- a/changelog.md +++ b/changelog.md @@ -7,7 +7,7 @@ * Implement `isByteArrayPinned` and `isMutableByteArrayPinned`. * Add `Eq1`, `Ord1`, `Show1`, and `Read1` instances for `Array` and - `SmallArray`. + `SmallArray`. Add indexed maps and traversals. * Improve the test suite. This includes having property tests for typeclasses from `base` such as `Eq`, `Ord`, `Functor`, `Applicative`,