Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 63 additions & 134 deletions src/ArrayFire/Data.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import Foreign.Marshal hiding (void)
import Foreign.Ptr (Ptr)
import Foreign.Storable
import System.IO.Unsafe
import Unsafe.Coerce

import Data.Bits

Expand All @@ -60,7 +59,7 @@ import ArrayFire.Arith
-- [1 1 1 1]
-- -1
bitNot
:: (AFType a, Bits a)
:: forall a. (AFType a, Bits a, Integral a)
=> Array a
-> Array a
bitNot arr = arr `bitXor` ones
Expand All @@ -72,148 +71,78 @@ bitNot arr = arr `bitXor` ones
, fromIntegral d2
, fromIntegral d3
]
(complement zeroBits)
(fromIntegral (complement (zeroBits :: a)))

-- | Creates an 'Array' from a scalar value from given dimensions
--
-- >>> constant @Double [2,2] 2.0
-- ArrayFire Array
-- [2 2 1 1]
-- 2.0000 2.0000
-- 2.0000 2.0000
-- | Creates a constant 'Array' filled with a 'Double' scalar.
-- ArrayFire converts the value to the element type internally.
-- Use 'constantComplex' for complex arrays, 'constantLong' / 'constantULong'
-- for 64-bit integer arrays where the value exceeds 2^53.
constant
:: forall a . AFType a
=> [Int]
-- ^ Dimensions
-> a
-- ^ Scalar value
:: forall a. AFType a
=> [Int] -- ^ Dimensions
-> Double -- ^ Scalar value
-> Array a
{-# NOINLINE constant #-}
constant dims val =
case dtyp of
x | x == c64 ->
cast $ constantComplex dims (unsafeCoerce val :: Complex Double)
| x == c32 ->
cast $ constantComplex dims (unsafeCoerce val :: Complex Float)
| x == s64 ->
cast $ constantLong dims (unsafeCoerce val :: Int)
| x == u64 ->
cast $ constantULong dims (unsafeCoerce val :: Word64)
| x == s32 ->
constant' dims (fromIntegral (unsafeCoerce val :: Int32) :: Double)
| x == s16 ->
constant' dims (fromIntegral (unsafeCoerce val :: Int16) :: Double)
| x == u32 ->
constant' dims (fromIntegral (unsafeCoerce val :: Word32) :: Double)
| x == u8 ->
constant' dims (fromIntegral (unsafeCoerce val :: Word8) :: Double)
| x == u16 ->
constant' dims (fromIntegral (unsafeCoerce val :: Word16) :: Double)
| x == f64 ->
constant' dims (unsafeCoerce val :: Double)
| x == b8 ->
constant' dims (fromIntegral (unsafeCoerce val :: CBool) :: Double)
| x == f32 ->
constant' dims (realToFrac (unsafeCoerce val :: Float))
| otherwise -> error "constant: Invalid array fire type"
unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims) $ \dimArray -> do
throwAFError =<< af_constant ptrPtr val n dimArray dtyp
peek ptrPtr
Array <$> newForeignPtr af_release_array_finalizer ptr
where
n = fromIntegral (length dims)
dtyp = afType (Proxy @a)

