Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PositiveDefinite #89

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d275ee2
feat: add `PositiveDefinite` and corresponding tests
nicholaskl97 Dec 18, 2024
663e5e9
Added `NNlib` import to Positive Definite Container test
nicholaskl97 Dec 19, 2024
dde544c
Including and exporting `PositiveDefinite`
nicholaskl97 Dec 19, 2024
1ba4321
Fixed `PositiveDefinite` inner constructors
nicholaskl97 Dec 19, 2024
46ffc9e
Fixed incorrect function call in `PositiveDefinite` test
nicholaskl97 Dec 19, 2024
536eaea
Updated `PositiveDefinite` to account for possibly changing state of …
nicholaskl97 Jan 2, 2025
1dd9854
Replaced call to `mapslices` in `PositiveDefinite` with `mapreduce(..…
nicholaskl97 Jan 9, 2025
6ba2a24
Fixed broken call to `mapreduce` in `PositiveDefinite`
nicholaskl97 Jan 9, 2025
ebf0efe
Fixed typo in `PositiveDefinite` test and removed `==` from test in f…
nicholaskl97 Jan 9, 2025
8441c0a
Removed unnecessary definition of `norm2` from `PositiveDefinite`
nicholaskl97 Jan 9, 2025
4322db8
Added `ShiftTo` container
nicholaskl97 Jan 15, 2025
81557ac
Removed vector fields from `PositiveDefinite` and `ShiftTo`
nicholaskl97 Jan 22, 2025
bde526f
Added `init` to `PositiveDefinite` call to `mapreduce`
nicholaskl97 Jan 23, 2025
fbc2496
Fixed typo in `ShiftTo` test
nicholaskl97 Jan 23, 2025
2cac710
Formatting changes
nicholaskl97 Jan 23, 2025
acd1b8a
Added `@allowscalar` to `PeriodicEmbedding` when determining which in…
nicholaskl97 Jan 23, 2025
43a8846
Forgot to import `@allowscalar`
nicholaskl97 Jan 23, 2025
c3c9cd1
Removed `@allowscalar` from `PeriodicEmbedding`
nicholaskl97 Jan 23, 2025
d718310
Trying to match types better in `PositiveDefinite` `mapreduce` for th…
nicholaskl97 Jan 23, 2025
fceb28a
Make `PositiveDefinite` match/utilize `WeightInitializers`
nicholaskl97 Jan 23, 2025
8ca7f04
Forgot to import `WeightInitializers`
nicholaskl97 Jan 23, 2025
bdf3b78
Removed unnecessary `WeightInitializers` import
nicholaskl97 Jan 24, 2025
00d75b9
Merge branch 'main' into positive-definite-container
avik-pal Jan 24, 2025
9f6629e
`ShiftTo` and `PositiveDefinite` no longer ignore state from one of t…
nicholaskl97 Jan 24, 2025
3ea0e79
Simplified `PositiveDefinite` `mapreduce`
nicholaskl97 Jan 27, 2025
5c6888b
Fixed `PositiveDefinite` erroring on taking the gradient of `permuted…
nicholaskl97 Jan 30, 2025
76e0fe2
Improved `PositiveDefinite` test code coverage
nicholaskl97 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/ref.bib
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ @misc{zagoruyko2017wideresidualnetworks
primaryclass = {cs.CV},
url = {https://arxiv.org/abs/1605.07146}
}

@misc{gaby2022lyapunovnetdeepneuralnetwork,
title={Lyapunov-Net: A Deep Neural Network Architecture for Lyapunov Function Approximation},
author={Nathan Gaby and Fumin Zhang and Xiaojing Ye},
year={2022},
eprint={2109.13359},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2109.13359},
}
6 changes: 4 additions & 2 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Static: Static

using ForwardDiff: ForwardDiff

using Lux: Lux, LuxOps, StatefulLuxLayer
using Lux: Lux, LuxOps, StatefulLuxLayer, WeightInitializers
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer
using MLDataDevices: get_device, CPUDevice
using NNlib: NNlib
Expand All @@ -31,6 +31,7 @@ const NORM_LAYER_DOC = "Function with signature `f(i::Integer, dims::Integer, ac

include("attention.jl")
include("conv_norm_act.jl")
include("containers.jl")
include("dynamic_expressions.jl")
include("encoder.jl")
include("embeddings.jl")
Expand All @@ -42,6 +43,7 @@ include("tensor_product.jl")
@compat(public,
(ClassTokens, ConvBatchNormActivation, ConvNormActivation, DynamicExpressionsLayer,
HamiltonianNN, MultiHeadSelfAttention, MLP, PatchEmbedding, PeriodicEmbedding,
SplineLayer, TensorProductLayer, ViPosEmbedding, VisionTransformerEncoder))
PositiveDefinite, ShiftTo, SplineLayer, TensorProductLayer, ViPosEmbedding,
VisionTransformerEncoder))

end
144 changes: 144 additions & 0 deletions src/layers/containers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
PositiveDefinite(model, x0; ψ, r)
PositiveDefinite(model; in_dims, ψ, r)

Constructs a Lyapunov-Net [gaby2022lyapunovnetdeepneuralnetwork](@citep), which is positive
definite about `x0` whenever `ψ` and `r` meet certain conditions described below.

For a model `ϕ`,
`PositiveDefinite(ϕ, ψ, r, x0)(x, ps, st) = ψ(ϕ(x, ps, st) - ϕ(x0, ps, st)) + r(x, x0)`.
This results in a model which maps `x0` to `0` and any other input to a positive number
(i.e., a model which is positive definite about `x0`) whenever `ψ` is positive definite
about zero and `r` returns a positive number for any non-equal inputs and zero for equal
inputs.

## Arguments
- `model`: the underlying model being transformed into a positive definite function
- `x0`: The unique input that will be mapped to zero instead of a positive number

