Skip to content

Commit

Permalink
Support keyword arguments, kind of
Browse files Browse the repository at this point in the history
This implements the last solution outlined in #7.
For some reason, it's a lot faster than I remember this approach
being.

TODO:
- update docs to include section about reusing contexts
- add note in docs to explain when kwargs are discarded
- add tests for keyword arguments, discarding, and warning
  • Loading branch information
christopher-dG committed Oct 9, 2019
1 parent 60de62a commit 683b556
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 53 deletions.
9 changes: 9 additions & 0 deletions src/Contexts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module Contexts

using Core: kwftype

using Cassette: Cassette, overdub, posthook, prehook, recurse, @context

using ..SimpleMock: Metadata, should_mock, update!

end
3 changes: 2 additions & 1 deletion src/SimpleMock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Base: Callable, invokelatest, unwrap_unionall
using Base.Iterators: Pairs
using Core: Builtin, IntrinsicFunction

using Cassette: Cassette, overdub, posthook, prehook, recurse, @context
using Cassette: Cassette, overdub, posthook, prehook

export
Call,
Expand All @@ -29,6 +29,7 @@ export

include("metadata.jl")
include("filters.jl")
include("Contexts.jl")
include("mock_type.jl")
include("mock_fun.jl")

Expand Down
81 changes: 48 additions & 33 deletions src/mock_fun.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,8 @@
@context MockCtx

# TODO: Maybe these should be inlined, but it slows down compilation a lot.

@noinline function Cassette.prehook(ctx::MockCtx{Metadata{true}}, f, args...)
@nospecialize f args
update!(ctx.metadata, prehook, f, args...)
end

@noinline function Cassette.posthook(ctx::MockCtx{Metadata{true}}, v, f, args...)
@nospecialize v f args
update!(ctx.metadata, posthook, f, args...)
end

"""
mock(f::Function, args...; filters::Vector{<:Function}=Function[])
mock(f::Function[, ctx::Symbol], args...; filters::Vector{<:Function}=Function[])
Run `f` with specified functions mocked out.
!!! note
Keyword arguments to mocked functions are not supported.
If you call a mocked function with keyword arguments, it will dispatch to the original function.
For more details, see [Cassette#48](https://github.com/jrevels/Cassette.jl/issues/48).
## Examples
Mocking a single function:
Expand Down Expand Up @@ -82,23 +63,33 @@ Filter functions take a single argument of type [`Metadata`](@ref).
If any filter rejects, then mocking is not performed.
See [Filter Functions](@ref) for a list of included filters, as well as building blocks for you to create your own.
"""
function mock(f::Function, args...; filters::Vector{<:Function}=Function[])
mock(f::Function, args...; filters::Vector{<:Function}=Function[]) =
mock(f, gensym(), args...; filters=filters)

function mock(f::Function, ctx::Symbol, args...; filters::Vector{<:Function}=Function[])
mocks = map(sig2mock, args) # ((f, sig) => mock).
isempty(mocks) && throw(ArgumentError("At least one function must be mocked"))

# Create the new context type if it doesn't already exist.
context_is_new = !isdefined(Contexts, ctx)
context_is_new && make_context(ctx)
Ctx = getfield(Contexts, ctx)

# Implement the overdubs, but only if they aren't already implemented.
has_new_overdub = false
foreach(map(first, mocks)) do k
fun = k[1]
sig = k[2:end]
if !overdub_exists(fun, sig)
make_overdub(fun, sig)
if context_is_new || !overdub_exists(Ctx, fun, sig)
make_overdub(Ctx, fun, sig)
has_new_overdub = true
end
end

# Only use `invokelatest` if the Context/overdub implementations are new.
od_args = [MockCtx(; metadata=Metadata(Dict(mocks), filters)), f, map(last, mocks)...]
meta = Metadata(Dict(mocks), filters)
c = context_is_new ? invokelatest(Ctx; metadata=meta) : Ctx(; metadata=meta)
od_args = [c, f, map(last, mocks)...]
return has_new_overdub ? invokelatest(overdub, od_args...) : overdub(od_args...)
end

Expand All @@ -108,20 +99,37 @@ sig2mock(p::Pair) = (p.first, Vararg{Any}) => p.second
sig2mock(t::Tuple) = t => Mock()
sig2mock(f) = (f, Vararg{Any}) => Mock()

# Create a new context type.
make_context(Ctx::Symbol) = @eval Contexts begin
@context $Ctx

# TODO: Maybe these should be inlined, but it slows down compilation a lot.