-- Creates the array directly with the target dtype: @af_constant@ takes
-- the value as a C double for every non-complex, non-64-bit-integral
-- dtype. Routing through an f64 array and casting (as this used to do)
-- fails with AF_ERR_NO_DBL on OpenCL devices without fp64 support and
-- changes b8 semantics (the cast normalises non-zero values to 1).
constant'
:: [Int]
-- ^ Dimensions
-> Double
-- ^ Scalar value
-> Array a
constant' dims' val' =
unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims') $ \dimArray -> do
throwAFError =<< af_constant ptrPtr val' n dimArray dtyp
peek ptrPtr
Array <$>
newForeignPtr
af_release_array_finalizer
ptr
where
n = fromIntegral (length dims')

-- | Creates an 'Array (Complex Double)' from a scalar val'ue
--
-- @
-- >>> constantComplex [2,2] (2.0 :+ 2.0)
-- @
--
constantComplex
:: forall arr . (Real arr, AFType (Complex arr))
=> [Int]
-- ^ Dimensions
-> Complex arr
-- ^ Scalar val'ue
-> Array (Complex arr)
constantComplex dims' ((realToFrac -> x) :+ (realToFrac -> y)) = unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims') $ \dimArray -> do
throwAFError =<< af_constant_complex ptrPtr x y n dimArray typ
peek ptrPtr
Array <$>
newForeignPtr
af_release_array_finalizer
ptr
where
n = fromIntegral (length dims')
typ = afType (Proxy @(Complex arr))
-- | Creates a constant complex 'Array' from a 'Complex' scalar.
constantComplex
:: forall r. (Real r, AFType (Complex r))
=> [Int] -- ^ Dimensions
-> Complex r -- ^ Scalar value
-> Array (Complex r)
{-# NOINLINE constantComplex #-}
constantComplex dims ((realToFrac -> x) :+ (realToFrac -> y)) =
unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims) $ \dimArray -> do
throwAFError =<< af_constant_complex ptrPtr x y n dimArray typ
peek ptrPtr
Array <$> newForeignPtr af_release_array_finalizer ptr
where
n = fromIntegral (length dims)
typ = afType (Proxy @(Complex r))

-- | Creates an 'Array Int64' from a scalar val'ue
--
-- @
-- >>> constantLong [2,2] 2.0
-- @
--
constantLong
:: [Int]
-- ^ Dimensions
-> Int
-- ^ Scalar val'ue
-> Array Int
constantLong dims' val' = unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims') $ \dimArray -> do
throwAFError =<< af_constant_long ptrPtr (fromIntegral val') n dimArray
peek ptrPtr
Array <$>
newForeignPtr
af_release_array_finalizer
ptr
where
n = fromIntegral (length dims')
-- | Creates a constant 'Array' of 64-bit signed integers.
-- Preserves the full integer value without 'Double' rounding.
constantLong
:: [Int] -- ^ Dimensions
-> Int -- ^ Scalar value
-> Array Int
{-# NOINLINE constantLong #-}
constantLong dims val =
unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims) $ \dimArray -> do
throwAFError =<< af_constant_long ptrPtr (fromIntegral val) n dimArray
peek ptrPtr
Array <$> newForeignPtr af_release_array_finalizer ptr
where n = fromIntegral (length dims)

-- | Creates an 'Array Word64' from a scalar val'ue
--
-- @
-- >>> constantULong [2,2] 2.0
-- @
--
constantULong
:: [Int]
-> Word64
-> Array Word64
constantULong dims' val' = unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims') $ \dimArray -> do
throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val') n dimArray
peek ptrPtr
Array <$>
newForeignPtr
af_release_array_finalizer
ptr
where
n = fromIntegral (length dims')
-- | Creates a constant 'Array' of 64-bit unsigned integers.
-- Preserves the full integer value without 'Double' rounding.
constantULong
:: [Int] -- ^ Dimensions
-> Word64 -- ^ Scalar value
-> Array Word64
{-# NOINLINE constantULong #-}
constantULong dims val =
unsafePerformIO . mask_ $ do
ptr <- calloca $ \ptrPtr -> do
withArray (fromIntegral <$> dims) $ \dimArray -> do
throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val) n dimArray
peek ptrPtr
Array <$> newForeignPtr af_release_array_finalizer ptr
where n = fromIntegral (length dims)

-- | Creates a range of values in an Array
--
Expand Down
55 changes: 14 additions & 41 deletions src/ArrayFire/Internal/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -126,47 +126,20 @@ newtype Window = Window (ForeignPtr ())
class Storable a => AFType a where
afType :: Proxy a -> AFDtype

instance AFType Double where
afType Proxy = f64

instance AFType Float where
afType Proxy = f32

instance AFType (Complex Double) where
afType Proxy = c64

instance AFType (Complex Float) where
afType Proxy = c32

instance AFType CBool where
afType Proxy = b8

instance AFType Int32 where
afType Proxy = s32

instance AFType Word32 where
afType Proxy = u32

instance AFType Word8 where
afType Proxy = u8

instance AFType Int64 where
afType Proxy = s64

instance AFType Int where
afType Proxy = s64

instance AFType Int16 where
afType Proxy = s16

instance AFType Word16 where
afType Proxy = u16

instance AFType Word64 where
afType Proxy = u64

instance AFType Word where
afType Proxy = u64
instance AFType Double where afType Proxy = f64
instance AFType Float where afType Proxy = f32
instance AFType (Complex Double) where afType Proxy = c64
instance AFType (Complex Float) where afType Proxy = c32
instance AFType CBool where afType Proxy = b8
instance AFType Int32 where afType Proxy = s32
instance AFType Word32 where afType Proxy = u32
instance AFType Word8 where afType Proxy = u8
instance AFType Int64 where afType Proxy = s64
instance AFType Int where afType Proxy = s64
instance AFType Int16 where afType Proxy = s16
instance AFType Word16 where afType Proxy = u16
instance AFType Word64 where afType Proxy = u64
instance AFType Word where afType Proxy = u64

-- | Maps an ArrayFire element type to the scalar type returned by whole-array
-- reductions (e.g. 'meanAll', 'det'). Real and integral element types yield
Expand Down
79 changes: 61 additions & 18 deletions test/ArrayFire/DataSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,67 @@ import ArrayFire hiding (not)
spec :: Spec
spec =
describe "Data tests" $ do
it "Should create constant Array" $ do
constant @Float [1] 1 `shouldBe` 1
constant @Double [1] 1 `shouldBe` 1
constant @Int16 [1] 1 `shouldBe` 1
constant @Int32 [1] 1 `shouldBe` 1
constant @Int64 [1] 1 `shouldBe` 1
constant @Int [1] 1 `shouldBe` 1
constant @Word16 [1] 1 `shouldBe` 1
constant @Word32 [1] 1 `shouldBe` 1
constant @Word64 [1] 1 `shouldBe` 1
constant @Word [1] 1 `shouldBe` 1
constant @CBool [1] 1 `shouldBe` 1
constant @(Complex Double) [1] (1.0 :+ 1.0)
`shouldBe`
constant @(Complex Double) [1] (1.0 :+ 1.0)
constant @(Complex Float) [1] (1.0 :+ 1.0)
`shouldBe`
constant @(Complex Float) [1] (1.0 :+ 1.0)
describe "constant" $ do
it "creates a scalar Float array" $
constant @Float [1] 1 `shouldBe` scalar @Float 1
it "creates a scalar Double array" $
constant @Double [1] 2.5 `shouldBe` scalar @Double 2.5
it "creates a scalar Int16 array" $
constant @Int16 [1] 42 `shouldBe` scalar @Int16 42
it "creates a scalar Int32 array" $
constant @Int32 [1] (-7) `shouldBe` scalar @Int32 (-7)
it "creates a scalar Word8 array" $
constant @Word8 [1] 255 `shouldBe` scalar @Word8 255
it "creates a scalar Word16 array" $
constant @Word16 [1] 1000 `shouldBe` scalar @Word16 1000
it "creates a scalar Word32 array" $
constant @Word32 [1] 999 `shouldBe` scalar @Word32 999
it "creates a CBool array" $
constant @CBool [1] 1 `shouldBe` scalar @CBool 1
it "creates a multi-element array with correct shape" $ do
let a = constant @Double [3,3] 0
getDims a `shouldBe` (3,3,1,1)
it "all elements equal the scalar value" $
constant @Float [4] 3.14 `shouldBe` vector @Float 4 [3.14, 3.14, 3.14, 3.14]

describe "constantComplex" $ do
it "creates a Complex Double array preserving imaginary part" $
constantComplex [1] (1.0 :+ 2.0)
`shouldBe` scalar @(Complex Double) (1.0 :+ 2.0)
it "creates a Complex Float array preserving imaginary part" $
constantComplex [1] (3.0 :+ 4.0 :: Complex Float)
`shouldBe` scalar @(Complex Float) (3.0 :+ 4.0)
it "creates a zero complex array" $
constantComplex [2] (0 :+ 0 :: Complex Double)
`shouldBe` vector @(Complex Double) 2 [0, 0]
it "handles purely real complex values" $
constantComplex [1] (5.0 :+ 0.0 :: Complex Double)
`shouldBe` scalar @(Complex Double) (5.0 :+ 0.0)
it "handles purely imaginary complex values" $
constantComplex [1] (0.0 :+ 7.0 :: Complex Double)
`shouldBe` scalar @(Complex Double) (0.0 :+ 7.0)

describe "constantLong" $ do
it "creates an Int array with value 1" $
constantLong [1] 1 `shouldBe` scalar @Int 1
it "creates an Int array with a negative value" $
constantLong [1] (-42) `shouldBe` scalar @Int (-42)
it "preserves maxBound :: Int without rounding" $
constantLong [1] maxBound `shouldBe` scalar @Int maxBound
it "preserves minBound :: Int without rounding" $
constantLong [1] minBound `shouldBe` scalar @Int minBound
it "creates a multi-element array" $
constantLong [3] 7 `shouldBe` vector @Int 3 [7, 7, 7]

describe "constantULong" $ do
it "creates a Word64 array with value 1" $
constantULong [1] 1 `shouldBe` scalar @Word64 1
it "creates a Word64 array with value 0" $
constantULong [1] 0 `shouldBe` scalar @Word64 0
it "preserves maxBound :: Word64 without rounding" $
constantULong [1] maxBound `shouldBe` scalar @Word64 maxBound
it "creates a multi-element array" $
constantULong [3] 100 `shouldBe` vector @Word64 3 [100, 100, 100]

describe "arange" $ do
it "generates a sequence along dim 0 for a 1D array" $ do
Expand Down
Loading