From 683b556dd6d31961e640e256fbddb17bb4c823d1 Mon Sep 17 00:00:00 2001 From: Chris de Graaf Date: Wed, 9 Oct 2019 22:57:32 +0700 Subject: [PATCH] Support keyword arguments, kind of This implements the last solution outlined in #7. For some reason, it's a lot faster than I remember this approach being. TODO: - update docs to include section about reusing contexts - add note in docs to explain when kwargs are discarded - add tests for keyword arguments, discarding, and warning --- src/Contexts.jl | 9 ++++++ src/SimpleMock.jl | 3 +- src/mock_fun.jl | 81 ++++++++++++++++++++++++++++------------------- test/mock_fun.jl | 40 ++++++++++++----------- 4 files changed, 80 insertions(+), 53 deletions(-) create mode 100644 src/Contexts.jl diff --git a/src/Contexts.jl b/src/Contexts.jl new file mode 100644 index 0000000..d0818b5 --- /dev/null +++ b/src/Contexts.jl @@ -0,0 +1,9 @@ +module Contexts + +using Core: kwftype + +using Cassette: Cassette, overdub, posthook, prehook, recurse, @context + +using ..SimpleMock: Metadata, should_mock, update! + +end diff --git a/src/SimpleMock.jl b/src/SimpleMock.jl index 5577ca9..b0b3596 100644 --- a/src/SimpleMock.jl +++ b/src/SimpleMock.jl @@ -7,7 +7,7 @@ using Base: Callable, invokelatest, unwrap_unionall using Base.Iterators: Pairs using Core: Builtin, IntrinsicFunction -using Cassette: Cassette, overdub, posthook, prehook, recurse, @context +using Cassette: Cassette, overdub, posthook, prehook export Call, @@ -29,6 +29,7 @@ export include("metadata.jl") include("filters.jl") +include("Contexts.jl") include("mock_type.jl") include("mock_fun.jl") diff --git a/src/mock_fun.jl b/src/mock_fun.jl index 36d4638..e947bb0 100644 --- a/src/mock_fun.jl +++ b/src/mock_fun.jl @@ -1,27 +1,8 @@ -@context MockCtx - -# TODO: Maybe these should be inlined, but it slows down compilation a lot. - -@noinline function Cassette.prehook(ctx::MockCtx{Metadata{true}}, f, args...) - @nospecialize f args - update!(ctx.metadata, prehook, f, args...) -end - -@noinline function Cassette.posthook(ctx::MockCtx{Metadata{true}}, v, f, args...) - @nospecialize v f args - update!(ctx.metadata, posthook, f, args...) -end - """ - mock(f::Function, args...; filters::Vector{<:Function}=Function[]) + mock(f::Function[, ctx::Symbol], args...; filters::Vector{<:Function}=Function[]) Run `f` with specified functions mocked out. -!!! note - Keyword arguments to mocked functions are not supported. - If you call a mocked function with keyword arguments, it will dispatch to the original function. - For more details, see [Cassette#48](https://github.com/jrevels/Cassette.jl/issues/48). - ## Examples Mocking a single function: @@ -82,23 +63,33 @@ Filter functions take a single argument of type [`Metadata`](@ref). If any filter rejects, then mocking is not performed. See [Filter Functions](@ref) for a list of included filters, as well as building blocks for you to create your own. """ -function mock(f::Function, args...; filters::Vector{<:Function}=Function[]) +mock(f::Function, args...; filters::Vector{<:Function}=Function[]) = + mock(f, gensym(), args...; filters=filters) + +function mock(f::Function, ctx::Symbol, args...; filters::Vector{<:Function}=Function[]) mocks = map(sig2mock, args) # ((f, sig) => mock). isempty(mocks) && throw(ArgumentError("At least one function must be mocked")) + # Create the new context type if it doesn't already exist. + context_is_new = !isdefined(Contexts, ctx) + context_is_new && make_context(ctx) + Ctx = getfield(Contexts, ctx) + # Implement the overdubs, but only if they aren't already implemented. has_new_overdub = false foreach(map(first, mocks)) do k fun = k[1] sig = k[2:end] - if !overdub_exists(fun, sig) - make_overdub(fun, sig) + if context_is_new || !overdub_exists(Ctx, fun, sig) + make_overdub(Ctx, fun, sig) has_new_overdub = true end end # Only use `invokelatest` if the Context/overdub implementations are new. - od_args = [MockCtx(; metadata=Metadata(Dict(mocks), filters)), f, map(last, mocks)...] + meta = Metadata(Dict(mocks), filters) + c = context_is_new ? invokelatest(Ctx; metadata=meta) : Ctx(; metadata=meta) + od_args = [c, f, map(last, mocks)...] return has_new_overdub ? invokelatest(overdub, od_args...) : overdub(od_args...) end @@ -108,8 +99,25 @@ sig2mock(p::Pair) = (p.first, Vararg{Any}) => p.second sig2mock(t::Tuple) = t => Mock() sig2mock(f) = (f, Vararg{Any}) => Mock() +# Create a new context type. +make_context(Ctx::Symbol) = @eval Contexts begin + @context $Ctx + + # TODO: Maybe these should be inlined, but it slows down compilation a lot. + + @noinline function Cassette.prehook(ctx::$Ctx{Metadata{true}}, f, args...) + @nospecialize f args + update!(ctx.metadata, prehook, f, args...) + end + + @noinline function Cassette.posthook(ctx::$Ctx{Metadata{true}}, v, f, args...) + @nospecialize v f args + update!(ctx.metadata, posthook, f, args...) + end +end + # Has a given function and signature already been overdubbed? -overdub_exists(::F, sig::Tuple) where F = any(methods(overdub)) do m +overdub_exists(::Type{Ctx}, ::F, sig::Tuple) where {Ctx, F} = any(methods(overdub)) do m squashed = foldl(sig; init=[]) do acc, T if T isa DataType && T.name.name === :Vararg append!(acc, repeat([T.parameters[1]], T.parameters[2])) @@ -117,11 +125,11 @@ overdub_exists(::F, sig::Tuple) where F = any(methods(overdub)) do m push!(acc, T) end end - m.sig === Tuple{typeof(overdub), MockCtx, F, squashed...} + m.sig === Tuple{typeof(overdub), Ctx, F, squashed...} end # Implement `overdub` for a given Context, function, and signature. -function make_overdub(f::F, sig::Tuple) where F +function make_overdub(::Type{Ctx}, f::F, sig::Tuple) where {Ctx, F} sig_exs = Expr[] sig_names = [] @@ -149,12 +157,19 @@ function make_overdub(f::F, sig::Tuple) where F end end - @eval @inline function Cassette.overdub(ctx::MockCtx, f::$F, $(sig_exs...)) - method = (f, $(sig...)) - if should_mock(ctx.metadata, method) - ctx.metadata.mocks[method]($(sig_names...)) - else - recurse(ctx, f, $(sig_names...)) + @eval Contexts begin + @inline function Cassette.overdub(ctx::$Ctx, f::$F, $(sig_exs...); kwargs...) + method = (f, $(sig...)) + if should_mock(ctx.metadata, method) + ctx.metadata.mocks[method]($(sig_names...); kwargs...) + else + isempty(kwargs) || @warn "Discarding keyword arguments" kwargs + recurse(ctx, f, $(sig_names...)) + end end + + # https://github.com/jrevels/Cassette.jl/issues/48#issuecomment-440605481 + @inline Cassette.overdub(ctx::$Ctx, ::kwftype($F), kwargs, f::$F, $(sig_exs...)) = + overdub(ctx, f, $(sig_names...); kwargs...) end end diff --git a/test/mock_fun.jl b/test/mock_fun.jl index 5bf0568..e9a099d 100644 --- a/test/mock_fun.jl +++ b/test/mock_fun.jl @@ -1,22 +1,3 @@ -@testset "mock does not overwrite methods" begin - # https://github.com/fredrikekre/jlpkg/blob/3b1c2400932dbe13fa7c3cba92bde3842315976c/src/cli.jl#L151-L160 - o = JLOptions() - if o.warn_overwrite == 0 - args = map(n -> n === :warn_overwrite ? 1 : getfield(o, n), fieldnames(JLOptions)) - unsafe_store!(cglobal(:jl_options, JLOptions), JLOptions(args...)) - end - mock(identity, identity) - out = @capture_err mock(identity, identity) - @test isempty(out) -end - -@testset "Reusing Context" begin - f(x) = strip(uppercase(x)) - # If the method checks aren't working properly, this will throw. - @test mock(_g -> f(" hi "), strip => identity) == " HI " - @test mock(_g -> f(" hi "), uppercase => identity) == "hi" -end - @testset "Basics" begin mock(identity) do id identity(10) @@ -96,6 +77,27 @@ end end end +@testset "mock does not overwrite methods" begin + # https://github.com/fredrikekre/jlpkg/blob/3b1c2400932dbe13fa7c3cba92bde3842315976c/src/cli.jl#L151-L160 + o = JLOptions() + if o.warn_overwrite == 0 + args = map(n -> n === :warn_overwrite ? 1 : getfield(o, n), fieldnames(JLOptions)) + unsafe_store!(cglobal(:jl_options, JLOptions), JLOptions(args...)) + end + ctx = gensym() + mock(identity, ctx, identity) + out = @capture_err mock(identity, ctx, identity) + @test isempty(out) +end + +@testset "Reusing Context" begin + f(x) = strip(uppercase(x)) + # If the method checks aren't working properly, this will throw. + ctx = gensym() + @test mock(_f -> f(" hi "), ctx, strip => identity) == " HI " + @test mock(_f -> f(" hi "), ctx, uppercase => identity) == "hi" +end + @testset "Filters" begin @testset "Maximum/minimum depth" begin f(x) = identity(x)