Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nested op_name scope in reduce functions #367

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
60 changes: 34 additions & 26 deletions src/ops/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,36 +176,44 @@ end
# TODO Clean this up
for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
@eval @op function $(Symbol("reduce_", reduction))(n::AbstractTensor; axis=nothing, keep_dims=false, name=nothing)
if name === nothing
name = get_name("reduce")
end
if axis == nothing
local desc
shape = get_shape(n)
nodetype = $(capitalize(reduction))

if axis == nothing && shape.rank_unknown
n = Tensor(n) # TODO: rewrite this
range_start = constant(Int32(0))
range_delta = constant(Int32(1))
desc = NodeDescription("Rank", "$name/rank")
add_input(desc, n)
rank = Tensor(Operation(desc), 1)
desc = NodeDescription("Range", "$name/range")
add_input(desc, range_start)
add_input(desc, rank)
add_input(desc, range_delta)
range = Tensor(Operation(desc), 1)
desc = NodeDescription($(capitalize(reduction)), name)
add_input(desc, n)
add_input(desc, range)
Tensor(Operation(desc), 1)
rank = tf.with_op_name(nothing, "Rank") do
desc_rank = NodeDescription("Rank")
add_input(desc_rank, n)
Tensor(Operation(desc_rank), 1)
end
range = tf.with_op_name(nothing, "range") do
@tf start = constant(Int32(0))
@tf delta = constant(Int32(1))
desc_range = NodeDescription("Range")
add_input(desc_range, start)
add_input(desc_range, rank)
add_input(desc_range, delta)
Tensor(Operation(desc_range), 1)
end
tf.with_op_name(name, nodetype) do
desc = NodeDescription(nodetype)
add_input(desc, n)
add_input(desc, range)
end
else
if isa(axis, Number)
axis = [axis]
tf.with_op_name(name, nodetype) do
if axis == nothing
axis = 1:length(shape.dims)
end
@tf reduction_indices = constant(Int32.(axis.-1))
desc = NodeDescription(nodetype)
add_input(desc, Tensor(n))
add_input(desc, reduction_indices)
desc["keep_dims"] = keep_dims
end
axis = [Int32(idx-1) for idx in axis]
desc = NodeDescription($(capitalize(reduction)), name)
add_input(desc, Tensor(n))
add_input(desc, Tensor(axis))
desc["keep_dims"] = keep_dims
Tensor(Operation(desc), 1)
end
Tensor(Operation(desc), 1)
end
end

Expand Down
41 changes: 38 additions & 3 deletions test/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ end
@testset "Naming" begin
let
g = Graph()
local i, j_jl, j, k, ijk, ij, ij2, fq, m, W, Y, Ysum1, Ysum2, Ysum3, Ysum4
local i, j_jl, j, k, ijk, ij, ij2, fq, m, W, Y,
Ysum1, Ysum2, Ysum3, Ysum4, Ysum5, Ysum6, Ysum7, Ysum8,
p, psum1, psum2, psum3, psum4, psum5
as_default(g) do
@tf begin
i = constant(1.0)
Expand Down Expand Up @@ -46,6 +48,29 @@ end
Ysum3 = reduce_sum(Y, keep_dims=true) # With a comma (issue #188)

Ysum4 = reduce_sum(Y, keep_dims=true, name="namefor_Ysum4") # With a comma (issue #188)

Ysum5 = reduce_sum(Y, axis=2)

nn.tf.with_op_name("level1") do
Ysum6 = reduce_sum(Y)
nn.tf.with_op_name("level2") do
Ysum7 = reduce_sum(Y)
Ysum8 = reduce_sum(Y, axis=1)
end
end

p = placeholder(Float32)
psum1 = reduce_sum(p)
psum2 = reduce_sum(p, axis=1)

nn.tf.with_op_name("anotherlevel1") do
psum3 = reduce_sum(p)

nn.tf.with_op_name("level2") do
psum4 = reduce_sum(p)
psum5 = reduce_sum(p, axis=1)
end
end
end
end

Expand All @@ -68,8 +93,18 @@ end
@test Ysum2 == get_tensor_by_name(g, "Ysum2")
@test Ysum3 == get_tensor_by_name(g, "Ysum3")
@test Ysum4 == get_tensor_by_name(g, "namefor_Ysum4")


@test Ysum5 == get_tensor_by_name(g, "Ysum5")
@test Ysum6 == get_tensor_by_name(g, "level1/Ysum6")
@test Ysum7 == get_tensor_by_name(g, "level1/level2/Ysum7")
@test Ysum8 == get_tensor_by_name(g, "level1/level2/Ysum8")

@test psum1 == get_tensor_by_name(g, "psum1")
@test psum2 == get_tensor_by_name(g, "psum2")
@test psum3 == get_tensor_by_name(g, "anotherlevel1/psum3")
@test psum4 == get_tensor_by_name(g, "anotherlevel1/level2/psum4")
@test psum5 == get_tensor_by_name(g, "anotherlevel1/level2/psum5")

@test_throws TensorFlow.TFException reduce_sum(p, name="Ysum1")
end
end

Expand Down