@noinline function Cassette.prehook(ctx::$Ctx{Metadata{true}}, f, args...)
@nospecialize f args
update!(ctx.metadata, prehook, f, args...)
end

@noinline function Cassette.posthook(ctx::$Ctx{Metadata{true}}, v, f, args...)
@nospecialize v f args
update!(ctx.metadata, posthook, f, args...)
end
end

# Has a given function and signature already been overdubbed?
overdub_exists(::F, sig::Tuple) where F = any(methods(overdub)) do m
overdub_exists(::Type{Ctx}, ::F, sig::Tuple) where {Ctx, F} = any(methods(overdub)) do m
squashed = foldl(sig; init=[]) do acc, T
if T isa DataType && T.name.name === :Vararg
append!(acc, repeat([T.parameters[1]], T.parameters[2]))
else
push!(acc, T)
end
end
m.sig === Tuple{typeof(overdub), MockCtx, F, squashed...}
m.sig === Tuple{typeof(overdub), Ctx, F, squashed...}
end

# Implement `overdub` for a given Context, function, and signature.
function make_overdub(f::F, sig::Tuple) where F
function make_overdub(::Type{Ctx}, f::F, sig::Tuple) where {Ctx, F}
sig_exs = Expr[]
sig_names = []

Expand Down Expand Up @@ -149,12 +157,19 @@ function make_overdub(f::F, sig::Tuple) where F
end
end

@eval @inline function Cassette.overdub(ctx::MockCtx, f::$F, $(sig_exs...))
method = (f, $(sig...))
if should_mock(ctx.metadata, method)
ctx.metadata.mocks[method]($(sig_names...))
else
recurse(ctx, f, $(sig_names...))
@eval Contexts begin
@inline function Cassette.overdub(ctx::$Ctx, f::$F, $(sig_exs...); kwargs...)
method = (f, $(sig...))
if should_mock(ctx.metadata, method)
ctx.metadata.mocks[method]($(sig_names...); kwargs...)
else
isempty(kwargs) || @warn "Discarding keyword arguments" kwargs
recurse(ctx, f, $(sig_names...))
end
end

# https://github.com/jrevels/Cassette.jl/issues/48#issuecomment-440605481
@inline Cassette.overdub(ctx::$Ctx, ::kwftype($F), kwargs, f::$F, $(sig_exs...)) =
overdub(ctx, f, $(sig_names...); kwargs...)
end
end
40 changes: 21 additions & 19 deletions test/mock_fun.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
@testset "mock does not overwrite methods" begin
# https://github.com/fredrikekre/jlpkg/blob/3b1c2400932dbe13fa7c3cba92bde3842315976c/src/cli.jl#L151-L160
o = JLOptions()
if o.warn_overwrite == 0
args = map(n -> n === :warn_overwrite ? 1 : getfield(o, n), fieldnames(JLOptions))
unsafe_store!(cglobal(:jl_options, JLOptions), JLOptions(args...))
end
mock(identity, identity)
out = @capture_err mock(identity, identity)
@test isempty(out)
end

@testset "Reusing Context" begin
f(x) = strip(uppercase(x))
# If the method checks aren't working properly, this will throw.
@test mock(_g -> f(" hi "), strip => identity) == " HI "
@test mock(_g -> f(" hi "), uppercase => identity) == "hi"
end

@testset "Basics" begin
mock(identity) do id
identity(10)
Expand Down Expand Up @@ -96,6 +77,27 @@ end
end
end

@testset "mock does not overwrite methods" begin
# https://github.com/fredrikekre/jlpkg/blob/3b1c2400932dbe13fa7c3cba92bde3842315976c/src/cli.jl#L151-L160
o = JLOptions()
if o.warn_overwrite == 0
args = map(n -> n === :warn_overwrite ? 1 : getfield(o, n), fieldnames(JLOptions))
unsafe_store!(cglobal(:jl_options, JLOptions), JLOptions(args...))
end
ctx = gensym()
mock(identity, ctx, identity)
out = @capture_err mock(identity, ctx, identity)
@test isempty(out)
end

@testset "Reusing Context" begin
f(x) = strip(uppercase(x))
# If the method checks aren't working properly, this will throw.
ctx = gensym()
@test mock(_f -> f(" hi "), ctx, strip => identity) == " HI "
@test mock(_f -> f(" hi "), ctx, uppercase => identity) == "hi"
end

@testset "Filters" begin
@testset "Maximum/minimum depth" begin
f(x) = identity(x)
Expand Down

0 comments on commit 683b556

Please sign in to comment.