diff --git a/Project.toml b/Project.toml index f625e9027..92fc3e89b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.41" +version = "0.10.42" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index bd4627b19..9129fb84b 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -1,5 +1,5 @@ """ - ChainTransform(ts::AbstractVector{<:Transform}) + ChainTransform(transforms) Transformation that applies a chain of transformations `ts` to the input. @@ -19,7 +19,7 @@ julia> map(t2 ∘ t1, ColVecs(X)) == ColVecs(A * (l .* X)) true ``` """ -struct ChainTransform{V<:AbstractVector{<:Transform}} <: Transform +struct ChainTransform{V} <: Transform transforms::V end @@ -28,23 +28,23 @@ end Base.length(t::ChainTransform) = length(t.transforms) # Constructor to create a chain transform with an array of parameters -function ChainTransform(v::AbstractVector{<:Type{<:Transform}}, θ::AbstractVector) +function ChainTransform(v, θ::AbstractVector) @assert length(v) == length(θ) return ChainTransform(v.(θ)) end -Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁]) -Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t)) -Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms)) +Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁)) +Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(tuple(tc.transforms..., t)) +Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(tuple(t, tc.transforms...)) (t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x) function _map(t::ChainTransform, x::AbstractVector) - return foldl((x, t) -> map(t, x), t.transforms; init=x) + return foldl((x, t) -> _map(t, x), t.transforms; init=x) end set!(t::ChainTransform, θ) = set!.(t.transforms, θ) -duplicate(t::ChainTransform, θ) = ChainTransform(duplicate.(t.transforms, θ)) +duplicate(t::ChainTransform, θ) = ChainTransform(map(duplicate, t.transforms, θ)) Base.show(io::IO, t::ChainTransform) = printshifted(io, t, 0) diff --git a/test/test_utils.jl b/test/test_utils.jl index b8c349e37..c506d5a69 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -274,3 +274,81 @@ function test_AD(AD::Symbol, k::MOKernel, dims=(in=3, out=2, obs=3)) end end end + +function count_allocs(f, args...) + stats = @timed f(args...) + return Base.gc_alloc_count(stats.gcstats) +end + +""" + constant_allocs_heuristic(f, args1::T, args2::T) where {T} + +True if number of allocations associated with evaluating `f(args1...)` is equal to those +required to evaluate `f(args2...)`. Runs `f` beforehand to ensure that compilation-related +allocations are not included. + +Why is this a good test? In lots of situations it will be the case that the total amount of +memory allocated by a function will vary as the input sizes vary, but the total _number_ +of allocations ought to be constant. A common performance bug is that the number of +allocations actually does scale with the size of the inputs (e.g. due to a type +instability), and we would very much like to know if this is happening. + +Typically this kind of condition is not a sufficient condition for good performance, but it +is certainly a necessary condition. + +This kind of test is very quick to conduct (just requires running `f` 4 times). It's also +easier to write than simply checking that the total number of allocations used to execute +a function is below some arbitrary `f`-dependent threshold. +""" +function constant_allocs_heuristic(f, args1::T, args2::T) where {T} + + # Ensure that we're not counting allocations associated with compilation. + f(args1...) + f(args2...) + + allocs_1 = count_allocs(f, args1...) + allocs_2 = count_allocs(f, args2...) + return allocs_1 == allocs_2 +end + +""" + ad_constant_allocs_heuristic(f, args1::T, args2::T; Δ1=nothing, Δ2=nothing) where {T} + +Assesses `constant_allocs_heuristic` for `f`, `Zygote.pullback(f, args...)` and its +pullback for both of `args1` and `args2`. + +`Δ1` and `Δ2` are passed to the pullback associated with `Zygote.pullback(f, args1...)` +and `Zygote.pullback(f, args2...)` respectively. If left as `nothing`, it is assumed that +the output of the primal is an acceptable cotangent to be passed to the corresponding +pullback. +""" +function ad_constant_allocs_heuristic( + f, args1::T, args2::T; Δ1=nothing, Δ2=nothing +) where {T} + + # Check that primal has constant allocations. + primal_heuristic = constant_allocs_heuristic(f, args1, args2) + + # Check that forwards-pass has constant allocations. + forwards_heuristic = constant_allocs_heuristic( + (args...) -> Zygote.pullback(f, args...), args1, args2 + ) + + # Check that pullback has constant allocations for both arguments. Run twice to remove + # compilation-related allocations. + + # First thing + out1, pb1 = Zygote.pullback(f, args1...) + Δ1_val = Δ1 === nothing ? out1 : Δ1 + pb1(Δ1_val) + allocs_1 = count_allocs(pb1, Δ1_val) + + # Second thing + out2, pb2 = Zygote.pullback(f, args2...) + Δ2_val = Δ2 === nothing ? out2 : Δ2 + pb2(Δ2_val) + allocs_2 = count_allocs(pb2, Δ2 === nothing ? out2 : Δ2) + + pullback_heuristic = allocs_1 == allocs_2 + return primal_heuristic, forwards_heuristic, pullback_heuristic +end diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index f8b19cffe..3b6856b1f 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -7,11 +7,11 @@ f(x) = sin.(x) tf = FunctionTransform(f) - t = ChainTransform([tp, tf]) + t = ChainTransform((tp, tf)) # Check composition constructors. - @test (tf ∘ ChainTransform([tp])).transforms == [tp, tf] - @test (ChainTransform([tf]) ∘ tp).transforms == [tp, tf] + @test (tf ∘ ChainTransform([tp])).transforms == (tp, tf) + @test (ChainTransform([tf]) ∘ tp).transforms == (tp, tf) # Verify correctness. x = ColVecs(randn(rng, 2, 3)) @@ -27,5 +27,14 @@ randn(rng, 4); ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote ) - @test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263" + + @testset "AD performance" begin + primal, forward, pb = ad_constant_allocs_heuristic((randn(5),), (randn(10),)) do x + k = SEKernel() ∘ (ScaleTransform(0.1) ∘ PeriodicTransform(10.0)) + return kernelmatrix(k, x) + end + @test primal + @test forward + @test pb + end end