Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PositiveDefinite #89

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

nicholaskl97
Copy link

A PositiveDefinite container wraps an underlying model and results in a model that returns a postive number whenever the input is nonzero (or not equal to a different point specified when defining the container). This is useful, among other applications, in neural Lyapunov applications.

@avik-pal avik-pal force-pushed the positive-definite-container branch from 0445dd8 to ebf0efe Compare January 9, 2025 15:24
"""
@concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model}
model <: AbstractLuxLayer
x0 <: AbstractVector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't store a vector here. Instead pass in a initialization_function (ideally from WeightInitializers.jl) and construct the vector inside initialstates

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I've resolved this, but I'd like you to check if I did what you were asking when you get the chance.

Comment on lines 110 to 111
in_val <: AbstractVector
out_val <: AbstractVector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

src/layers/embeddings.jl Outdated Show resolved Hide resolved
end
function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2),
r=Base.Fix1(sum, abs2) ∘ -)
return PositiveDefinite(model, () -> zeros(in_dims), ψ, r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass in zeros32?

() -> copy(x0)

Here you can pass a dummy function that takes in (rng, in_dims) and ignores them

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like the change I just made?

end

function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st)
ϕ0, _ = pd.model(st.x0, ps, st.model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't ignore the returned states here

end

function (s::ShiftTo)(x::AbstractMatrix, ps, st)
ϕ0, _ = s.model(st.in_val, ps, st.model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Author

@nicholaskl97 nicholaskl97 Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If model is the mathematical function $$\phi$$, then ShiftTo is supposed to represent $$\phi(x) - \phi(x_0) + y_0$$, so that $$x_0$$ always gets mapped to $$y_0$$ (PositiveDefinite does something similar). In a sense, $$\phi(x)$$ and $$\phi(x_0)$$ are called at the same time in the mathematical representation. In my current code one has to be called first and using the output state of the first call as the input state of the second call means that my choice of which gets called first matters, but it really shouldn't. I figured it made the most sense just to use the state from the call $$\phi(x)$$.

Is it possible to combine the calls? I suppose I could call with hcat(x, st.in_val), but I don't know how I'd separate the last column back out without using the scalar indexing that CUDA hates so much.

@@ -11,7 +11,7 @@ using Static: Static

using ForwardDiff: ForwardDiff

using Lux: Lux, LuxOps, StatefulLuxLayer
using Lux: Lux, LuxOps, StatefulLuxLayer, WeightInitializers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import it as a package, not from Lux

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants