Skip to content

Commit

Permalink
Fix Parallel TPF accuracy when run with multiple workers
Browse files Browse the repository at this point in the history
  • Loading branch information
ShlokG committed Aug 12, 2021
1 parent ff2ca1a commit bb6487a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# StateSpaceRoutine.jl v0.4.2 Release notes
- Fix Tempered Particle Filter when run in parallel with multiple nodes.

# StateSpaceRoutines.jl v0.4.1 Release notes
- Fix Koopman disturbance smoother when Z is time-varying.
- Add further tests in time-varying cases.
Expand Down
1 change: 1 addition & 0 deletions src/filters/tempered_particle_filter/correction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ function weight_kernel!(coeff_terms::V, log_e_1_terms::V, log_e_2_terms::V,
error = y_t - Ψ(s_t_nontemp[:, i])
return dot(error, inv_HH * error)
end

sq_error = @sync @distributed (vcat) for i in 1:n_particles
error_closure(i)
end
Expand Down
37 changes: 32 additions & 5 deletions src/filters/tempered_particle_filter/mutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function mutation!(Φ::Function, Ψ::Function, QQ::Matrix{Float64},
n_particles = size(ϵ_t, 2)

# Initialize vector of acceptances
accept_vec = parallel ? SharedVector{Int}(n_particles) : Vector{Int}(undef, n_particles)
accept_vec = Vector{Int}(undef, n_particles)

# Used to generate new draws of ϵ
dist_ϵ = MvNormal(c^2 * diag(QQ))
Expand All @@ -44,11 +44,38 @@ function mutation!(Φ::Function, Ψ::Function, QQ::Matrix{Float64},
@everywhere mh_steps_closure(i::Int) = mh_steps(Φ, Ψ, dist_ϵ, y_t, s_t1[:,i], s_t[:,i], ϵ_t[:,i],
scaled_det_HH, scaled_inv_HH, n_mh_steps;
poolmodel = poolmodel)

s_t, ϵ_t, accept_vec .= @sync @distributed (vector_reduce) for i in 1:n_particles
mh_steps_closure(i)
#=
@floop DistributedEx() for i in 1:n_particles#DistributedEx(threads_basesize = Int(ceil(n_particles / nworkers()))) begin
#=s_t_fin = Matrix(undef, size(s_t,1), 0)
ϵ_t_fin = Matrix(undef, size(ϵ_t,1), 0)
accept_vec_fin = []
for i in 1:n_particles=#
s_t2, ϵ_t2, accept_vec2 = mh_steps_closure(i)
#= @reduce() do (s_t_fin = Matrix(undef, size(s_t,1), 0); s_t2)
s_t_fin = hcat(s_t_fin, s_t2)
end
@reduce() do (ϵ_t_fin = Matrix(undef, size(ϵ_t,1), 0); ϵ_t2)
ϵ_t_fin = hcat(ϵ_t_fin, ϵ_t2)
end
@reduce() do (accept_vec_fin = []; accept_vec2)
accept_vec_fin = append!(accept_vec_fin, accept_vec2)
end
=#
@reduce(s_t_fin = hcat(Matrix(undef, size(s_t,1), 0), s_t2), ϵ_t_fin = hcat(Matrix(undef, size(ϵ_t,1), 0), ϵ_t2),
accept_vec_fin = append!([],accept_vec2))
end
#end
s_t .= s_t_fin
ϵ_t .= ϵ_t_fin
accept_vec .= accept_vec_fin
=#
s_t2, ϵ_t2, accept_vec2 = @sync @distributed (vector_reduce) for i in 1:n_particles
vector_reshape(mh_steps_closure(i)...)
end
accept_vec = vec(accept_vec)
s_t .= s_t2
ϵ_t .= ϵ_t2
accept_vec .= vec(accept_vec2)
else
for i in 1:n_particles
s_t[:,i], ϵ_t[:,i], accept_vec[i] =
Expand Down
29 changes: 26 additions & 3 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end
function vec_red_scal(args...)
```
vector_reduce but scalar_reduce for last argument
Hard coded for 3 arguments
"""
function vec_scal_reduce(args...)
nargs1 = length(args) # The number of times the loop is run
Expand All @@ -64,14 +65,36 @@ function vec_scal_reduce(args...)
return_arg = args[1]
for i in 1:nargs2
for j in 2:nargs1
if i < nargs2
return_arg[i] = hcat(return_arg[i], args[j][i])
if i == nargs2
append!(return_arg[i], args[j][i][1])
#return_arg[i] = vcat(return_arg[i], args[j][i])
else
append!(return_arg[i], args[j][i])
return_arg[i] = hcat(return_arg[i], args[j][i])
end
end
end
return return_arg
#=
nargs1 = length(args) # The number of times the loop is run
nargs2 = length(args[1]) # The number of variables output by a single run
if nargs2 == 1
return args
end
arg1, arg2, arg3 = args[1]
#arg2 = args[1][2]
#arg3 = args[1][3]
for j in 2:nargs1
arg1 = hcat(arg1, args[j][1])
arg2 = hcat(arg2, args[j][2])
arg3 = vcat(arg3, args[j][3])
end
return_arg = (arg1, arg2, arg3)
return return_arg
=#
end

"""
Expand Down

0 comments on commit bb6487a

Please sign in to comment.