diff --git a/examples/smartfft/Project.toml b/examples/smartfft/Project.toml new file mode 100644 index 0000000..5a771b3 --- /dev/null +++ b/examples/smartfft/Project.toml @@ -0,0 +1,10 @@ +[deps] +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" +DaggerWebDash = "cfc5aa84-1a2a-41ab-b391-ede92ecae40c" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" +TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" diff --git a/examples/smartfft/SmartSolve.jl b/examples/smartfft/SmartSolve.jl new file mode 100644 index 0000000..13579c8 --- /dev/null +++ b/examples/smartfft/SmartSolve.jl @@ -0,0 +1,87 @@ +module SmartSolve + +using Plots + +const ALGORITHMS = Dict{Symbol, Vector{Function}}() + +function add_candidate_algorithm!(alg_name::Symbol, alg_func::Function) + if !haskey(ALGORITHMS, alg_name) + ALGORITHMS[alg_name] = Function[] + end + push!(ALGORITHMS[alg_name], alg_func) +end + +const BENCHMARK_RESULTS = Dict{Symbol, Dict{Function, Dict{Int, Vector{Float64}}}}() + +function benchmark_algorithms!(f::Function, alg_name::Symbol, Ns) + if !haskey(BENCHMARK_RESULTS, alg_name) + BENCHMARK_RESULTS[alg_name] = Dict{Function, Dict{Int, Vector{Float64}}}() + end + for alg in ALGORITHMS[alg_name] + @info "Benchmarking $alg_name with $alg" + if !haskey(BENCHMARK_RESULTS[alg_name], alg) + BENCHMARK_RESULTS[alg_name][alg] = Dict{Int, Vector{Float64}}() + end + for n in Ns + if !haskey(BENCHMARK_RESULTS[alg_name][alg], n) + BENCHMARK_RESULTS[alg_name][alg][n] = Float64[] + end + times = Float64[] + for iter in 1:10 + t = f(alg, n) + push!(times, t) + end + BENCHMARK_RESULTS[alg_name][alg][n] = times + end + end +end + +function select_best_algorithm(alg_name::Symbol, input::Array) + if !haskey(BENCHMARK_RESULTS, alg_name) + error("No benchmark results for $alg_name") + end + best_alg = nothing + best_time = Inf + first_alg = first(ALGORITHMS[alg_name]) + Ns = sort(collect(keys(BENCHMARK_RESULTS[alg_name][first_alg]))) + closest_n = Ns[argmin(abs.(Ns .- size(input, 1)))] + for alg in ALGORITHMS[alg_name] + times = BENCHMARK_RESULTS[alg_name][alg][closest_n] + time = minimum(times) + if time < best_time + best_time = time + best_alg = alg + end + end + return best_alg +end + +function benchmark_results(alg_name::Symbol) + if !haskey(BENCHMARK_RESULTS, alg_name) + error("No benchmark results for $alg_name") + end + first_alg = first(ALGORITHMS[alg_name]) + results = zeros(length(ALGORITHMS[alg_name]), length(BENCHMARK_RESULTS[alg_name][first_alg])) + i = 0 + for alg in ALGORITHMS[alg_name] + i += 1 + j = 0 + for n in sort(collect(keys(BENCHMARK_RESULTS[alg_name][alg]))) + j += 1 + times = BENCHMARK_RESULTS[alg_name][alg][n] + results[i, j] = minimum(times) + end + end + return results +end + +function plot_benchmark_results(alg_name::Symbol) + results = benchmark_results(alg_name) + algs = permutedims(map(string, collect(ALGORITHMS[alg_name]))) + return plot(results', yscale=:log10, + title="FFT Performance", xlabel="Input Size", ylabel="Time (s)", + legend=:outertopright, + labels=algs) +end + +end # module SmartSolve \ No newline at end of file diff --git a/examples/smartfft/generate_gantt.jl b/examples/smartfft/generate_gantt.jl new file mode 100644 index 0000000..308030e --- /dev/null +++ b/examples/smartfft/generate_gantt.jl @@ -0,0 +1,75 @@ +using Revise +using Dagger +#using DaggerWebDash +#import DaggerWebDash: GanttPlot, LinePlot +#using TimespanLogging +using DataFrames +using Plots +using GraphViz +using LinearAlgebra +using ScopedValues + +const IS_SETUP = Ref{Bool}(false) +function setup() + if IS_SETUP[] + return + end + ml = TimespanLogging.MultiEventLog() + ml[:core] = TimespanLogging.Events.CoreMetrics() + ml[:id] = TimespanLogging.Events.IDMetrics() + ml[:wsat] = Dagger.Events.WorkerSaturation() + ml[:loadavg] = TimespanLogging.Events.CPULoadAverages() + ml[:bytes] = Dagger.Events.BytesAllocd() + ml[:mem] = TimespanLogging.Events.MemoryFree() + ml[:esat] = TimespanLogging.Events.EventSaturation() + ml[:psat] = Dagger.Events.ProcessorSaturation() + lw = TimespanLogging.Events.LogWindow(20*10^9, :core) + logs_df = DataFrame([key=>[] for key in keys(ml.consumers)]...) + ts = DaggerWebDash.TableStorage(logs_df) + push!(lw.creation_handlers, ts) + d3r = DaggerWebDash.D3Renderer(8080; seek_store=ts) + push!(lw.creation_handlers, d3r) + push!(lw.deletion_handlers, d3r) + push!(d3r, GanttPlot(:core, :id, :esat, :psat; title="Overview")) + # TODO: push!(d3r, ProfileViewer(:core, :profile, "Profile Viewer")) + #push!(d3r, LinePlot(:core, :wsat, "Worker Saturation", "Running Tasks")) + #push!(d3r, LinePlot(:core, :loadavg, "CPU Load Average", "Average Running Threads")) + #push!(d3r, LinePlot(:core, :bytes, "Allocated Bytes", "Bytes")) + #push!(d3r, LinePlot(:core, :mem, "Available Memory", "% Free")) + #push!(d3r, GraphPlot(:core, :id, :timeline, :profile, "DAG")) + ml.aggregators[:d3r] = d3r + ml.aggregators[:logwindow] = lw + Dagger.Sch.eager_context().log_sink = ml + IS_SETUP[] = true +end + +function generate(A, B, C; nt=Threads.nthreads()) + @info "Generating for $nt threads" + Dagger.enable_logging!(;metrics=false, all_task_deps=true) + GC.enable(false) + Dagger.with_options(;scope=Dagger.scope(;threads=1:nt)) do + @with Dagger.DATADEPS_SCHEDULE_REUSABLE => false begin + @time mul!(C, A, B) + end + end + GC.enable(true) + logs = Dagger.fetch_logs!() + Dagger.disable_logging!() + display(Dagger.render_logs(logs, :plots_gantt; target=:execution, color_init_hash=UInt(2))) + return +end +function generate_all() + #setup() + Dagger.MemPool.MEM_RESERVED[] = 0 + A = rand(Blocks(512, 512), 2048, 2048) + B = rand(Blocks(512, 512), 2048, 2048) + C = zeros(Blocks(512, 512), 2048, 2048) + wait.(A.chunks) + wait.(B.chunks) + wait.(C.chunks) + generate(A, B, C, nt=1) + generate(A, B, C, nt=4) + generate(A, B, C, nt=8) + generate(A, B, C, nt=16) +end +generate_all() \ No newline at end of file diff --git a/examples/smartfft/step1_candidate_algorithms_1d.jl b/examples/smartfft/step1_candidate_algorithms_1d.jl new file mode 100644 index 0000000..9d9465b --- /dev/null +++ b/examples/smartfft/step1_candidate_algorithms_1d.jl @@ -0,0 +1,289 @@ +using FFTW + +### 1. Recursive Radix-2 FFT (Cooley-Tukey Algorithm) +function recursive_fft(x::Vector{<:Number}) + n = length(x) + if n == 1 + return x + end + + # Split into even and odd indices + even = recursive_fft(x[1:2:n]) + odd = recursive_fft(x[2:2:n]) + + # Combine results + result = zeros(Complex{Float64}, n) + for k in 0:n÷2-1 + t = exp(-2im * π * k / n) * odd[k+1] + result[k+1] = even[k+1] + t + result[k+1+n÷2] = even[k+1] - t + end + + return result +end + +### 2. Iterative Radix-2 FFT (Cooley-Tukey Algorithm) +function iterative_fft(x::Vector{<:Number}) + n = length(x) + + # Bit-reversal permutation + result = complex.(x) + j = 1 + for i in 1:n-1 + if i < j + result[i], result[j] = result[j], result[i] + end + m = n ÷ 2 + while m <= j + j -= m + m ÷= 2 + end + j += m + end + + # Butterflies + for s in 1:log2(n) + m = 2^s + wm = exp(-2im * π / m) + for k in 0:m:n-1 + w = 1.0 + 0.0im + for j in 0:m÷2-1 + t = w * result[k+j+m÷2+1] + u = result[k+j+1] + result[k+j+1] = u + t + result[k+j+m÷2+1] = u - t + w *= wm + end + end + end + + return result +end + +### 3. Bluestein's FFT Algorithm (Chirp Z-Transform) - Handles any size n +function bluestein_fft(x::Vector{<:Number}) + n = length(x) + + # Find next power of 2 >= 2n-1 + m = 2^ceil(Int, log2(2n-1)) + + # Create chirp sequences + a = [exp(im * π * (k^2) / n) for k in 0:n-1] + b = [exp(-im * π * (k^2) / n) for k in -(n-1):n-1] + + # Zero padding + a_padded = zeros(Complex{Float64}, m) + b_padded = zeros(Complex{Float64}, m) + + # Fill in values + a_padded[1:n] = x .* a + b_padded[1:2n-1] = b + + # Convolve using FFT (assuming the presence of an FFT function) + c_padded = ifft(fft(a_padded) .* fft(b_padded))[1:n] + + # Adjust final result + result = c_padded .* a + + return result +end + +### 4. Stockham Radix-2 FFT Algorithm (Auto-Sorting) +function stockham_fft(x::Vector{<:Number}) + n = length(x) + + # Ensure n is a power of 2 + if log2(n) != floor(log2(n)) + error("Input length must be a power of 2") + end + + # Initialize buffers + buffer1 = complex.(x) + buffer2 = zeros(Complex{Float64}, n) + + for s in 1:Int(log2(n)) + m = 2^s + half_m = m ÷ 2 + + for k in 0:half_m:n-1 + w = 1.0 + 0.0im + wm = exp(-2im * π / m) + + for j in 0:half_m-1 + even_idx = k + j + 1 + odd_idx = even_idx + half_m + + # Butterfly operation + even_val = buffer1[even_idx] + odd_val = buffer1[odd_idx] + + buffer2[k÷half_m * m + j + 1] = even_val + w * odd_val + buffer2[k÷half_m * m + j + half_m + 1] = even_val - w * odd_val + + w *= wm + end + end + + # Swap buffers for next iteration + buffer1, buffer2 = buffer2, buffer1 + end + + return buffer1 +end + +### 5. Radix-4 FFT Algorithm +function radix4_fft(x::Vector{<:Number}) + n = length(x) + + # Ensure n is a power of 4 + if log2(n) % 2 != 0 + error("Input length must be a power of 4") + end + + if n == 1 + return x + end + + # Split into four parts + x0 = radix4_fft(x[1:4:n]) + x1 = radix4_fft(x[2:4:n]) + x2 = radix4_fft(x[3:4:n]) + x3 = radix4_fft(x[4:4:n]) + + # Combine results + result = zeros(Complex{Float64}, n) + for k in 0:n÷4-1 + w1 = exp(-2im * π * k / n) + w2 = w1 * w1 + w3 = w2 * w1 + + t0 = x0[k+1] + t1 = w1 * x1[k+1] + t2 = w2 * x2[k+1] + t3 = w3 * x3[k+1] + + result[k+1] = t0 + t1 + t2 + t3 + result[k+1+n÷4] = t0 - im*t1 - t2 + im*t3 + result[k+1+2n÷4] = t0 - t1 + t2 - t3 + result[k+1+3n÷4] = t0 + im*t1 - t2 - im*t3 + end + + return result +end + +### 6. Split-Radix FFT Algorithm +function split_radix_fft(x::Vector{<:Number}) + n = length(x) + + if n == 1 + return x + end + + if n == 2 + return [x[1] + x[2], x[1] - x[2]] + end + + # Split into even, every 4th, and every 4th+2 + even = split_radix_fft(x[1:2:n]) + odd1 = split_radix_fft(x[2:4:n]) + odd3 = split_radix_fft(x[4:4:n]) + + # Combine results + result = zeros(Complex{Float64}, n) + for k in 0:n÷4-1 + w1 = exp(-2im * π * k / n) + w3 = exp(-6im * π * k / n) + + result[k+1] = even[k+1] + w1 * odd1[k+1] + w3 * odd3[k+1] + result[k+1+n÷2] = even[k+1] - w1 * odd1[k+1] - w3 * odd3[k+1] + + result[k+1+n÷4] = even[k+1+n÷4] - im * (w1 * odd1[k+1] - w3 * odd3[k+1]) + result[k+1+3n÷4] = even[k+1+n÷4] + im * (w1 * odd1[k+1] - w3 * odd3[k+1]) + end + + return result +end + +### 7. Prime-Factor Algorithm (PFA) for FFT +function prime_factor_fft(x::Vector{<:Number}, n1::Int, n2::Int) + n = length(x) + + # Ensure n = n1 * n2 and n1, n2 are coprime + if n != n1 * n2 + error("n must equal n1 * n2") + end + + # Create 2D array from 1D + X = zeros(Complex{Float64}, n1, n2) + for i in 0:n-1 + i1 = i % n1 + i2 = i % n2 + X[i1+1, i2+1] = x[i+1] + end + + # FFT along each dimension + for i1 in 1:n1 + X[i1, :] = recursive_fft(X[i1, :]) + end + + for i2 in 1:n2 + X[:, i2] = recursive_fft(X[:, i2]) + end + + # Map back to 1D + result = zeros(Complex{Float64}, n) + for i1 in 0:n1-1 + for i2 in 0:n2-1 + k = (i1 * n2 + i2) % n + result[k+1] = X[i1+1, i2+1] + end + end + + return result +end + +### 8. FFT using Two Butterfly Algorithms (Radix-2^2) +function butterfly_fft(x::Vector{<:Number}) + n = length(x) + + # Bit-reversal permutation + result = complex.(x) + j = 1 + for i in 1:n-1 + if i < j + result[i], result[j] = result[j], result[i] + end + m = n ÷ 2 + while m <= j + j -= m + m ÷= 2 + end + j += m + end + + # 2-point DFT (butterfly) + for i in 1:2:n + t = result[i] + result[i] = t + result[i+1] + result[i+1] = t - result[i+1] + end + + # 4-point and higher radix butterflies + for stage in 2:log2(n) + m = 2^stage + half_m = m ÷ 2 + + for k in 0:m:n-1 + for j in 0:half_m-1 + idx1 = k + j + 1 + idx2 = idx1 + half_m + w = exp(-2im * π * j / m) + t = w * result[idx2] + result[idx2] = result[idx1] - t + result[idx1] = result[idx1] + t + end + end + end + + return result +end \ No newline at end of file diff --git a/examples/smartfft/step1_candidate_algorithms_2d.jl b/examples/smartfft/step1_candidate_algorithms_2d.jl new file mode 100644 index 0000000..41b58f7 --- /dev/null +++ b/examples/smartfft/step1_candidate_algorithms_2d.jl @@ -0,0 +1,353 @@ +using FFTW + +# 1. Radix-2 Cooley-Tukey Algorithm (Row-Column Method) +function fft2d_cooley_tukey!(A::Matrix{ComplexF64}) + N = size(A, 1) + # Perform 1D FFT on each row + for i in 1:N + fft1d_cooley_tukey!(view(A, i, :)) + end + # Perform 1D FFT on each column + for j in 1:N + fft1d_cooley_tukey!(view(A, :, j)) + end + return A +end + +# Helper 1D Cooley-Tukey FFT (recursive, in-place) +function fft1d_cooley_tukey!(x::AbstractVector{ComplexF64}) + N = length(x) + if N ≤ 1 return x end + + # Check if N is a power of 2 + if N & (N - 1) != 0 + error("Length must be a power of 2") + end + + # Split into even and odd indices + even = x[1:2:N] + odd = x[2:2:N] + + # Recursive calls + fft1d_cooley_tukey!(even) + fft1d_cooley_tukey!(odd) + + # Combine results + for k in 1:N÷2 + t = exp(-2im * π * (k-1) / N) * odd[k] + x[k] = even[k] + t + x[k + N÷2] = even[k] - t + end + + return x +end + +# 2. Stockham Auto-Sort Algorithm +function fft2d_stockham!(A::Matrix{ComplexF64}) + N = size(A, 1) + # Apply to rows + for i in 1:N + A[i, :] = fft1d_stockham(A[i, :]) + end + # Apply to columns + for j in 1:N + A[:, j] = fft1d_stockham(A[:, j]) + end + return A +end + +function fft1d_stockham(x::AbstractVector{ComplexF64}) + N = length(x) + if N ≤ 1 return copy(x) end + + # Check if N is a power of 2 + if N & (N - 1) != 0 + error("Length must be a power of 2") + end + + # Two buffers to avoid allocation in the loop + buffer1 = copy(x) + buffer2 = similar(x) + + m = 1 + while m < N + for k in 0:m-1 + w = exp(-2im * π * k / (2 * m)) + for j in 0:(N÷(2*m)-1) + idx1 = j * 2 * m + k + 1 + idx2 = idx1 + m + + t1 = buffer1[idx1] + t2 = buffer1[idx2] * w + + buffer2[j * m + k + 1] = t1 + t2 + buffer2[j * m + k + 1 + N÷2] = t1 - t2 + end + end + buffer1, buffer2 = buffer2, buffer1 + m *= 2 + end + + return buffer1 +end + +# 3. Eight-Step FFT Algorithm (Cache-Efficient) +function fft2d_eight_step!(A::Matrix{ComplexF64}) + N = size(A, 1) + # 1. Transpose + A = transpose!(copy(A), A) + + # 2. 1D FFTs on rows + for i in 1:N + A[i, :] = fft1d_iterative!(copy(A[i, :])) + end + + # 3. Apply twiddle factors + for j in 0:N-1 + for i in 0:N-1 + A[i+1, j+1] *= exp(-2im * π * i * j / N) + end + end + + # 4. Transpose + A = transpose!(copy(A), A) + + # 5. 1D FFTs on rows (originally columns) + for i in 1:N + A[i, :] = fft1d_iterative!(copy(A[i, :])) + end + + # 6. Transpose back to original orientation + A = transpose!(copy(A), A) + + return A +end + +# Helper iterative 1D FFT +function fft1d_iterative!(x::AbstractVector{ComplexF64}) + N = length(x) + if N ≤ 1 return x end + + # Bit reversal + j = 1 + for i in 1:N-1 + if i < j + x[i], x[j] = x[j], x[i] + end + m = N ÷ 2 + while m ≥ 1 && j > m + j -= m + m ÷= 2 + end + j += m + end + + # Butterfly operations + for s in 1:Int(log2(N)) + m = 2^s + wm = exp(-2im * π / m) + for k in 0:m:N-1 + w = 1.0 + 0.0im + for j in 0:m÷2-1 + t = w * x[k + j + m÷2 + 1] + u = x[k + j + 1] + x[k + j + 1] = u + t + x[k + j + m÷2 + 1] = u - t + w *= wm + end + end + end + + return x +end + +# 4. Split-Radix FFT (Iterative Implementation) +function fft2d_split_radix!(A::Matrix{ComplexF64}) + N = size(A, 1) + # Perform 1D FFT on each row + for i in 1:N + fft1d_split_radix!(view(A, i, :)) + end + # Perform 1D FFT on each column + for j in 1:N + fft1d_split_radix!(view(A, :, j)) + end + return A +end + +function fft1d_split_radix!(x::AbstractVector{ComplexF64}) + N = length(x) + if N <= 1 return x end + + # Check if N is a power of 2 + if N & (N - 1) != 0 + error("Length must be a power of 2") + end + + # Bit-reversal permutation (similar to Cooley-Tukey) + j = 1 + for i in 1:N-1 + if i < j + x[i], x[j] = x[j], x[i] + end + m = N ÷ 2 + while m ≥ 1 && j > m + j -= m + m ÷= 2 + end + j += m + end + + # L-shaped butterflies for split-radix + L = 2 + while L <= N + # Radix-2 part + m = L ÷ 2 + wm = exp(-2im * π / L) + + for k in 0:L:N-1 + w = 1.0 + 0.0im + for j in 0:m-1 + temp = w * x[k+j+m+1] + x[k+j+m+1] = x[k+j+1] - temp + x[k+j+1] += temp + w *= wm + end + end + + # Extra butterflies for split-radix strategy when L >= 4 + if L >= 4 + m = L ÷ 4 + wm = exp(-2im * π / L) + for k in L÷2:L:N-1 + w = exp(-2im * π * (k % L) / L) + for j in 0:m-1 + idx1 = k + j + 1 + idx2 = k + j + m + 1 + + temp1 = x[idx1] + temp2 = x[idx2] + + x[idx1] = (temp1 + 1im * temp2) / sqrt(2.0) + x[idx2] = (temp1 - 1im * temp2) / sqrt(2.0) + end + end + end + + L *= 2 + end + + return x +end + +# 5. Bluestein's FFT Algorithm (Chirp Z-Transform) +function fft2d_bluestein!(A::Matrix{ComplexF64}) + N = size(A, 1) + # Perform 1D FFT on each row + for i in 1:N + A[i, :] = fft1d_bluestein!(copy(A[i, :])) + end + # Perform 1D FFT on each column + for j in 1:N + A[:, j] = fft1d_bluestein!(copy(A[:, j])) + end + return A +end + +function fft1d_bluestein!(x::AbstractVector{ComplexF64}) + N = length(x) + if N ≤ 1 return x end + + # Find M that is a power of 2 and M ≥ 2N-1 + M = 1 + while M < 2*N-1 + M *= 2 + end + + # Precompute chirp factors with explicit handling of indices + a = [exp(im * π * (n^2) / N) for n in 0:N-1] + + # Create b with careful indexing to avoid out-of-bounds access + b = zeros(ComplexF64, 2*N-1) + for n in -(N-1):(N-1) + b[n+(N-1)+1] = exp(-im * π * (n^2) / N) + end + + # Zero pad sequences to length M + a_padded = zeros(ComplexF64, M) + b_padded = zeros(ComplexF64, M) + + # Fill a_padded with data * chirp + for n in 1:N + a_padded[n] = x[n] * a[n] + end + + # Fill b_padded with proper indexing + for n in 1:(2*N-1) + b_padded[n] = b[n] + end + + # Perform convolution using FFT + A_fft = fft1d_iterative!(copy(a_padded)) + B_fft = fft1d_iterative!(copy(b_padded)) + + # Element-wise multiplication + C_fft = A_fft .* B_fft + + # Inverse FFT + c = fft1d_iterative!(copy(C_fft)) + c ./= M # Normalization for inverse FFT + + # Extract result and apply chirp correction safely + result = similar(x) + for n in 1:N + result[n] = c[n] * conj(a[n]) + end + + return result +end + +# Helper function for transposes +function transpose!(dest::Matrix{ComplexF64}, src::Matrix{ComplexF64}) + n, m = size(src) + for j in 1:m + for i in 1:n + dest[j, i] = src[i, j] + end + end + return dest +end + +# Example usage: +using FFTW +function test_fft(N::Int=8) + # Create test matrix + A = [ComplexF64(i + j) for i in 1:N, j in 1:N] + A_copy = copy(A) + + # Test implementations + println("Testing Cooley-Tukey implementation:") + result1 = fft2d_cooley_tukey!(copy(A)) + + println("Testing Stockham implementation:") + result2 = fft2d_stockham!(copy(A)) + + println("Testing Eight-Step implementation:") + result3 = fft2d_eight_step!(copy(A)) + + println("Testing Split-Radix implementation:") + result4 = fft2d_split_radix!(copy(A)) + + println("Testing Bluestein implementation:") + result5 = fft2d_bluestein!(copy(A)) + + # Verify results against Julia's built-in FFT + reference = fft(A_copy) + + println("\nMaximum difference from reference:") + println("Cooley-Tukey: ", maximum(abs.(result1 - reference))) + println("Stockham: ", maximum(abs.(result2 - reference))) + println("Eight-Step: ", maximum(abs.(result3 - reference))) + println("Split-Radix: ", maximum(abs.(result4 - reference))) + println("Bluestein: ", maximum(abs.(result5 - reference))) +end \ No newline at end of file diff --git a/examples/smartfft/step2_connect_algorithms_1d.jl b/examples/smartfft/step2_connect_algorithms_1d.jl new file mode 100644 index 0000000..4254613 --- /dev/null +++ b/examples/smartfft/step2_connect_algorithms_1d.jl @@ -0,0 +1,22 @@ +# Load AI-generated candidate algorithms +include("step1_candidate_algorithms_1d.jl") + +# Connect candidate algorithms to SmartSolve +# Note: Commented-out algorithms were non-functional +SmartSolve.add_candidate_algorithm!(:fft1, recursive_fft) +#SmartSolve.add_candidate_algorithm!(:fft1, iterative_fft) +SmartSolve.add_candidate_algorithm!(:fft1, bluestein_fft) +#SmartSolve.add_candidate_algorithm!(:fft1, stockham_fft) +SmartSolve.add_candidate_algorithm!(:fft1, radix4_fft) +SmartSolve.add_candidate_algorithm!(:fft1, split_radix_fft) +#SmartSolve.add_candidate_algorithm!(:fft1, butterfly_fft) + +# Benchmark all algorithms with an NxNxN tensor +SmartSolve.benchmark_algorithms!(:fft1, 2:10) do alg, n + # Create a 1D input tensor of size (4^n) + input = rand(ComplexF64, 4^n) + # Benchmark this algorithm + perf = @elapsed alg(input) + # Return the performance result + return perf +end \ No newline at end of file diff --git a/examples/smartfft/step2_connect_algorithms_2d.jl b/examples/smartfft/step2_connect_algorithms_2d.jl new file mode 100644 index 0000000..081e00e --- /dev/null +++ b/examples/smartfft/step2_connect_algorithms_2d.jl @@ -0,0 +1,19 @@ +# Load AI-generated candidate algorithms +include("step1_candidate_algorithms_2d.jl") + +# Connect candidate algorithms to SmartSolve +SmartSolve.add_candidate_algorithm!(:fft2, fft2d_cooley_tukey!) +SmartSolve.add_candidate_algorithm!(:fft2, fft2d_stockham!) +SmartSolve.add_candidate_algorithm!(:fft2, fft2d_eight_step!) +SmartSolve.add_candidate_algorithm!(:fft2, fft2d_split_radix!) +SmartSolve.add_candidate_algorithm!(:fft2, fft2d_bluestein!) + +# Benchmark all algorithms with an NxNxN tensor +SmartSolve.benchmark_algorithms!(:fft2, 2:5) do alg, n + # Create a 1D input tensor of size (4^n, 4^n) + input = rand(ComplexF64, 4^n, 4^n) + # Benchmark this algorithm + perf = @elapsed alg(input) + # Return the performance result + return perf +end \ No newline at end of file diff --git a/examples/smartfft/step3_selection_best_algorithm.jl b/examples/smartfft/step3_selection_best_algorithm.jl new file mode 100644 index 0000000..69b30ab --- /dev/null +++ b/examples/smartfft/step3_selection_best_algorithm.jl @@ -0,0 +1,23 @@ +# Load SmartSolve +include("SmartSolve.jl") + +# Load our database of algorithms +include("step2_connect_algorithms_1d.jl") +include("step2_connect_algorithms_2d.jl") + +# Plot the performance results +p = SmartSolve.plot_benchmark_results(:fft1) +p = SmartSolve.plot_benchmark_results(:fft2) + +# Select the best algorithm for a given input size +input = rand(ComplexF64, 4^1); +smartfft_1d! = SmartSolve.select_best_algorithm(:fft1, input) + +input = rand(ComplexF64, 4^7); +smartfft_1d! = SmartSolve.select_best_algorithm(:fft1, input) + +input = rand(ComplexF64, 4^1, 4^1); +smartfft_2d! = SmartSolve.select_best_algorithm(:fft2, input) + +input = rand(ComplexF64, 4^5, 4^5); +smartfft_2d! = SmartSolve.select_best_algorithm(:fft2, input) \ No newline at end of file diff --git a/examples/smartfft/step4_integration_dagger.jl b/examples/smartfft/step4_integration_dagger.jl new file mode 100644 index 0000000..886c1e4 --- /dev/null +++ b/examples/smartfft/step4_integration_dagger.jl @@ -0,0 +1,41 @@ +# Load our optimizing scheduler, Dagger +using Dagger + +# Load our benchmarked database of algorithms +include("step3_selection_best_algorithm.jl") + +# Let's compare against serial FFTW +using FFTW +input = rand(ComplexF64, 4^5, 4^5); +println("FFTW") +FFTW.set_num_threads(1) +@time fft!(input); +println("SmartSolve") +@time smartfft_2d!(input); + +# If our data is large, Dagger can scale the algorithm for us +include("generate_gantt.jl") + +# TODO: This will use SmartSolve's FFT within Dagger's Pencil FFT, +# potentially even allowing each pencil's FFT to be optimized for +# where the pencil will execute (optimize for specific CPU/GPU architecture) +function daggerfft!(A, B) + # This is a fully parallel 2D FFT algorithm + # It supports multi-CPU, multi-GPU, and distributed computing + A_parts = A.chunks + B_parts = B.chunks + Dagger.spawn_datadeps() do + for idx in eachindex(A_parts) + Dagger.@spawn name="smartfft!(dim 1)" smartfft!(InOut(A_parts[idx])) + end + end + copyto!(B, A) + Dagger.spawn_datadeps() do + for idx in eachindex(B_parts) + Dagger.@spawn name="smartfft!(dim 2)" smartfft!(InOut(B_parts[idx])) + end + end +end +A = rand(Blocks(4^5, div(4^5, 16)), ComplexF64, 4^5, 4^5); +B = zeros(Blocks(div(4^5, 16), 4^5), ComplexF64, 4^5, 4^5); +@time daggerfft!(A, B); \ No newline at end of file