Skip to content

Commit

Permalink
docs: Normalizing Flow (RealNVP) example (#1215)
Browse files Browse the repository at this point in the history
* fix: patch reactant bug?

* docs: Normalizing Flow (RealNVP) example

* feat: print out throughput info

* fix: use smaller size for CI
  • Loading branch information
avik-pal authored Jan 20, 2025
1 parent 6b18e29 commit 521fefd
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.5.1"
version = "1.5.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pages = [
"tutorials/intermediate/4_PINN2DPDE.md",
"tutorials/intermediate/5_ConvolutionalVAE.md",
"tutorials/intermediate/6_GCN_Cora.md",
"tutorials/intermediate/7_RealNVP.md",
],
"Advanced" => [
"tutorials/advanced/1_GravitationalWaveForm.md"
Expand Down
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ export default defineConfig({
text: "Graph Convolutional Network on Cora",
link: "/tutorials/intermediate/6_GCN_Cora",
},
{
text: "Normalizing Flows for Density Estimation",
link: "/tutorials/intermediate/7_RealNVP",
}
],
},
{
Expand Down
Binary file added docs/src/public/realnvp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ const intermediate = [
src: "../gcn_cora.jpg",
caption: "Graph Convolutional Network on Cora",
desc: "Train a Graph Convolutional Network on Cora dataset."
},
{
href: "intermediate/7_RealNVP",
src: "../realnvp.png",
caption: "Normalizing Flows for Density Estimation",
desc: "Train a normalizing flow for density estimation on the Moons dataset.",
}
];
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const INTERMEDIATE_TUTORIALS = [
"PINN2DPDE/main.jl" => "CUDA",
"ConvolutionalVAE/main.jl" => "CUDA",
"GCN_Cora/main.jl" => "CUDA",
"RealNVP/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down
2 changes: 0 additions & 2 deletions examples/GCN_Cora/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -18,7 +17,6 @@ Enzyme = "0.13.28"
GNNGraphs = "1"
Lux = "1.5"
MLDatasets = "0.7.18"
MLUtils = "0.4.5"
OneHotArrays = "0.2"
Optimisers = "0.4.4"
Printf = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions examples/GCN_Cora/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# This example is based on [GCN MLX tutorial](https://github.com/ml-explore/mlx-examples/blob/main/gcn/). While we are doing this manually, we recommend directly using
# [GNNLux.jl](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/).

using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, GNNGraphs, MLUtils,
ConcreteStructs, Printf, OneHotArrays, Optimisers
using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, GNNGraphs, ConcreteStructs,
Printf, OneHotArrays, Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Expand Down
11 changes: 11 additions & 0 deletions examples/RealNVP/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
259 changes: 259 additions & 0 deletions examples/RealNVP/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# # [Normalizing Flows for Density Estimation](@id RealNVP-Tutorial)

# This tutorial demonstrates how to use Lux to train a
# [RealNVP](https://arxiv.org/abs/1605.08803). This is based on the
# [RealNVP implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/normalizing_flow/).

using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf,
Optimisers, CairoMakie

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

# ## Define & Load the Moons Dataset

# We define a function to generate data from the moons dataset. We use the code here from
# [this tutorial](https://liorsinai.github.io/machine-learning/2024/08/19/micrograd-5-mlp.html#moons-dataset).

function make_moons(
rng::AbstractRNG, ::Type{T}, n_samples::Int=100;
noise::Union{Nothing, AbstractFloat}=nothing
) where {T}
n_moons = n_samples ÷ 2
t_min, t_max = T(0), T(π)
t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
outer_circ_x = cos.(t_outer)
outer_circ_y = sin.(t_outer) .+ T(1)
inner_circ_x = 1 .- cos.(t_inner)
inner_circ_y = 1 .- sin.(t_inner) .- T(1)

data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y]
z = permutedims(data, (2, 1))
noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z)))
return z
end

# Let's visualize the dataset

begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

z = make_moons(Random.default_rng(), Float32, 10_000; noise=0.1)
scatter!(ax, z[1, :], z[2, :]; markersize=2)

fig
end

# ---

function load_moons_dataloader(
args...; batchsize::Int, noise::Union{Nothing, AbstractFloat}=nothing, kwargs...
)
return DataLoader(
make_moons(args...; noise); batchsize, shuffle=true, partial=false, kwargs...
)
end

# ## Bijectors Implementation

abstract type AbstractBijector end

@concrete struct AffineBijector <: AbstractBijector
shift <: AbstractArray
log_scale <: AbstractArray
end

function AffineBijector(shift_and_log_scale::AbstractArray{T, N}) where {T, N}
n = size(shift_and_log_scale, 1) ÷ 2
idxs = ntuple(Returns(Colon()), N - 1)
return AffineBijector(
shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...]
)
end

function forward_and_log_det(bj::AffineBijector, x::AbstractArray)
y = x .* exp.(bj.log_scale) .+ bj.shift
return y, bj.log_scale
end

function inverse_and_log_det(bj::AffineBijector, y::AbstractArray)
x = (y .- bj.shift) ./ exp.(bj.log_scale)
return x, -bj.log_scale
end

@concrete struct MaskedCoupling <: AbstractBijector
mask <: AbstractArray
conditioner
bijector
end

function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F}
x_masked = x .* (1 .- bj.mask)
bijector_params = bj.conditioner(x_masked)
y, log_det = fn(bijector_params)
log_det = log_det .* bj.mask
y = ifelse.(bj.mask, y, x)
return y, dsum(log_det; dims=Tuple(collect(1:(ndims(x) - 1))))
end

function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray)
return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x))
end

function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray)
return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y))
end

# ## Model Definition

function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation=gelu)
return Chain(
Dense(in_dims => hidden_dims, activation),
[Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]...,
Dense(hidden_dims => out_dims)
)
end

@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)}
conditioners
dist_dims::Int
n_transforms::Int
end

const StatefulRealNVP{M} = StatefulLuxLayer{M, <:RealNVP}

function Lux.initialstates(rng::AbstractRNG, l::RealNVP)
mask_list = [collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)] .|>
Vector{Bool}
return (; mask_list, conditioners=Lux.initialstates(rng, l.conditioners))
end

function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int)
conditioners = [MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation=gelu)
for _ in 1:n_transforms]
conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}(
Tuple(conditioners)
)
return RealNVP(conditioners, dist_dims, n_transforms)
end

log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log(2π)) .- T(0.5) .* abs2.(x)

function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T}
smodels = [StatefulLuxLayer{true}(
conditioner, l.ps.conditioners[i], l.st.conditioners[i])
for (i, conditioner) in enumerate(l.model.conditioners)]

lprob = zeros_like(x, size(x, ndims(x)))
for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels))
bj = MaskedCoupling(mask, conditioner, AffineBijector)
x, log_det = inverse_and_log_det(bj, x)
lprob += log_det
end
lprob += dsum(log_prob(x); dims=Tuple(collect(1:(ndims(x) - 1))))

conditioners = NamedTuple{ntuple(
Base.Fix1(Symbol, :conditioners_), l.model.n_transforms)
}(Tuple([smodel.st for smodel in smodels]))
l.st = merge(l.st, (; conditioners))

return lprob
end

function sample(
rng::AbstractRNG, ::Type{T}, d::StatefulRealNVP,
nsamples::Int, nsteps::Int=length(d.model.conditioners)
) where {T}
@assert 1 nsteps length(d.model.conditioners)

smodels = [StatefulLuxLayer{true}(
conditioner, d.ps.conditioners[i], d.st.conditioners[i])
for (i, conditioner) in enumerate(d.model.conditioners)]

x = randn(rng, T, d.model.dist_dims, nsamples)
for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels))
x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x)
i nsteps && break
end
return x
end

# ## Helper Functions

dsum(x; dims) = dropdims(sum(x; dims); dims)

function loss_function(model, ps, st, x)
smodel = StatefulLuxLayer{true}(model, ps, st)
lprob = log_prob(smodel, x)
return -mean(lprob), smodel.st, (;)
end

# ## Training the Model

function main(;
maxiters::Int=10_000, n_train_samples::Int=100_000, batchsize::Int=128,
n_transforms::Int=6, hidden_dims::Int=16, n_layers::Int=4,
lr::Float64=0.0004, noise::Float64=0.06
)
rng = Random.default_rng()
Random.seed!(rng, 0)

dataloader = load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |>
xdev |> Iterators.cycle

model = RealNVP(; n_transforms, dist_dims=2, hidden_dims, n_layers)
ps, st = Lux.setup(rng, model) |> xdev
opt = Adam(lr)

train_state = Training.TrainState(model, ps, st, opt)
@printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)

total_samples = 0
start_time = time()

for (iter, x) in enumerate(dataloader)
total_samples += size(x, ndims(x))
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, x, train_state;
return_gradients=Val(false)
)

isnan(loss) && error("NaN loss encountered in iter $(iter)!")

if iter == 1 || iter == maxiters || iter % 1000 == 0
throughput = total_samples / (time() - start_time)
@printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
Throughput: %.6f samples/s\n" iter maxiters loss throughput
end

iter maxiters && break
end

return StatefulLuxLayer{true}(
model, train_state.parameters, Lux.testmode(train_state.states)
)
end

trained_model = main()
nothing #hide

# ## Visualizing the Results
z_stages = Matrix{Float32}[]
for i in 1:(trained_model.model.n_transforms)
z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i)
push!(z_stages, Array(z))
end

begin
fig = Figure(; size=(1200, 800))

for (idx, z) in enumerate(z_stages)
i, j = (idx - 1) ÷ 3, (idx - 1) % 3
ax = Axis(fig[i, j]; title="$(idx) transforms")
scatter!(ax, z[1, :], z[2, :]; markersize=2)
end

fig
end
10 changes: 5 additions & 5 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ mutable struct StatsAndNewStateWrapper
end

function wrapped_objective_function(
fn::F, model, ps, st, data, cache::StatsAndNewStateWrapper
fn::F, model, ps, data, cache::StatsAndNewStateWrapper
) where {F}
loss, stₙ, stats = fn(model, ps, st, data)
loss, stₙ, stats = fn(model, ps, cache.st, data)
cache.stats = stats
cache.st = stₙ
return loss
end

function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
st_stats_wrapper = StatsAndNewStateWrapper(nothing, st)
res = Enzyme.gradient(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(wrapped_objective_function), Const(objective_function),
Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
Const(model), ps, Const(data), Const(st_stats_wrapper)
)
loss, dps = res.val, res.derivs[3]
return dps, loss, stats_wrapper.stats, stats_wrapper.st
return dps, loss, st_stats_wrapper.stats, st_stats_wrapper.st
end

function maybe_dump_to_mlir_file!(f::F, args...) where {F}
Expand Down

2 comments on commit 521fefd

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123373

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.5.2 -m "<description of version>" 521fefded8398091ed0b63c9cbce688d85d12571
git push origin v1.5.2

Please sign in to comment.