From 49c54392beaea4e5a8f53e12b9dde7415a89e145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Wed, 26 Jul 2023 09:04:34 -0300 Subject: [PATCH] Add inv for SequentialTransform --- src/sequential.jl | 26 ++++++++++++++------------ test/runtests.jl | 1 + 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/sequential.jl b/src/sequential.jl index bb24b7e..20e86b2 100644 --- a/src/sequential.jl +++ b/src/sequential.jl @@ -11,22 +11,12 @@ struct SequentialTransform <: Transform transforms::Vector{Transform} end -# AbstractTrees interface -AbstractTrees.nodevalue(::SequentialTransform) = SequentialTransform -AbstractTrees.children(s::SequentialTransform) = s.transforms - -Base.show(io::IO, s::SequentialTransform) = - print(io, join(s.transforms, " → ")) - -function Base.show(io::IO, ::MIME"text/plain", s::SequentialTransform) - tree = AbstractTrees.repr_tree(s, context=io) - print(io, tree[begin:end-1]) # remove \n at end -end - isrevertible(s::SequentialTransform) = all(isrevertible, s.transforms) isinvertible(s::SequentialTransform) = all(isinvertible, s.transforms) +Base.inv(s::SequentialTransform) = SequentialTransform([inv(t) for t in reverse(s.transforms)]) + function apply(s::SequentialTransform, table) allcache = [] current = table @@ -80,3 +70,15 @@ Create a [`SequentialTransform`](@ref) transform with →(t1::Identity, t2::Identity) = Identity() →(t1::Transform, t2::Identity) = t1 →(t1::Identity, t2::Transform) = t2 + +# AbstractTrees interface +AbstractTrees.nodevalue(::SequentialTransform) = SequentialTransform +AbstractTrees.children(s::SequentialTransform) = s.transforms + +Base.show(io::IO, s::SequentialTransform) = + print(io, join(s.transforms, " → ")) + +function Base.show(io::IO, ::MIME"text/plain", s::SequentialTransform) + tree = AbstractTrees.repr_tree(s, context=io) + print(io, tree[begin:end-1]) # remove \n at end +end diff --git a/test/runtests.jl b/test/runtests.jl index 61f76e4..f4d16b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Test @test TransformsBase.isrevertible(Identity()) @test TransformsBase.isinvertible(Identity()) @test inv(Identity()) == Identity() + @test inv(Identity() → Identity()) == Identity() @test (Identity() → Identity()) == Identity() end