This repository was archived by the owner on Nov 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathNumeric.hs
98 lines (89 loc) · 2.7 KB
/
Numeric.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module AOC.Common.Numeric (
fft
, ifft
, convolve
, rconvolve
, zconvolve
, FFT.FFTWReal
) where
import Data.Complex
import GHC.TypeNats
import qualified Data.Array.CArray as CA
import qualified Data.Array.IArray as IA
import qualified Data.Ix as Ix
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
import qualified Foreign.Storable as FS
import qualified Math.FFT as FFT
import qualified Math.FFT.Base as FFT
fft :: (FFT.FFTWReal a, VG.Vector v (Complex a))
=> SVG.Vector v n (Complex a)
-> SVG.Vector v n (Complex a)
fft = SVG.withVectorUnsafe $
fromCA
. FFT.dft
. toCA
ifft
:: (FFT.FFTWReal a, VG.Vector v (Complex a))
=> SVG.Vector v n (Complex a)
-> SVG.Vector v n (Complex a)
ifft = SVG.withVectorUnsafe $
fromCA
. FFT.idft
. toCA
fromCA
:: (FS.Storable a, VG.Vector v (Complex a))
=> CA.CArray Int (Complex a)
-> v (Complex a)
fromCA v = VG.generate (Ix.rangeSize (IA.bounds v)) (v IA.!)
toCA
:: (FS.Storable a, VG.Vector v (Complex a))
=> v (Complex a)
-> CA.CArray Int (Complex a)
toCA v = IA.listArray (0, VG.length v - 1) (VG.toList v)
-- | FFT-based convolution
convolve
:: ( VG.Vector v (Complex a)
, KnownNat n, 1 <= n
, KnownNat m, 1 <= m
, FFT.FFTWReal a
)
=> SVG.Vector v n (Complex a)
-> SVG.Vector v m (Complex a)
-> SVG.Vector v (n + m - 1) (Complex a)
convolve x y = ifft $ fft x' * fft y'
where
x' = x SVG.++ 0
y' = y SVG.++ 0
-- | FFT-based real-valued convolution
rconvolve
:: ( VG.Vector v (Complex a)
, VG.Vector v a
, KnownNat n, 1 <= n
, KnownNat m, 1 <= m
, FFT.FFTWReal a
)
=> SVG.Vector v n a
-> SVG.Vector v m a
-> SVG.Vector v (n + m - 1) a
rconvolve x y = SVG.map realPart $ convolve (SVG.map (:+ 0) x) (SVG.map (:+ 0) y)
-- | FFT-based integral convolution
zconvolve
:: ( VG.Vector v (Complex Double)
, VG.Vector v Double
, VG.Vector v a
, KnownNat n, 1 <= n
, KnownNat m, 1 <= m
, Integral a
)
=> SVG.Vector v n a
-> SVG.Vector v m a
-> SVG.Vector v (n + m - 1) a
zconvolve x y = SVG.map (round @Double) $
rconvolve (SVG.map fromIntegral x) (SVG.map fromIntegral y)