Skip to content

Commit

Permalink
Update Turing implementation and other fixes (#33)
Browse files Browse the repository at this point in the history
* Simplify tests and suppress progress with env variable

* up version

* updated code to match new AdvancedHMC

* fix turing implementation

* adapt rest of name

* up julia lower bound

* fix compats

* update dependabot
  • Loading branch information
theogf authored Feb 10, 2025
1 parent aa0a6b9 commit 77b7c83
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 73 deletions.
8 changes: 2 additions & 6 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "" # See documentation for possible values
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
5 changes: 2 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
fail-fast: false
matrix:
version:
- "1.6"
- "1.10"
- "1"
- "nightly"
os:
- ubuntu-latest
- macOS-latest
Expand All @@ -42,6 +40,7 @@ jobs:
- uses: julia-actions/julia-runtest@v1
env:
JULIA_NUM_THREADS: 4
PROGRESS_BARS: false
- uses: julia-actions/julia-processcoverage@v1
- name: Upload coverage reports to Codecov
uses: codecov/[email protected]
Expand Down
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
name = "ThermodynamicIntegration"
uuid = "1022446e-a4a4-4a46-8bce-0ffd39f68cd3"
authors = ["Theo Galy-Fajou <[email protected]> and contributors"]
version = "0.2.7"
version = "0.2.8"

[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1"

[compat]
AdvancedHMC = "0.2, 0.3, 0.6"
AdvancedHMC = "0.6"
Distributed = "1"
ForwardDiff = "0.10"
LinearAlgebra = "1"
LogDensityProblems = "2"
ProgressMeter = "1"
Random = "1"
Requires = "1"
Statistics = "1"
Trapz = "2"
julia = "1.6"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ makedocs(;
sitename="ThermodynamicIntegration.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://theogf.github.io/ThermodynamicIntegration.jl",
canonical="https://theogf.dev/ThermodynamicIntegration.jl",
assets=String[],
),
pages=["Home" => "index.md"],
Expand Down
3 changes: 3 additions & 0 deletions src/ThermodynamicIntegration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ForwardDiff
using ProgressMeter
using Random: Random, AbstractRNG, default_rng
using Requires
using LogDensityProblems
using Statistics
using Trapz

Expand All @@ -25,6 +26,8 @@ function set_adbackend(::Any)
)
end

const SHOW_PROGRESS_BARS = parse(Bool, get(ENV, "PROGRESS_BARS", "true"))

include("thermint.jl")

function __init__()
Expand Down
53 changes: 33 additions & 20 deletions src/thermint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,15 @@ function ThermInt(rng::AbstractRNG, schedule; n_samples::Int=2000, n_warmup::Int
end

function ThermInt(schedule; n_samples::Int=2000, n_warmup::Int=500)
return ThermInt(default_rng(), schedule; n_samples=n_samples, n_warmup=n_warmup)
return ThermInt(default_rng(), schedule; n_samples, n_warmup)
end

function ThermInt(rng::AbstractRNG; n_steps::Int, n_samples::Int=2000, n_warmup::Int=500)
return ThermInt(
rng, range(0, 1; length=n_steps) .^ 5; n_samples=n_samples, n_warmup=n_warmup
)
return ThermInt(rng, range(0, 1; length=n_steps) .^ 5; n_samples, n_warmup)
end

function ThermInt(; n_steps::Int=30, n_samples::Int=2000, n_warmup::Int=500)
return ThermInt(
default_rng(),
range(0, 1; length=n_steps) .^ 5;
n_samples=n_samples,
n_warmup=n_warmup,
)
return ThermInt(default_rng(), range(0, 1; length=n_steps) .^ 5; n_samples, n_warmup)
end

abstract type TIEnsemble end
Expand All @@ -56,7 +49,7 @@ function (alg::ThermInt)(
logprior,
x_init::AbstractVector,
::TISerial=TISerial();
progress=true,
progress=SHOW_PROGRESS_BARS,
kwargs...,
)
p = ProgressMeter.Progress(length(alg.schedule); enabled=progress, desc="TI Sampling:")
Expand All @@ -74,7 +67,12 @@ function check_threads()
end

function (alg::ThermInt)(
loglikelihood, logprior, x_init::AbstractVector, ::TIThreads; progress=true, kwargs...
loglikelihood,
logprior,
x_init::AbstractVector,
::TIThreads;
progress=SHOW_PROGRESS_BARS,
kwargs...,
)
check_threads()
nsteps = length(alg.schedule)
Expand Down Expand Up @@ -104,7 +102,7 @@ function (alg::ThermInt)(
logprior,
x_init::AbstractVector,
::TIDistributed;
progress=false,
progress=SHOW_PROGRESS_BARS,
kwargs...,
)
check_processes()
Expand Down Expand Up @@ -135,28 +133,43 @@ function (alg::ThermInt)(
)
end

struct PowerJoint{Tβ,FL,FP}
β::Tβ
dim::Int
loglikelihood::FL
logprior::FP
end

function LogDensityProblems.logdensity((; β, loglikelihood, logprior)::PowerJoint, θ)
return β * loglikelihood(θ) + logprior(θ)
end
LogDensityProblems.dimension((; dim)::PowerJoint) = dim
function LogDensityProblems.capabilities(::Type{<:PowerJoint})
return LogDensityProblems.LogDensityOrder{0}()
end

function evaluate_loglikelihood(loglikelihood, logprior, alg::ThermInt, x_init, β::Real)
powerlogπ(θ) = β * loglikelihood(θ) + logprior)
samples = sample_powerlogπ(powerlogπ, alg, x_init)
pj = PowerJoint(β, length(x_init), loglikelihood, logprior)
samples = sample_powerlogπ(pj, alg, x_init)
x_init .= samples[end] # Update the initial sample to be the last one of the chain
return mean(loglikelihood, samples)
end

function sample_powerlogπ(powerlogπ, alg::ThermInt, x_init)
D = length(x_init)
function sample_powerlogπ(pj::PowerJoint, alg::ThermInt, x_init)
D = LogDensityProblems.dimension(pj)
metric = DiagEuclideanMetric(D)
hamiltonian = get_hamiltonian(metric, powerlogπ, alg)
hamiltonian = get_hamiltonian(metric, pj, alg)

initial_ϵ = find_good_stepsize(hamiltonian, x_init)
integrator = Leapfrog(initial_ϵ)

proposal = AdvancedHMC.NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

samples, _ = sample(
alg.rng,
hamiltonian,
proposal,
kernel,
x_init,
alg.n_samples,
adaptor,
Expand Down
31 changes: 18 additions & 13 deletions src/turing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using .Turing: DynamicPPL, Prior

function (alg::ThermInt)(
model::DynamicPPL.Model, ::TISerial=TISerial(); progress=true, kwargs...
model::DynamicPPL.Model, ::TISerial=TISerial(); progress=SHOW_PROGRESS_BARS, kwargs...
)
nsteps = length(alg.schedule)
p = Progress(nsteps; enabled=progress, desc="TI Sampling")
Expand All @@ -14,7 +14,9 @@ function (alg::ThermInt)(
return trapz(alg.schedule, ΔlogZ)
end

function (alg::ThermInt)(model::DynamicPPL.Model, ::TIThreads; progress=true, kwargs...)
function (alg::ThermInt)(
model::DynamicPPL.Model, ::TIThreads; progress=SHOW_PROGRESS_BARS, kwargs...
)
check_threads()
nsteps = length(alg.schedule)
nthreads = min(Threads.nthreads(), nsteps)
Expand All @@ -30,7 +32,7 @@ function (alg::ThermInt)(model::DynamicPPL.Model, ::TIThreads; progress=true, kw
end

function (alg::ThermInt)(
model::DynamicPPL.Model, ::TIDistributed; progress=false, kwargs...
model::DynamicPPL.Model, ::TIDistributed; progress=SHOW_PROGRESS_BARS, kwargs...
)
check_processes()
progress && @warn "progress is not possible with distributed computing for now."
Expand All @@ -44,31 +46,34 @@ function (alg::ThermInt)(
end

function evaluate_loglikelihood(model::DynamicPPL.Model, alg::ThermInt, β::Real)
powerlogπ = power_logjoint(model, β)
logprior = get_logprior(model)
loglikelihood = get_loglikelihood(model)
x_init = vec(Array(sample(model, Prior(), 1))) # Bad ugly hack cause I don't know how to sample from the prior
samples = sample_powerlogπ(powerlogπ, alg, x_init)
x_init = vec(Array(sample(model, Prior(), 1; progress=false))) # Bad ugly hack cause I don't know how to sample from the prior
pj = PowerJoint(β, length(x_init), loglikelihood, logprior)
samples = sample_powerlogπ(pj, alg, x_init)
return mean(loglikelihood, samples)
end

function power_logjoint(model, β)
ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), β)
"""
Build a logprior function acting on the flattened version of the parameters.
"""
function get_logprior(model)
spl = DynamicPPL.SampleFromPrior()
vi = DynamicPPL.VarInfo(model)
return function f(z)
varinfo = DynamicPPL.VarInfo(vi, spl, z)
model(varinfo, spl, ctx)
return DynamicPPL.getlogp(varinfo)
return DynamicPPL.logprior(model, varinfo)
end
end

"""
Build a loglikelihood function acting on the flattened version of the parameters.
"""
function get_loglikelihood(model)
ctx = DynamicPPL.LikelihoodContext()
spl = DynamicPPL.SampleFromPrior()
vi = DynamicPPL.VarInfo(model)
return function f(z)
varinfo = DynamicPPL.VarInfo(vi, spl, z)
model(varinfo, spl, ctx)
return DynamicPPL.getlogp(varinfo)
return DynamicPPL.loglikelihood(model, varinfo)
end
end
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7 changes: 1 addition & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
using ThermodynamicIntegration
using Distributed
using Distributions
using Test
using LinearAlgebra
using Turing
using Suppressor

@testset "ThermodynamicIntegration.jl" begin
include("thermint.jl")
end
50 changes: 30 additions & 20 deletions test/thermint.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
using ThermodynamicIntegration
using Distributed
using Distributions
using Test
using LinearAlgebra
using Turing

function test_basic_model(
alg::ThermInt, method::ThermodynamicIntegration.TIEnsemble=TISerial(); D=5, atol=1e-1
)
prior = MvNormal(Diagonal(0.5 * ones(D)))
likelihood = MvNormal(Diagonal(2.0 * ones(D)))
logprior(x) = logpdf(prior, x)
loglikelihood(x) = logpdf(likelihood, x)
# We capture stdout to avoid having the progress meter in CI.
@capture_out logZ = alg(logprior, loglikelihood, rand(prior), method)
true_logZ = -0.5 * (logdet(cov(prior) + cov(likelihood)) + D * log(2π))
@testset "$(nameof(typeof(alg))) - $(nameof(typeof(method)))" begin
prior = MvNormal(Diagonal(0.5 * ones(D)))
likelihood = MvNormal(Diagonal(2.0 * ones(D)))
lprior(x) = logpdf(prior, x)
llikelihood(x) = logpdf(likelihood, x)
logZ = alg(lprior, llikelihood, rand(prior), method)
true_logZ = -0.5 * (logdet(cov(prior) + cov(likelihood)) + D * log(2π))

@test logZ true_logZ atol = atol
@test_throws ArgumentError alg(logprior, loglikelihood, first(rand(prior)), method)
@test logZ true_logZ atol = atol
@test_throws ArgumentError alg(lprior, llikelihood, first(rand(prior)), method)
end
end

function test_basic_turing(
alg::ThermInt, method::ThermodynamicIntegration.TIEnsemble=TISerial(); D=5, atol=1e-1
)
prior = MvNormal(Diagonal(0.5 * ones(D)))
likelihood = MvNormal(Diagonal(2.0 * ones(D)))
@model function gauss(y)
x ~ prior
return y ~ MvNormal(x, cov(likelihood))
end
m = gauss(zeros(D))
logZ = alg(m, method)
true_logZ = -0.5 * (logdet(cov(prior) + cov(likelihood)) + D * log(2π))
@testset "Turing - $(nameof(typeof(alg))) - $(nameof(typeof(method)))" begin
prior = MvNormal(Diagonal(0.5 * ones(D)))
likelihood = MvNormal(Diagonal(2.0 * ones(D)))
@model function gauss(y)
x ~ prior
return y ~ MvNormal(x, cov(likelihood))
end
m = gauss(zeros(D))
logZ = alg(m, method)
true_logZ = -0.5 * (logdet(cov(prior) + cov(likelihood)) + D * log(2π))

@test logZ true_logZ atol = atol
@test logZ true_logZ atol = atol
end
end

@testset "Basic model" begin
@testset "Test basic model with different options" begin
alg = ThermInt(; n_samples=5000)
# Test serialized version
test_basic_model(alg, TISerial())
Expand Down

2 comments on commit 77b7c83

@theogf
Copy link
Owner Author

@theogf theogf commented on 77b7c83 Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/124752

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.8 -m "<description of version>" 77b7c831af5de0a0fe57efd76e7299cd1668029d
git push origin v0.2.8

Please sign in to comment.