## Keyword Arguments
- `in_dims`: the number of input dimensions if `x0` is not provided; uses
`x0 = zeros(in_dims)`
- `ψ`: a positive definite function (about zero); defaults to ``ψ(x) = ||x||^2``
- `r`: a bivariate function such that `r(x0, x0) = 0` and
`r(x, x0) > 0` whenever `x ≠ x0`; defaults to ``r(x, y) = ||x - y||^2``

## Inputs
- `x`: will be passed directly into `model`, so must meet the input requirements of that
argument

## Returns
- The output of the positive definite model
- The state of the positive definite model. If the underlying model changes it state, the
state will be updated according to the call with the input `x`, not with the call using
`x0`.

## States
- `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value

## Parameters
- Same as the underlying `model`
"""
@concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model}
model <: AbstractLuxLayer
init_x0 <: Function
in_dims::Integer
ψ <: Function
r <: Function

function PositiveDefinite(model, x0::AbstractVector; ψ=Base.Fix1(sum, abs2),
r=Base.Fix1(sum, abs2) ∘ -)
return PositiveDefinite(model, (rng, in_dims) -> copy(x0), length(x0), ψ, r)
end
function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2),
r=Base.Fix1(sum, abs2) ∘ -)
return PositiveDefinite(model, zeros32, in_dims, ψ, r)
end
end

function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite)
return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.init_x0(rng, pd.in_dims))
end

function (pd::PositiveDefinite)(x::AbstractVector, ps, st)
out, new_st = pd(reshape(x, :, 1), ps, st)
return vec(out), new_st
end

function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st)
ϕ0, _ = pd.model(st.x0, ps, st.model)
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
ϕx, new_model_st = pd.model(x, ps, st.model)
ϕx_cols = eachcol(ϕx)
return (
permutedims(
mapreduce(vcat, zip(eachcol(x), ϕx_cols); init=empty(first(ϕx_cols))) do (x, ϕx)
pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0)
end
),
merge(st, (; model=new_model_st))
)
end

"""
ShiftTo(model, in_val, out_val)

Vertically shifts the output of `model` to output `out_val` when the input is `in_val`.

For a model `ϕ`, `ShiftTo(ϕ, in_val, out_val)(x, ps, st) = ϕ(x, ps, st) + Δϕ`,
where `Δϕ = out_val - ϕ(in_val, ps, st)`.

## Arguments
- `model`: the underlying model being transformed into a positive definite function
- `in_val`: The input that will be mapped to `out_val`
- `out_val`: The value that the output will be shifted to when the input is `in_val`

## Inputs
- `x`: will be passed directly into `model`, so must meet the input requirements of that
argument

## Returns
- The output of the shifted model
- The state of the shifted model. If the underlying model changes it state, the
state will be updated according to the call with the input `x`, not the call using
`in_val`.

## States
- `st`: a `NamedTuple` containing the state of the underlying `model` and the `in_val` and
`out_val` values

## Parameters
- Same as the underlying `model`
"""
@concrete struct ShiftTo <: AbstractLuxWrapperLayer{:model}
model <: AbstractLuxLayer
init_in_val <: Function
init_out_val <: Function
function ShiftTo(model, in_val::AbstractVector, out_val::AbstractVector)
_in_val = copy(in_val)
_out_val = copy(out_val)
return ShiftTo(model, () -> _in_val, () -> _out_val)
end
end

function LuxCore.initialstates(rng::AbstractRNG, s::ShiftTo)
return (;
model=LuxCore.initialstates(rng, s.model),
in_val=s.init_in_val(),
out_val=s.init_out_val()
)
end

function (s::ShiftTo)(x::AbstractVector, ps, st)
out, new_st = s(reshape(x, :, 1), ps, st)
return vec(out), new_st
end

function (s::ShiftTo)(x::AbstractMatrix, ps, st)
ϕ0, _ = s.model(st.in_val, ps, st.model)
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
Δϕ = st.out_val .- ϕ0
ϕx, new_model_st = s.model(x, ps, st.model)
return (
ϕx .+ Δϕ,
merge(st, (; model=new_model_st))
)
end
44 changes: 44 additions & 0 deletions test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,47 @@ end
end
end
end

@testitem "Positive Definite Container" setup=[SharedTestSetup] tags=[:layers] begin
using NNlib

@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
model = Layers.MLP(2, (4, 4, 2), NNlib.gelu)
pd = Layers.PositiveDefinite(model; in_dims=2)
ps, st = Lux.setup(StableRNG(0), pd) |> dev

x = randn(StableRNG(0), Float32, 2, 2) |> aType
x0 = zeros(Float32, 2) |> aType

y, _ = pd(x, ps, st)
z, _ = model(x, ps, st.model)
z0, _ = model(x0, ps, st.model)
y_by_hand = sum(abs2, z .- z0; dims=1) .+ sum(abs2, x .- x0; dims=1)

@test maximum(abs, y - y_by_hand) < 1.0f-8

@jet pd(x, ps, st)

__f = (x, ps) -> sum(first(pd(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
end
end

@testitem "ShiftTo Container" setup=[SharedTestSetup] tags=[:layers] begin
using NNlib

@testset "$(mode)" for (mode, aType, dev, ongpu) in MODES
model = Layers.MLP(2, (4, 4, 2), NNlib.gelu)
s = Layers.ShiftTo(model, ones(2), zeros(2))
ps, st = Lux.setup(StableRNG(0), s) |> dev

y0, _ = s(st.in_val, ps, st)
@test maximum(abs, y0) < 1.0f-8

x = randn(StableRNG(0), Float32, 2, 2) |> aType
@jet s(x, ps, st)

__f = (x, ps) -> sum(first(s(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
end
end