Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use Base.Fix1 instead of closures in ForwardDiffStaticArraysExt.jl #735

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jondeuce
Copy link

Constant propagation of the tag T into the closure d -> value(T,d) seems to fail in ForwardDiff.vector_mode_jacobian!(::ImmutableDiffResult, ...)

@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
ydual = static_dual_eval(T, f, x)
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
result = DiffResults.value!(d -> value(T,d), result, ydual)
return result
end

so I replaced the closure with Base.Fix1. I preemptively replaced two other closures as well.

On the one hand, this is probably a compiler regression (though I'm not sure when; this bug occurs on both 1.9 and 1.11) that should be fixed elsewhere. But on the other hand, StaticArrays puts a lot of pressure on the compiler and closures are well-known to be finicky with respect to type inference, so IMO getting rid of the closures is worth the slight loss of readability.

MWE

using StaticArrays, ForwardDiff, DiffResults

f(x) = x .^ 2 ./ 2

function withjacobian(x::SVector)
    res = DiffResults.JacobianResult(x)
    res = ForwardDiff.jacobian!(res, f, x)
    return DiffResults.value(res), DiffResults.jacobian(res)
end

@code_warntype withjacobian(SVector(1.0, 2.0))
@btime withjacobian($(SVector(1.0, 2.0)))

Without this PR:

MethodInstance for withjacobian(::SVector{2, Float64})
  from withjacobian(x::SVector) @ Main ~/ForwardDiffBug/bug.jl:8
Arguments
  #self#::Core.Const(Main.withjacobian)
  x::SVector{2, Float64}
Locals
  res::DiffResults.ImmutableDiffResult{1, _A, Tuple{SMatrix{2, 2, Float64, 4}}} where _A
Body::Tuple{Any, SMatrix{2, 2, Float64, 4}}
1%1  = DiffResults.JacobianResult::Core.Const(DiffResults.JacobianResult)
│         (res = (%1)(x))
│   %3  = ForwardDiff.jacobian!::Core.Const(ForwardDiff.jacobian!)
│   %4  = res::Core.PartialStruct(DiffResults.ImmutableDiffResult{1, SVector{2, Float64}, Tuple{SMatrix{2, 2, Float64, 4}}}, Any[SVector{2, Float64}, Core.Const(([0.0 0.0; 0.0 0.0],))])
│   %5  = Main.f::Core.Const(Main.f)
│         (res = (%3)(%4, %5, x))
│   %7  = DiffResults.value::Core.Const(DiffResults.value)
│   %8  = res::DiffResults.ImmutableDiffResult{1, _A, Tuple{SMatrix{2, 2, Float64, 4}}} where _A
│   %9  = (%7)(%8)::Any%10 = DiffResults.jacobian::Core.Const(DiffResults.jacobian)
│   %11 = res::DiffResults.ImmutableDiffResult{1, _A, Tuple{SMatrix{2, 2, Float64, 4}}} where _A
│   %12 = (%10)(%11)::SMatrix{2, 2, Float64, 4}%13 = Core.tuple(%9, %12)::Tuple{Any, SMatrix{2, 2, Float64, 4}}
└──       return %13

  2.998 μs (17 allocations: 752 bytes)
([0.5, 2.0], [1.0 0.0; 0.0 2.0])

With this PR:

MethodInstance for withjacobian(::SVector{2, Float64})
  from withjacobian(x::SVector) @ Main ~/ForwardDiffBug/bug.jl:8
Arguments
  #self#::Core.Const(Main.withjacobian)
  x::SVector{2, Float64}
Locals
  res::DiffResults.ImmutableDiffResult{1, SVector{2, Float64}, Tuple{SMatrix{2, 2, Float64, 4}}}
Body::Tuple{SVector{2, Float64}, SMatrix{2, 2, Float64, 4}}
1%1  = DiffResults.JacobianResult::Core.Const(DiffResults.JacobianResult)
│         (res = (%1)(x))
│   %3  = ForwardDiff.jacobian!::Core.Const(ForwardDiff.jacobian!)
│   %4  = res::Core.PartialStruct(DiffResults.ImmutableDiffResult{1, SVector{2, Float64}, Tuple{SMatrix{2, 2, Float64, 4}}}, Any[SVector{2, Float64}, Core.Const(([0.0 0.0; 0.0 0.0],))])
│   %5  = Main.f::Core.Const(Main.f)
│         (res = (%3)(%4, %5, x))
│   %7  = DiffResults.value::Core.Const(DiffResults.value)
│   %8  = res::DiffResults.ImmutableDiffResult{1, SVector{2, Float64}, Tuple{SMatrix{2, 2, Float64, 4}}}%9  = (%7)(%8)::SVector{2, Float64}%10 = DiffResults.jacobian::Core.Const(DiffResults.jacobian)
│   %11 = res::DiffResults.ImmutableDiffResult{1, SVector{2, Float64}, Tuple{SMatrix{2, 2, Float64, 4}}}%12 = (%10)(%11)::SMatrix{2, 2, Float64, 4}%13 = Core.tuple(%9, %12)::Tuple{SVector{2, Float64}, SMatrix{2, 2, Float64, 4}}
└──       return %13

  2.845 ns (0 allocations: 0 bytes)
([0.5, 2.0], [1.0 0.0; 0.0 2.0])

Version info:

julia> Pkg.status()
Status `~/ForwardDiffBug/Project.toml`
  [163ba53b] DiffResults v1.1.0
  [f6369f11] ForwardDiff v0.11.0-DEV `~/.julia/dev/ForwardDiff`
  [90137ffa] StaticArrays v1.9.11

julia> versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × AMD Ryzen 9 3950X 16-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)

Copy link

codecov bot commented Feb 11, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.14%. Comparing base (c310fb5) to head (881ad59).
Report is 14 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #735      +/-   ##
==========================================
- Coverage   89.57%   86.14%   -3.44%     
==========================================
  Files          11       10       -1     
  Lines         969      895      -74     
==========================================
- Hits          868      771      -97     
- Misses        101      124      +23     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well spotted, this seems fine.

I think the closure gets typeof(T) == DataType, MWE is:

julia> let T = Float32
         @code_warntype map(x -> convert(T, x), 1:3)
       end
MethodInstance for map(::var"#27#28"{DataType}, ::UnitRange{Int64})
  from map(f, A::AbstractArray) @ Base abstractarray.jl:3371
Arguments
  #self#::Core.Const(map)
  f::var"#27#28"{DataType}
  A::UnitRange{Int64}
Body::Vector
1%1 = Base.collect_similar::Core.Const(Base.collect_similar)
│   %2 = Base.Generator(f, A)::Base.Generator{UnitRange{Int64}, var"#27#28"{DataType}}%3 = (%1)(A, %2)::Vector
└──      return %3

julia> VERSION
v"1.11.3"

It's fixed on master:

julia> let T = Float32
         @code_warntype map(x -> convert(T, x), 1:3)
       end
MethodInstance for map(::var"#8#9"{Type{Float32}}, ::UnitRange{Int64})
  from map(f, A::AbstractArray) @ Base abstractarray.jl:3411
Arguments
  #self#::Core.Const(map)
  f::var"#8#9"{Type{Float32}}
  A::UnitRange{Int64}
Body::Vector{Float32}
1%1 = Base.collect_similar::Core.Const(Base.collect_similar)
│   %2 = Base.Generator::Core.Const(Base.Generator)
│   %3 = (%2)(f, A)::Base.Generator{UnitRange{Int64}, var"#8#9"{Type{Float32}}}%4 = (%1)(A, %3)::Vector{Float32}
└──      return %4


julia> VERSION
v"1.12.0-DEV.1731"

julia> @code_warntype withjacobian(SVector(1.0, 2.0))  # from above
MethodInstance for withjacobian(::SVector{2, Float64})
  from withjacobian(x::SVector) @ Main REPL[9]:1
Arguments
  #self#::Core.Const(Main.withjacobian)
  x::SVector{2, Float64}
Locals
  res::DiffResults.ImmutableDiffResult{1, SVector{2, Float64}, Tuple{SMatrix{2, 2, Float64, 4}}}
Body::Tuple{SVector{2, Float64}, SMatrix{2, 2, Float64, 4}}

(@v1.12) pkg> st ForwardDiff
Status `~/.julia/environments/v1.12/Project.toml`
  [f6369f11] ForwardDiff v0.10.38

@devmotion
Copy link
Member

devmotion commented Feb 11, 2025

#640 attempted the same (but less extensively) to fix #639, however, was missing a test. Maybe at least the example above could be added as a test?

@jondeuce
Copy link
Author

Added checks that the MWEs from this PR and from #639 infer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants