From 39f7bf1d963f4c97240549db1278c4ccf3af6965 Mon Sep 17 00:00:00 2001 From: meooow25 Date: Sun, 17 Nov 2024 18:55:27 +0530 Subject: [PATCH] Make fromListN functions good consumers ...in terms of list fusion. --- Data/Primitive/Array.hs | 20 +++++++++++++++----- Data/Primitive/ByteArray.hs | 13 ++++++++----- Data/Primitive/PrimArray.hs | 13 ++++++++----- Data/Primitive/SmallArray.hs | 11 ++++++----- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/Data/Primitive/Array.hs b/Data/Primitive/Array.hs index c527c60..e8a5812 100644 --- a/Data/Primitive/Array.hs +++ b/Data/Primitive/Array.hs @@ -586,18 +586,28 @@ mapArray' f a = -- | Create an array from a list of a known length. If the length -- of the list does not match the given length, this throws an exception. + +-- Note [fromListN] +-- ~~~~~~~~~~~~~~~~ +-- We want arrayFromListN to be a "good consumer" in list fusion, so we define +-- the function using foldr and inline it to help fire fusion rules. +-- If fusion occurs with a "good producer", it may reduce to a fold on some +-- structure. In certain cases (such as for Data.Set) GHC is not be able to +-- optimize the index to an unboxed Int# (see GHC #24628), so we explicitly use +-- an Int# here. arrayFromListN :: Int -> [a] -> Array a +{-# INLINE arrayFromListN #-} arrayFromListN n l = createArray n (die "fromListN" "uninitialized element") $ \sma -> - let go !ix [] = if ix == n + let z ix# = if I# ix# == n then return () else die "fromListN" "list length less than specified size" - go !ix (x : xs) = if ix < n + f x k = GHC.Exts.oneShot $ \ix# -> if I# ix# < n then do - writeArray sma ix x - go (ix+1) xs + writeArray sma (I# ix#) x + k (ix# +# 1#) else die "fromListN" "list length greater than specified size" - in go 0 l + in foldr f z l 0# -- | Create an array from a list. arrayFromList :: [a] -> Array a diff --git a/Data/Primitive/ByteArray.hs b/Data/Primitive/ByteArray.hs index 4cd935b..e6abdba 100644 --- a/Data/Primitive/ByteArray.hs +++ b/Data/Primitive/ByteArray.hs @@ -378,17 +378,20 @@ byteArrayFromList xs = byteArrayFromListN (length xs) xs -- | Create a 'ByteArray' from a list of a known length. If the length -- of the list does not match the given length, this throws an exception. + +-- See Note [fromListN] in Data.Primitive.Array byteArrayFromListN :: forall a. Prim a => Int -> [a] -> ByteArray +{-# INLINE byteArrayFromListN #-} byteArrayFromListN n ys = createByteArray (n * sizeOfType @a) $ \marr -> - let go !ix [] = if ix == n + let z ix# = if I# ix# == n then return () else die "byteArrayFromListN" "list length less than specified size" - go !ix (x : xs) = if ix < n + f x k = GHC.Exts.oneShot $ \ix# -> if I# ix# < n then do - writeByteArray marr ix x - go (ix + 1) xs + writeByteArray marr (I# ix#) x + k (ix# +# 1#) else die "byteArrayFromListN" "list length greater than specified size" - in go 0 ys + in foldr f z ys 0# unI# :: Int -> Int# unI# (I# n#) = n# diff --git a/Data/Primitive/PrimArray.hs b/Data/Primitive/PrimArray.hs index 5c2b9e0..9a2f761 100644 --- a/Data/Primitive/PrimArray.hs +++ b/Data/Primitive/PrimArray.hs @@ -234,17 +234,20 @@ primArrayFromList vs = primArrayFromListN (L.length vs) vs -- | Create a 'PrimArray' from a list of a known length. If the length -- of the list does not match the given length, this throws an exception. + +-- See Note [fromListN] in Data.Primitive.Array primArrayFromListN :: forall a. Prim a => Int -> [a] -> PrimArray a +{-# INLINE primArrayFromListN #-} primArrayFromListN len vs = createPrimArray len $ \arr -> - let go [] !ix = if ix == len + let z ix# = if I# ix# == len then return () else die "fromListN" "list length less than specified size" - go (a : as) !ix = if ix < len + f a k = GHC.Exts.oneShot $ \ix# -> if I# ix# < len then do - writePrimArray arr ix a - go as (ix + 1) + writePrimArray arr (I# ix#) a + k (ix# +# 1#) else die "fromListN" "list length greater than specified size" - in go vs 0 + in foldr f z vs 0# -- | Convert a 'PrimArray' to a list. {-# INLINE primArrayToList #-} diff --git a/Data/Primitive/SmallArray.hs b/Data/Primitive/SmallArray.hs index 74c8fbf..cf3d1da 100644 --- a/Data/Primitive/SmallArray.hs +++ b/Data/Primitive/SmallArray.hs @@ -924,18 +924,19 @@ instance (Typeable s, Typeable a) => Data (SmallMutableArray s a) where -- | Create a 'SmallArray' from a list of a known length. If the length -- of the list does not match the given length, this throws an exception. smallArrayFromListN :: Int -> [a] -> SmallArray a +{-# INLINE smallArrayFromListN #-} smallArrayFromListN n l = createSmallArray n (die "smallArrayFromListN" "uninitialized element") $ \sma -> - let go !ix [] = if ix == n + let z ix# = if I# ix# == n then return () else die "smallArrayFromListN" "list length less than specified size" - go !ix (x : xs) = if ix < n + f x k = GHC.Exts.oneShot $ \ix# -> if I# ix# < n then do - writeSmallArray sma ix x - go (ix + 1) xs + writeSmallArray sma (I# ix#) x + k (ix# +# 1#) else die "smallArrayFromListN" "list length greater than specified size" - in go 0 l + in foldr f z l 0# -- | Create a 'SmallArray' from a list. smallArrayFromList :: [a] -> SmallArray a