From 7e2b8dcd12f264833db18b57e814de100aedf6df Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 07:12:22 +0200 Subject: [PATCH 01/17] first draft --- Project.toml | 2 + src/KernelFunctions.jl | 1 + src/diffKernel.jl | 108 +++++++++++++++++++++++++++++++++++++++++ test/diffKernel.jl | 25 ++++++++++ test/runtests.jl | 1 + 5 files changed, 137 insertions(+) create mode 100644 src/diffKernel.jl create mode 100644 test/diffKernel.jl diff --git a/Project.toml b/Project.toml index c763a8824..f951461be 100644 --- a/Project.toml +++ b/Project.toml @@ -8,10 +8,12 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..6524a9092 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -125,6 +125,7 @@ include("chainrules.jl") include("zygoterules.jl") include("TestUtils.jl") +include("diffKernel.jl") function __init__() @require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin diff --git a/src/diffKernel.jl b/src/diffKernel.jl new file mode 100644 index 000000000..b4daaffee --- /dev/null +++ b/src/diffKernel.jl @@ -0,0 +1,108 @@ +using OneHotArrays: OneHotVector +import ForwardDiff as FD +import LinearAlgebra as LA + +""" + DiffPt(x; partial=()) + +For a covariance kernel k of GP Z, i.e. +```julia + k(x,y) # = Cov(Z(x), Z(y)), +``` +a DiffPt allows the differentiation of Z, i.e. +```julia + k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y)) +``` +for higher order derivatives partial can be any iterable, i.e. +```julia + k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y)) +``` +""" +struct DiffPt{Dim} + pos # the actual position + partial +end + +DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor + +""" +Take the partial derivative of a function `fun` with input dimesion `dim`. +If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned. +""" +function partial(fun, dim, partials=()) + if !isnothing(local next = iterate(partials)) + idx, state = next + return partial( + x -> FD.derivative(0) do dx + fun(x .+ dx * OneHotVector(idx, dim)) + end, + dim, + Base.rest(partials, state), + ) + end + return fun +end + +""" +Take the partial derivative of a function with two dim-dimensional inputs, +i.e. 2*dim dimensional input +""" +function partial(k, dim; partials_x=(), partials_y=()) + local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x) + return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y) +end + + + + +""" + _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} + +implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since +generics are not allowed in the syntax above by the dispatch system, this +redirection over `_evaluate` is necessary + +unboxes the partial instructions from DiffPt and applies them to k, +evaluates them at the positions of DiffPt +""" +function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} + return partial( + k, Dim, + partials_x=x.partial, partials_y=y.partial + )(x.pos, y.pos) +end + + + +#= +This is a hack to work around the fact that the `where {T<:Kernel}` clause is +not allowed for the `(::T)(x,y)` syntax. If we were to only implement +```julia + (::Kernel)(::DiffPt,::DiffPt) +``` +then julia would not know whether to use +`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` +``` +To avoid this hack, no kernel type T should implement +```julia + (::T)(x,y) +``` +and instead implement +```julia + _evaluate(k::T, x, y) +``` +Then there should be only a single +```julia + (k::Kernel)(x,y) = evaluate(k, x, y) +``` +which all the kernels would fall back to. + +This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always +more specialized and call beforehand. +=# +for T in [SimpleKernel, Kernel] #subtypes(Kernel) + (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = evaluate(k, x, y) + (k::T)(x::DiffPt{Dim}, y) where {Dim} = evaluate(k, x, DiffPt(y)) + (k::T)(x, y::DiffPt{Dim}) where {Dim} = evaluate(k, DiffPt(x), y) +end + diff --git a/test/diffKernel.jl b/test/diffKernel.jl new file mode 100644 index 000000000..51d86e1b4 --- /dev/null +++ b/test/diffKernel.jl @@ -0,0 +1,25 @@ +@testset "diffKernel" begin + @testset "smoke test" begin + k = MaternKernel() + k(1,1) + k(1, DiffPt(1, partial=(1,1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 + k(DiffPt([1], partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] + k(DiffPt([1,2], partial=(1)), DiffPt([1,2], partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] + end + + @testset "Sanity Checks with $k" for k in [MaternKernel()] + for x in [0, 1, -1, 42] + # for stationary kernels Cov(∂Z(x) , Z(x)) = 0 + @test k(DiffPt(x, partial=1), x) ≈ 0 + + # the slope should be positively correlated with a point further down + @test k( + DiffPt(x, partial=1), # slope + x + 1e-10 # point further down + ) > 0 + + # correlation with self should be positive + @test k(DiffPt(x, partial=1), DiffPt(x, partial=1)) > 0 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..accdea1f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -176,6 +176,7 @@ include("test_utils.jl") include("generic.jl") include("chainrules.jl") include("zygoterules.jl") + include("diffKernel.jl") @testset "doctests" begin DocMeta.setdocmeta!( From 61798991755d6fdf79fe7a559bcaff4d6235247b Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 07:31:51 +0200 Subject: [PATCH 02/17] export DiffPt --- src/KernelFunctions.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 6524a9092..a291fbc59 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -43,6 +43,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou export IndependentMOKernel, LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel +export DiffPt + # Reexports export tensor, ⊗, compose From 74085db63a002c9099a682cbf1fb41d764a912f8 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 07:44:02 +0200 Subject: [PATCH 03/17] green tests --- src/diffKernel.jl | 8 ++++---- test/diffKernel.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index b4daaffee..2d253852f 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -93,7 +93,7 @@ and instead implement ``` Then there should be only a single ```julia - (k::Kernel)(x,y) = evaluate(k, x, y) + (k::Kernel)(x,y) = _evaluate(k, x, y) ``` which all the kernels would fall back to. @@ -101,8 +101,8 @@ This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always more specialized and call beforehand. =# for T in [SimpleKernel, Kernel] #subtypes(Kernel) - (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = evaluate(k, x, y) - (k::T)(x::DiffPt{Dim}, y) where {Dim} = evaluate(k, x, DiffPt(y)) - (k::T)(x, y::DiffPt{Dim}) where {Dim} = evaluate(k, DiffPt(x), y) + (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) + (k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y)) + (k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y) end diff --git a/test/diffKernel.jl b/test/diffKernel.jl index 51d86e1b4..dad313e82 100644 --- a/test/diffKernel.jl +++ b/test/diffKernel.jl @@ -7,7 +7,7 @@ k(DiffPt([1,2], partial=(1)), DiffPt([1,2], partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] end - @testset "Sanity Checks with $k" for k in [MaternKernel()] + @testset "Sanity Checks with $k" for k in [SEKernel()] for x in [0, 1, -1, 42] # for stationary kernels Cov(∂Z(x) , Z(x)) = 0 @test k(DiffPt(x, partial=1), x) ≈ 0 @@ -15,7 +15,7 @@ # the slope should be positively correlated with a point further down @test k( DiffPt(x, partial=1), # slope - x + 1e-10 # point further down + x + 1e-1 # point further down ) > 0 # correlation with self should be positive From 21980a91149a47a47f21aa3fb3b102cf849eceac Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 07:49:31 +0200 Subject: [PATCH 04/17] hack avoidance explanation in PR --- src/diffKernel.jl | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index 2d253852f..66ed535c1 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -83,22 +83,6 @@ not allowed for the `(::T)(x,y)` syntax. If we were to only implement then julia would not know whether to use `(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` ``` -To avoid this hack, no kernel type T should implement -```julia - (::T)(x,y) -``` -and instead implement -```julia - _evaluate(k::T, x, y) -``` -Then there should be only a single -```julia - (k::Kernel)(x,y) = _evaluate(k, x, y) -``` -which all the kernels would fall back to. - -This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always -more specialized and call beforehand. =# for T in [SimpleKernel, Kernel] #subtypes(Kernel) (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) From 8f7449513904c7fbe2e97f94c38e29e6912df0d7 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 07:50:46 +0200 Subject: [PATCH 05/17] run formatter --- src/diffKernel.jl | 49 ++++++++++++++++++----------------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index 66ed535c1..eac172596 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -19,28 +19,26 @@ for higher order derivatives partial can be any iterable, i.e. ``` """ struct DiffPt{Dim} - pos # the actual position - partial + pos # the actual position + partial end -DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor +DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor """ Take the partial derivative of a function `fun` with input dimesion `dim`. If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned. """ function partial(fun, dim, partials=()) - if !isnothing(local next = iterate(partials)) - idx, state = next - return partial( - x -> FD.derivative(0) do dx - fun(x .+ dx * OneHotVector(idx, dim)) - end, - dim, - Base.rest(partials, state), - ) - end - return fun + if !isnothing(local next = iterate(partials)) + idx, state = next + return partial( + x -> FD.derivative(0) do dx + fun(x .+ dx * OneHotVector(idx, dim)) + end, dim, Base.rest(partials, state) + ) + end + return fun end """ @@ -48,13 +46,10 @@ Take the partial derivative of a function with two dim-dimensional inputs, i.e. 2*dim dimensional input """ function partial(k, dim; partials_x=(), partials_y=()) - local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x) - return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y) + local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x) + return (x, y) -> partial(t -> f(x, t), dim, partials_y)(y) end - - - """ _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} @@ -65,15 +60,10 @@ redirection over `_evaluate` is necessary unboxes the partial instructions from DiffPt and applies them to k, evaluates them at the positions of DiffPt """ -function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} - return partial( - k, Dim, - partials_x=x.partial, partials_y=y.partial - )(x.pos, y.pos) +function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim,T<:Kernel} + return partial(k, Dim; partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos) end - - #= This is a hack to work around the fact that the `where {T<:Kernel}` clause is not allowed for the `(::T)(x,y)` syntax. If we were to only implement @@ -85,8 +75,7 @@ then julia would not know whether to use ``` =# for T in [SimpleKernel, Kernel] #subtypes(Kernel) - (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) - (k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y)) - (k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y) + (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) + (k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y)) + (k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y) end - From 9c4ff2e371060374a4dd1b926f4ac071a1d75f3f Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 17 May 2023 07:51:09 +0200 Subject: [PATCH 06/17] also on tests --- test/diffKernel.jl | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/test/diffKernel.jl b/test/diffKernel.jl index dad313e82..d17c1cbc3 100644 --- a/test/diffKernel.jl +++ b/test/diffKernel.jl @@ -1,25 +1,25 @@ @testset "diffKernel" begin - @testset "smoke test" begin - k = MaternKernel() - k(1,1) - k(1, DiffPt(1, partial=(1,1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 - k(DiffPt([1], partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] - k(DiffPt([1,2], partial=(1)), DiffPt([1,2], partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] - end + @testset "smoke test" begin + k = MaternKernel() + k(1, 1) + k(1, DiffPt(1; partial=(1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 + k(DiffPt([1]; partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] + k(DiffPt([1, 2]; partial=(1)), DiffPt([1, 2]; partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] + end - @testset "Sanity Checks with $k" for k in [SEKernel()] - for x in [0, 1, -1, 42] - # for stationary kernels Cov(∂Z(x) , Z(x)) = 0 - @test k(DiffPt(x, partial=1), x) ≈ 0 + @testset "Sanity Checks with $k" for k in [SEKernel()] + for x in [0, 1, -1, 42] + # for stationary kernels Cov(∂Z(x) , Z(x)) = 0 + @test k(DiffPt(x; partial=1), x) ≈ 0 - # the slope should be positively correlated with a point further down - @test k( - DiffPt(x, partial=1), # slope - x + 1e-1 # point further down - ) > 0 + # the slope should be positively correlated with a point further down + @test k( + DiffPt(x; partial=1), # slope + x + 1e-1, # point further down + ) > 0 - # correlation with self should be positive - @test k(DiffPt(x, partial=1), DiffPt(x, partial=1)) > 0 - end - end + # correlation with self should be positive + @test k(DiffPt(x; partial=1), DiffPt(x; partial=1)) > 0 + end + end end From f2f5203817a97b88fadc088d6ecb0b526a9e64f0 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 15:08:28 +0200 Subject: [PATCH 07/17] varargs variant --- benchmark/benchmarks.jl | 14 ++++++++--- src/diffKernel.jl | 55 ++++++++++++++++++++++------------------- test/Project.toml | 1 + test/diffKernel.jl | 12 ++++----- 4 files changed, 47 insertions(+), 35 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 0cb80e5ce..89f69e6d1 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -23,7 +23,7 @@ kernels = Dict( inputtypes = Dict("ColVecs" => (Xc, Yc), "RowVecs" => (Xr, Yr), "Vecs" => (Xv, Yv)) functions = Dict( - "kernelmatrixX" => (kernel, X, Y) -> kernelmatrix(kernel, X), + "kernelmatrixX" => (kernel, X, Y) -> invoke(kernelmatrix, Tuple{kernel, Any}, kernel, X), "kernelmatrixXY" => (kernel, X, Y) -> kernelmatrix(kernel, X, Y), "kernelmatrix_diagX" => (kernel, X, Y) -> kernelmatrix_diag(kernel, X), "kernelmatrix_diagXY" => (kernel, X, Y) -> kernelmatrix_diag(kernel, X, Y), @@ -41,6 +41,14 @@ end # Uncomment the following to run benchmark locally -# tune!(SUITE) +tune!(SUITE) -# results = run(SUITE, verbose=true) +results = run(SUITE, verbose=true) + +Xc = ColVecs(rand(2, 2000)) +k = SqExponentialKernel() + +@which kernelmatrix(k, Xc) +@btime kernelmatrix($k, $Xc); +@btime invoke(kernelmatrix, Tuple{KernelFunctions.SimpleKernel, $typeof(Xc)}, $k, $Xc); +# @btime invoke(kernelmatrix, Tuple{Kernel, $typeof(Xc)}, $k, $Xc); \ No newline at end of file diff --git a/src/diffKernel.jl b/src/diffKernel.jl index eac172596..6dc62656c 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -18,36 +18,39 @@ for higher order derivatives partial can be any iterable, i.e. k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y)) ``` """ -struct DiffPt{Dim} - pos # the actual position - partial + +IndexType = Union{Int,Base.AbstractCartesianIndex} + +struct DiffPt{Order,KeyT<:Union{Int,IndexType},T} + pos::T # the actual position + partials::NTuple{Order,KeyT} end -DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor +DiffPt(x::T) where {T <: AbstractArray} = DiffPt{0, keytype(T), T}(x, ()::NTuple{0,keytype(T)}) +DiffPt(x::T) where {T <: Number} = DiffPt{0,Int,T}(x, ()::NTuple{0,Int}) +DiffPt(x::T, partial::Int) where T = DiffPt{1,Int,T}(x, (partial,)) +DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT} = DiffPt{Order,KeyT,T}(x, partials) -""" -Take the partial derivative of a function `fun` with input dimesion `dim`. -If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned. -""" -function partial(fun, dim, partials=()) - if !isnothing(local next = iterate(partials)) - idx, state = next - return partial( - x -> FD.derivative(0) do dx - fun(x .+ dx * OneHotVector(idx, dim)) - end, dim, Base.rest(partials, state) - ) - end - return fun +partial(func) = func +function partial(func, partials::Int...) + idx, state = iterate(partials) + return partial( + x -> FD.derivative(0) do dx + return func(x .+ dx * OneHotVector(idx, length(x))) + end, + Base.rest(partials, state)..., + ) end """ Take the partial derivative of a function with two dim-dimensional inputs, i.e. 2*dim dimensional input """ -function partial(k, dim; partials_x=(), partials_y=()) - local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x) - return (x, y) -> partial(t -> f(x, t), dim, partials_y)(y) +function partial( + k::Fun, partials_x::NTuple{N,T}, partials_y::NTuple{M,T} +) where {N,M,T<:IndexType} + local f(x, y) = partial(t -> k(t, y), partials_x...)(x) + return (x, y) -> partial(t -> f(x, t), partials_y...)(y) end """ @@ -60,8 +63,8 @@ redirection over `_evaluate` is necessary unboxes the partial instructions from DiffPt and applies them to k, evaluates them at the positions of DiffPt """ -function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim,T<:Kernel} - return partial(k, Dim; partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos) +function _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel} + return partial(k, x.partials, y.partials)(x.pos, y.pos) end #= @@ -75,7 +78,7 @@ then julia would not know whether to use ``` =# for T in [SimpleKernel, Kernel] #subtypes(Kernel) - (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y) - (k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y)) - (k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y) + (k::T)(x::DiffPt, y::DiffPt) = _evaluate(k, x, y) + (k::T)(x::DiffPt, y) = _evaluate(k, x, DiffPt(y)) + (k::T)(x, y::DiffPt) = _evaluate(k, DiffPt(x), y) end diff --git a/test/Project.toml b/test/Project.toml index e16f39a6c..5b21d7b70 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/diffKernel.jl b/test/diffKernel.jl index d17c1cbc3..3a5e3c7e1 100644 --- a/test/diffKernel.jl +++ b/test/diffKernel.jl @@ -2,24 +2,24 @@ @testset "smoke test" begin k = MaternKernel() k(1, 1) - k(1, DiffPt(1; partial=(1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 - k(DiffPt([1]; partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] - k(DiffPt([1, 2]; partial=(1)), DiffPt([1, 2]; partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] + k(1, DiffPt(1, (1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 + k(DiffPt([1], 1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] + k(DiffPt([1, 2], 1), DiffPt([1, 2], 2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] end @testset "Sanity Checks with $k" for k in [SEKernel()] for x in [0, 1, -1, 42] # for stationary kernels Cov(∂Z(x) , Z(x)) = 0 - @test k(DiffPt(x; partial=1), x) ≈ 0 + @test k(DiffPt(x, 1), x) ≈ 0 # the slope should be positively correlated with a point further down @test k( - DiffPt(x; partial=1), # slope + DiffPt(x, 1), # slope x + 1e-1, # point further down ) > 0 # correlation with self should be positive - @test k(DiffPt(x; partial=1), DiffPt(x; partial=1)) > 0 + @test k(DiffPt(x, 1), DiffPt(x, 1)) > 0 end end end From 16938a9333aa71b29e783af77b9b861cf31e3e32 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 15:10:27 +0200 Subject: [PATCH 08/17] format --- src/diffKernel.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index 6dc62656c..dffbd9f5e 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -26,10 +26,12 @@ struct DiffPt{Order,KeyT<:Union{Int,IndexType},T} partials::NTuple{Order,KeyT} end -DiffPt(x::T) where {T <: AbstractArray} = DiffPt{0, keytype(T), T}(x, ()::NTuple{0,keytype(T)}) -DiffPt(x::T) where {T <: Number} = DiffPt{0,Int,T}(x, ()::NTuple{0,Int}) -DiffPt(x::T, partial::Int) where T = DiffPt{1,Int,T}(x, (partial,)) -DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT} = DiffPt{Order,KeyT,T}(x, partials) +DiffPt(x::T) where {T<:AbstractArray} = DiffPt{0,keytype(T),T}(x, ()::NTuple{0,keytype(T)}) +DiffPt(x::T) where {T<:Number} = DiffPt{0,Int,T}(x, ()::NTuple{0,Int}) +DiffPt(x::T, partial::Int) where {T} = DiffPt{1,Int,T}(x, (partial,)) +function DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT} + return DiffPt{Order,KeyT,T}(x, partials) +end partial(func) = func function partial(func, partials::Int...) From 13c3cb148cb406a04849b76655d826328d417455 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 16:52:41 +0200 Subject: [PATCH 09/17] fix bug --- src/diffKernel.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index dffbd9f5e..8b0b1aa0c 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -21,19 +21,24 @@ for higher order derivatives partial can be any iterable, i.e. IndexType = Union{Int,Base.AbstractCartesianIndex} -struct DiffPt{Order,KeyT<:Union{Int,IndexType},T} +struct DiffPt{Order,KeyT<:IndexType,T} pos::T # the actual position partials::NTuple{Order,KeyT} end DiffPt(x::T) where {T<:AbstractArray} = DiffPt{0,keytype(T),T}(x, ()::NTuple{0,keytype(T)}) DiffPt(x::T) where {T<:Number} = DiffPt{0,Int,T}(x, ()::NTuple{0,Int}) -DiffPt(x::T, partial::Int) where {T} = DiffPt{1,Int,T}(x, (partial,)) +DiffPt(x::T, partial::IndexType) where {T} = DiffPt{1,IndexType,T}(x, (partial,)) function DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT} return DiffPt{Order,KeyT,T}(x, partials) end partial(func) = func +function partial(func, idx::Int) + return x -> FD.derivative(0) do dx + return func(x .+ dx * OneHotVector(idx, length(x))) + end +end function partial(func, partials::Int...) idx, state = iterate(partials) return partial( @@ -49,7 +54,7 @@ Take the partial derivative of a function with two dim-dimensional inputs, i.e. 2*dim dimensional input """ function partial( - k::Fun, partials_x::NTuple{N,T}, partials_y::NTuple{M,T} + k, partials_x::NTuple{N,T}, partials_y::NTuple{M,T} ) where {N,M,T<:IndexType} local f(x, y) = partial(t -> k(t, y), partials_x...)(x) return (x, y) -> partial(t -> f(x, t), partials_y...)(y) From 6b7a5e855a4dff7d7fe7f5a17a121684b9ff0c36 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 17:48:35 +0200 Subject: [PATCH 10/17] trying to get rid of oneHotVector --- src/diffKernel.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index 8b0b1aa0c..2295059da 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -1,4 +1,3 @@ -using OneHotArrays: OneHotVector import ForwardDiff as FD import LinearAlgebra as LA @@ -33,18 +32,30 @@ function DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT} return DiffPt{Order,KeyT,T}(x, partials) end -partial(func) = func -function partial(func, idx::Int) - return x -> FD.derivative(0) do dx - return func(x .+ dx * OneHotVector(idx, length(x))) +""" + tangentCurve(x₀, i::IndexType) +returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index i +""" +function tangentCurve(x0::AbstractArray{N,T}, idx::IndexType) where {N, T} + return t -> begin + x = similar(x0) + copyto!(x, x0) + x[idx] +=t + return x end end -function partial(func, partials::Int...) +function tangentCurve(x0::Number, ::IndexType) + return t -> x0 + t +end + +partial(func) = func +function partial(func, idx::IndexType) + return x -> FD.derivative(func ∘ tangentCurve(x, idx), 0) +end +function partial(func, partials::IndexType...) idx, state = iterate(partials) return partial( - x -> FD.derivative(0) do dx - return func(x .+ dx * OneHotVector(idx, length(x))) - end, + x -> FD.derivative(func ∘ tangentCurve(x, idx), 0), Base.rest(partials, state)..., ) end From f57e4d6d21cc6497d481b47839eddb9ae5d588a9 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 17:52:28 +0200 Subject: [PATCH 11/17] format --- src/diffKernel.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index 2295059da..e7fcd1e92 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -36,11 +36,11 @@ end tangentCurve(x₀, i::IndexType) returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index i """ -function tangentCurve(x0::AbstractArray{N,T}, idx::IndexType) where {N, T} +function tangentCurve(x0::AbstractArray{N,T}, idx::IndexType) where {N,T} return t -> begin x = similar(x0) copyto!(x, x0) - x[idx] +=t + x[idx] += t return x end end @@ -55,8 +55,7 @@ end function partial(func, partials::IndexType...) idx, state = iterate(partials) return partial( - x -> FD.derivative(func ∘ tangentCurve(x, idx), 0), - Base.rest(partials, state)..., + x -> FD.derivative(func ∘ tangentCurve(x, idx), 0), Base.rest(partials, state)... ) end From c0f7fef27df610ceb75f10d936d92dfb336c5d75 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 19:21:09 +0200 Subject: [PATCH 12/17] fix code --- src/diffKernel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index e7fcd1e92..c15f308df 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -38,7 +38,7 @@ returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index """ function tangentCurve(x0::AbstractArray{N,T}, idx::IndexType) where {N,T} return t -> begin - x = similar(x0) + x = similar(x0, promote_type(eltype(x0), typeof(t))) copyto!(x, x0) x[idx] += t return x From 1c6f8a22587eea81d12815f74132182d2d2cb38a Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Tue, 23 May 2023 19:22:28 +0200 Subject: [PATCH 13/17] remove dependency on OneHotArrays --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index f951461be..278444f08 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" From 9aad16fadef8e269a9eeaeb317a138a59897031c Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Wed, 24 May 2023 11:04:41 +0200 Subject: [PATCH 14/17] remove unnecessary type bounds --- src/diffKernel.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffKernel.jl b/src/diffKernel.jl index c15f308df..8ed0a8b88 100644 --- a/src/diffKernel.jl +++ b/src/diffKernel.jl @@ -36,7 +36,7 @@ end tangentCurve(x₀, i::IndexType) returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index i """ -function tangentCurve(x0::AbstractArray{N,T}, idx::IndexType) where {N,T} +function tangentCurve(x0::AbstractArray, idx::IndexType) return t -> begin x = similar(x0, promote_type(eltype(x0), typeof(t))) copyto!(x, x0) @@ -64,8 +64,8 @@ Take the partial derivative of a function with two dim-dimensional inputs, i.e. 2*dim dimensional input """ function partial( - k, partials_x::NTuple{N,T}, partials_y::NTuple{M,T} -) where {N,M,T<:IndexType} + k, partials_x::Tuple{Vararg{T}}, partials_y::Tuple{Vararg{T}} +) where {T<:IndexType} local f(x, y) = partial(t -> k(t, y), partials_x...)(x) return (x, y) -> partial(t -> f(x, t), partials_y...)(y) end From 73f7195338802cea3005d041b848985b84cdfc1e Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 25 May 2023 20:33:13 +0200 Subject: [PATCH 15/17] add partial type --- src/mokernels/differentiable.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/mokernels/differentiable.jl diff --git a/src/mokernels/differentiable.jl b/src/mokernels/differentiable.jl new file mode 100644 index 000000000..6eb0f1eb6 --- /dev/null +++ b/src/mokernels/differentiable.jl @@ -0,0 +1,31 @@ + +struct Partial{Order} + indices::CartesianIndex{Order} +end + +function Partial(indices::Integer...) + return Partial{length(indices)}(CartesianIndex(indices)) +end + +compact_string_representation(::Partial{0}) = print(io, "id") +function compact_string_representation(p::Partial) + tuple = Tuple(p.indices) + lower_numbers = @. tuple |> digits |> reverse |> n-> '₀' + n + return join(["∂$(join(x))" for x in lower_numbers]) +end +function Base.show(io::IO, p::Partial) + if get(io, :compact, false) + print(io, "Partial($(Tuple(p.indices)))") + else + print(io, compact_string_representation(p)) + end +end + +function Base.show(io::IO, ::MIME"text/html", p::Partial) + tuple = Tuple(p.indices) + if get(io, :compact, false) + print(io, join(map(n->"∂$(n)", tuple),"")) + else + print(io, compact_string_representation(p)) + end +end \ No newline at end of file From 285c866b3d5aac4c4c2a9ed2f8221622effea0c8 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 25 May 2023 20:35:18 +0200 Subject: [PATCH 16/17] format --- src/mokernels/differentiable.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/mokernels/differentiable.jl b/src/mokernels/differentiable.jl index 6eb0f1eb6..07a354e01 100644 --- a/src/mokernels/differentiable.jl +++ b/src/mokernels/differentiable.jl @@ -1,31 +1,31 @@ struct Partial{Order} - indices::CartesianIndex{Order} + indices::CartesianIndex{Order} end function Partial(indices::Integer...) - return Partial{length(indices)}(CartesianIndex(indices)) + return Partial{length(indices)}(CartesianIndex(indices)) end compact_string_representation(::Partial{0}) = print(io, "id") function compact_string_representation(p::Partial) - tuple = Tuple(p.indices) - lower_numbers = @. tuple |> digits |> reverse |> n-> '₀' + n - return join(["∂$(join(x))" for x in lower_numbers]) + tuple = Tuple(p.indices) + lower_numbers = @. (n -> '₀' + n)(reverse(digits(tuple))) + return join(["∂$(join(x))" for x in lower_numbers]) end function Base.show(io::IO, p::Partial) - if get(io, :compact, false) - print(io, "Partial($(Tuple(p.indices)))") - else - print(io, compact_string_representation(p)) - end + if get(io, :compact, false) + print(io, "Partial($(Tuple(p.indices)))") + else + print(io, compact_string_representation(p)) + end end function Base.show(io::IO, ::MIME"text/html", p::Partial) - tuple = Tuple(p.indices) - if get(io, :compact, false) - print(io, join(map(n->"∂$(n)", tuple),"")) - else - print(io, compact_string_representation(p)) - end + tuple = Tuple(p.indices) + if get(io, :compact, false) + print(io, join(map(n -> "∂$(n)", tuple), "")) + else + print(io, compact_string_representation(p)) + end end \ No newline at end of file From deebf0cdeebe870c9520f3d5c071b0fd5c013b76 Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 25 May 2023 20:42:49 +0200 Subject: [PATCH 17/17] should have mime type --- src/mokernels/differentiable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mokernels/differentiable.jl b/src/mokernels/differentiable.jl index 07a354e01..1825efa60 100644 --- a/src/mokernels/differentiable.jl +++ b/src/mokernels/differentiable.jl @@ -13,7 +13,7 @@ function compact_string_representation(p::Partial) lower_numbers = @. (n -> '₀' + n)(reverse(digits(tuple))) return join(["∂$(join(x))" for x in lower_numbers]) end -function Base.show(io::IO, p::Partial) +function Base.show(io::IO, ::MIME"text/plain", p::Partial) if get(io, :compact, false) print(io, "Partial($(Tuple(p.indices)))") else