Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 6, 2025
1 parent 8f623e4 commit a794c87
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ steps:
FLUX_TEST_CUDA: "true"
FLUX_TEST_CPU: "false"
FLUX_TEST_ENZYME: "false"
FLUX_TEST_REACTANT: "true"
timeout_in_minutes: 60

- label: "Metal - Julia 1"
Expand All @@ -37,6 +38,7 @@ steps:
FLUX_TEST_METAL: "true"
FLUX_TEST_CPU: "false"
FLUX_TEST_ENZYME: "false"
FLUX_TEST_REACTANT: "true"

- label: "AMDGPU - Julia 1"
plugins:
Expand All @@ -59,6 +61,7 @@ steps:
FLUX_TEST_AMDGPU: "true"
FLUX_TEST_CPU: "false"
FLUX_TEST_ENZYME: "false"
FLUX_TEST_REACTANT: "false"
JULIA_NUM_THREADS: 4

env:
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ OneHotArrays = "0.2.4"
Optimisers = "0.4.1"
Preferences = "1"
ProgressLogging = "0.1"
Reactant = "0.2.16"
Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Expand Down
51 changes: 0 additions & 51 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,6 @@ using Enzyme: Enzyme, Duplicated, Const, Active
end
end

@testset "Reactant Models" begin
function loss(model, x)
mean(model(x))
end

models_xs = [
(Dense(2=>4), randn(Float32, 2), "Dense"),
(Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"),
(f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"),
(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
(Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # Passes on 1.10, fails on 1.11 with MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
(first LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # AssertionError: Base.isconcretetype(typ)
(first MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # AssertionError: Base.isconcretetype(typ)
]

for (model, x, name) in models_xs
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true, test_reactant=true)
end
end
end

@testset "Recurrent Layers" begin
function loss(model, x)
mean(model(x))
Expand All @@ -81,27 +51,6 @@ end
end
end

@testset "Reactant Recurrent Layers" begin
function loss(model, x)
mean(model(x))
end

models_xs = [
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true, test_reactant=true)
end
end
end

@testset "gradient, withgradient, Duplicated" begin
# Tests above are about how Enzyme digests Flux layers.
# Tests here are just the interface Flux.gradient(f, Duplicated(model)) etc.
Expand Down
Loading

0 comments on commit a794c87

Please sign in to comment.