diff --git a/Project.toml b/Project.toml index bad32cec7..53ac0c162 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.5.1" +version = "1.5.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/make.jl b/docs/make.jl index ce91fb6b4..dcc31a8f8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -31,6 +31,7 @@ pages = [ "tutorials/intermediate/4_PINN2DPDE.md", "tutorials/intermediate/5_ConvolutionalVAE.md", "tutorials/intermediate/6_GCN_Cora.md", + "tutorials/intermediate/7_RealNVP.md", ], "Advanced" => [ "tutorials/advanced/1_GravitationalWaveForm.md" diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index a66f2e5d3..12c1c8468 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -226,6 +226,10 @@ export default defineConfig({ text: "Graph Convolutional Network on Cora", link: "/tutorials/intermediate/6_GCN_Cora", }, + { + text: "Normalizing Flows for Density Estimation", + link: "/tutorials/intermediate/7_RealNVP", + } ], }, { diff --git a/docs/src/public/realnvp.png b/docs/src/public/realnvp.png new file mode 100644 index 000000000..d8b330f4b Binary files /dev/null and b/docs/src/public/realnvp.png differ diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index bfbacf500..f7e17c169 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -77,6 +77,12 @@ const intermediate = [ src: "../gcn_cora.jpg", caption: "Graph Convolutional Network on Cora", desc: "Train a Graph Convolutional Network on Cora dataset." + }, + { + href: "intermediate/7_RealNVP", + src: "../realnvp.png", + caption: "Normalizing Flows for Density Estimation", + desc: "Train a normalizing flow for density estimation on the Moons dataset.", } ]; diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 4f878e6f8..806a876eb 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -14,6 +14,7 @@ const INTERMEDIATE_TUTORIALS = [ "PINN2DPDE/main.jl" => "CUDA", "ConvolutionalVAE/main.jl" => "CUDA", "GCN_Cora/main.jl" => "CUDA", + "RealNVP/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl" => "CPU", diff --git a/examples/GCN_Cora/Project.toml b/examples/GCN_Cora/Project.toml index 376c28306..eca9ad7bb 100644 --- a/examples/GCN_Cora/Project.toml +++ b/examples/GCN_Cora/Project.toml @@ -4,7 +4,6 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -18,7 +17,6 @@ Enzyme = "0.13.28" GNNGraphs = "1" Lux = "1.5" MLDatasets = "0.7.18" -MLUtils = "0.4.5" OneHotArrays = "0.2" Optimisers = "0.4.4" Printf = "1.10" diff --git a/examples/GCN_Cora/main.jl b/examples/GCN_Cora/main.jl index fc322471e..cc44c3c18 100644 --- a/examples/GCN_Cora/main.jl +++ b/examples/GCN_Cora/main.jl @@ -3,8 +3,8 @@ # This example is based on [GCN MLX tutorial](https://github.com/ml-explore/mlx-examples/blob/main/gcn/). While we are doing this manually, we recommend directly using # [GNNLux.jl](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/). -using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, GNNGraphs, MLUtils, - ConcreteStructs, Printf, OneHotArrays, Optimisers +using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, GNNGraphs, ConcreteStructs, + Printf, OneHotArrays, Optimisers const xdev = reactant_device(; force=true) const cdev = cpu_device() diff --git a/examples/RealNVP/Project.toml b/examples/RealNVP/Project.toml new file mode 100644 index 000000000..5979705b7 --- /dev/null +++ b/examples/RealNVP/Project.toml @@ -0,0 +1,11 @@ +[deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/examples/RealNVP/main.jl b/examples/RealNVP/main.jl new file mode 100644 index 000000000..a90955232 --- /dev/null +++ b/examples/RealNVP/main.jl @@ -0,0 +1,259 @@ +# # [Normalizing Flows for Density Estimation](@id RealNVP-Tutorial) + +# This tutorial demonstrates how to use Lux to train a +# [RealNVP](https://arxiv.org/abs/1605.08803). This is based on the +# [RealNVP implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/normalizing_flow/). + +using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf, + Optimisers, CairoMakie + +const xdev = reactant_device(; force=true) +const cdev = cpu_device() + +# ## Define & Load the Moons Dataset + +# We define a function to generate data from the moons dataset. We use the code here from +# [this tutorial](https://liorsinai.github.io/machine-learning/2024/08/19/micrograd-5-mlp.html#moons-dataset). + +function make_moons( + rng::AbstractRNG, ::Type{T}, n_samples::Int=100; + noise::Union{Nothing, AbstractFloat}=nothing +) where {T} + n_moons = n_samples ÷ 2 + t_min, t_max = T(0), T(π) + t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min + t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min + outer_circ_x = cos.(t_outer) + outer_circ_y = sin.(t_outer) .+ T(1) + inner_circ_x = 1 .- cos.(t_inner) + inner_circ_y = 1 .- sin.(t_inner) .- T(1) + + data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y] + z = permutedims(data, (2, 1)) + noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z))) + return z +end + +# Let's visualize the dataset + +begin + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y") + + z = make_moons(Random.default_rng(), Float32, 10_000; noise=0.1) + scatter!(ax, z[1, :], z[2, :]; markersize=2) + + fig +end + +# --- + +function load_moons_dataloader( + args...; batchsize::Int, noise::Union{Nothing, AbstractFloat}=nothing, kwargs... +) + return DataLoader( + make_moons(args...; noise); batchsize, shuffle=true, partial=false, kwargs... + ) +end + +# ## Bijectors Implementation + +abstract type AbstractBijector end + +@concrete struct AffineBijector <: AbstractBijector + shift <: AbstractArray + log_scale <: AbstractArray +end + +function AffineBijector(shift_and_log_scale::AbstractArray{T, N}) where {T, N} + n = size(shift_and_log_scale, 1) ÷ 2 + idxs = ntuple(Returns(Colon()), N - 1) + return AffineBijector( + shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...] + ) +end + +function forward_and_log_det(bj::AffineBijector, x::AbstractArray) + y = x .* exp.(bj.log_scale) .+ bj.shift + return y, bj.log_scale +end + +function inverse_and_log_det(bj::AffineBijector, y::AbstractArray) + x = (y .- bj.shift) ./ exp.(bj.log_scale) + return x, -bj.log_scale +end + +@concrete struct MaskedCoupling <: AbstractBijector + mask <: AbstractArray + conditioner + bijector +end + +function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F} + x_masked = x .* (1 .- bj.mask) + bijector_params = bj.conditioner(x_masked) + y, log_det = fn(bijector_params) + log_det = log_det .* bj.mask + y = ifelse.(bj.mask, y, x) + return y, dsum(log_det; dims=Tuple(collect(1:(ndims(x) - 1)))) +end + +function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray) + return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x)) +end + +function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray) + return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y)) +end + +# ## Model Definition + +function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation=gelu) + return Chain( + Dense(in_dims => hidden_dims, activation), + [Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]..., + Dense(hidden_dims => out_dims) + ) +end + +@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)} + conditioners + dist_dims::Int + n_transforms::Int +end + +const StatefulRealNVP{M} = StatefulLuxLayer{M, <:RealNVP} + +function Lux.initialstates(rng::AbstractRNG, l::RealNVP) + mask_list = [collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)] .|> + Vector{Bool} + return (; mask_list, conditioners=Lux.initialstates(rng, l.conditioners)) +end + +function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int) + conditioners = [MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation=gelu) + for _ in 1:n_transforms] + conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}( + Tuple(conditioners) + ) + return RealNVP(conditioners, dist_dims, n_transforms) +end + +log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log(2π)) .- T(0.5) .* abs2.(x) + +function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T} + smodels = [StatefulLuxLayer{true}( + conditioner, l.ps.conditioners[i], l.st.conditioners[i]) + for (i, conditioner) in enumerate(l.model.conditioners)] + + lprob = zeros_like(x, size(x, ndims(x))) + for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels)) + bj = MaskedCoupling(mask, conditioner, AffineBijector) + x, log_det = inverse_and_log_det(bj, x) + lprob += log_det + end + lprob += dsum(log_prob(x); dims=Tuple(collect(1:(ndims(x) - 1)))) + + conditioners = NamedTuple{ntuple( + Base.Fix1(Symbol, :conditioners_), l.model.n_transforms) + }(Tuple([smodel.st for smodel in smodels])) + l.st = merge(l.st, (; conditioners)) + + return lprob +end + +function sample( + rng::AbstractRNG, ::Type{T}, d::StatefulRealNVP, + nsamples::Int, nsteps::Int=length(d.model.conditioners) +) where {T} + @assert 1 ≤ nsteps ≤ length(d.model.conditioners) + + smodels = [StatefulLuxLayer{true}( + conditioner, d.ps.conditioners[i], d.st.conditioners[i]) + for (i, conditioner) in enumerate(d.model.conditioners)] + + x = randn(rng, T, d.model.dist_dims, nsamples) + for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels)) + x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x) + i ≥ nsteps && break + end + return x +end + +# ## Helper Functions + +dsum(x; dims) = dropdims(sum(x; dims); dims) + +function loss_function(model, ps, st, x) + smodel = StatefulLuxLayer{true}(model, ps, st) + lprob = log_prob(smodel, x) + return -mean(lprob), smodel.st, (;) +end + +# ## Training the Model + +function main(; + maxiters::Int=10_000, n_train_samples::Int=100_000, batchsize::Int=128, + n_transforms::Int=6, hidden_dims::Int=16, n_layers::Int=4, + lr::Float64=0.0004, noise::Float64=0.06 +) + rng = Random.default_rng() + Random.seed!(rng, 0) + + dataloader = load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |> + xdev |> Iterators.cycle + + model = RealNVP(; n_transforms, dist_dims=2, hidden_dims, n_layers) + ps, st = Lux.setup(rng, model) |> xdev + opt = Adam(lr) + + train_state = Training.TrainState(model, ps, st, opt) + @printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps) + + total_samples = 0 + start_time = time() + + for (iter, x) in enumerate(dataloader) + total_samples += size(x, ndims(x)) + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), loss_function, x, train_state; + return_gradients=Val(false) + ) + + isnan(loss) && error("NaN loss encountered in iter $(iter)!") + + if iter == 1 || iter == maxiters || iter % 1000 == 0 + throughput = total_samples / (time() - start_time) + @printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\ + Throughput: %.6f samples/s\n" iter maxiters loss throughput + end + + iter ≥ maxiters && break + end + + return StatefulLuxLayer{true}( + model, train_state.parameters, Lux.testmode(train_state.states) + ) +end + +trained_model = main() +nothing #hide + +# ## Visualizing the Results +z_stages = Matrix{Float32}[] +for i in 1:(trained_model.model.n_transforms) + z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i) + push!(z_stages, Array(z)) +end + +begin + fig = Figure(; size=(1200, 800)) + + for (idx, z) in enumerate(z_stages) + i, j = (idx - 1) ÷ 3, (idx - 1) % 3 + ax = Axis(fig[i, j]; title="$(idx) transforms") + scatter!(ax, z[1, :], z[2, :]; markersize=2) + end + + fig +end diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index d6c0c1c8d..306935fd1 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -4,23 +4,23 @@ mutable struct StatsAndNewStateWrapper end function wrapped_objective_function( - fn::F, model, ps, st, data, cache::StatsAndNewStateWrapper + fn::F, model, ps, data, cache::StatsAndNewStateWrapper ) where {F} - loss, stₙ, stats = fn(model, ps, st, data) + loss, stₙ, stats = fn(model, ps, cache.st, data) cache.stats = stats cache.st = stₙ return loss end function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} - stats_wrapper = StatsAndNewStateWrapper(nothing, nothing) + st_stats_wrapper = StatsAndNewStateWrapper(nothing, st) res = Enzyme.gradient( Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), Const(wrapped_objective_function), Const(objective_function), - Const(model), ps, Const(st), Const(data), Const(stats_wrapper) + Const(model), ps, Const(data), Const(st_stats_wrapper) ) loss, dps = res.val, res.derivs[3] - return dps, loss, stats_wrapper.stats, stats_wrapper.st + return dps, loss, st_stats_wrapper.stats, st_stats_wrapper.st end function maybe_dump_to_mlir_file!(f::F, args...) where {F}