Skip to content

Commit

Permalink
fix function name bug (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda authored Dec 15, 2022
1 parent ff947f7 commit 01a775d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
11 changes: 8 additions & 3 deletions src/latexify/codegen/algorithmic_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(
use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols
use_math_symbols=use_math_symbols,
use_mathrm=False,
)
self._indent_level = 0

Expand All @@ -63,6 +64,8 @@ def visit_Expr(self, node: ast.Expr) -> str:
# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
name_latex = self._identifier_converter.convert(node.name)[0]

# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
Expand All @@ -71,7 +74,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
latex = self._add_indent("\\begin{algorithmic}\n")
with self._increment_level():
latex += self._add_indent(
f"\\Function{{{node.name}}}{{${', '.join(arg_strs)}$}}\n"
f"\\Function{{{name_latex}}}{{${', '.join(arg_strs)}$}}\n"
)

with self._increment_level():
Expand Down Expand Up @@ -197,6 +200,8 @@ def visit_Expr(self, node: ast.Expr) -> str:
# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
name_latex = self._identifier_converter.convert(node.name)[0]

# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
Expand All @@ -209,7 +214,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
return (
r"\begin{array}{l} "
+ self._add_indent(r"\mathbf{function}")
+ rf" \ \mathrm{{{node.name}}}({', '.join(arg_strs)})"
+ rf" \ {name_latex}({', '.join(arg_strs)})"
+ f"{self._LINE_BREAK}{body}{self._LINE_BREAK}"
+ self._add_indent(r"\mathbf{end \ function}")
+ r" \end{array}"
Expand Down
6 changes: 3 additions & 3 deletions src/latexify/codegen/algorithmic_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_visit_while_with_else() -> None:
("a = b = 0", r"a \gets b \gets 0"),
],
)
def test_visit_assign_jupyter(code: str, latex: str) -> None:
def test_visit_assign_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.Assign)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex
Expand All @@ -188,7 +188,7 @@ def test_visit_assign_jupyter(code: str, latex: str) -> None:
(
r"\begin{array}{l}"
r" \mathbf{function}"
r" \ \mathrm{f}(x) \\"
r" \ f(x) \\"
r" \hspace{1em} \mathbf{return} \ x \\"
r" \mathbf{end \ function}"
r" \end{array}"
Expand All @@ -199,7 +199,7 @@ def test_visit_assign_jupyter(code: str, latex: str) -> None:
(
r"\begin{array}{l}"
r" \mathbf{function}"
r" \ \mathrm{f}(a, b, c) \\"
r" \ f(a, b, c) \\"
r" \hspace{1em} \mathbf{return} \ 3 \\"
r" \mathbf{end \ function}"
r" \end{array}"
Expand Down
13 changes: 10 additions & 3 deletions src/latexify/codegen/identifier_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ class IdentifierConverter:
"""

_use_math_symbols: bool
_use_mathrm: bool

def __init__(self, *, use_math_symbols: bool) -> None:
"""Initializer.
def __init__(self, *, use_math_symbols: bool, use_mathrm: bool = True) -> None:
r"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with math symbol names to
appropriate LaTeX command.
use_mathrm: Whether to wrap the resulting expression by \mathrm, if
applicable.
"""
self._use_math_symbols = use_math_symbols
self._use_mathrm = use_mathrm

def convert(self, name: str) -> tuple[str, bool]:
"""Converts Python identifier to LaTeX expression.
Expand All @@ -44,4 +48,7 @@ def convert(self, name: str) -> tuple[str, bool]:
if len(name) == 1 and name != "_":
return name, True

return r"\mathrm{" + name.replace("_", r"\_") + "}", False
escaped = name.replace("_", r"\_")
wrapped = rf"\mathrm{{{escaped}}}" if self._use_mathrm else escaped

return wrapped, False
37 changes: 19 additions & 18 deletions src/latexify/codegen/identifier_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,32 @@


@pytest.mark.parametrize(
"name,use_math_symbols,expected",
"name,use_math_symbols,use_mathrm,expected",
[
("a", False, ("a", True)),
("_", False, (r"\mathrm{\_}", False)),
("aa", False, (r"\mathrm{aa}", False)),
("a1", False, (r"\mathrm{a1}", False)),
("a_", False, (r"\mathrm{a\_}", False)),
("_a", False, (r"\mathrm{\_a}", False)),
("_1", False, (r"\mathrm{\_1}", False)),
("__", False, (r"\mathrm{\_\_}", False)),
("a_a", False, (r"\mathrm{a\_a}", False)),
("a__", False, (r"\mathrm{a\_\_}", False)),
("a_1", False, (r"\mathrm{a\_1}", False)),
("alpha", False, (r"\mathrm{alpha}", False)),
("alpha", True, (r"\alpha", True)),
("foo", False, (r"\mathrm{foo}", False)),
("foo", True, (r"\mathrm{foo}", False)),
("a", False, True, ("a", True)),
("_", False, True, (r"\mathrm{\_}", False)),
("aa", False, True, (r"\mathrm{aa}", False)),
("a1", False, True, (r"\mathrm{a1}", False)),
("a_", False, True, (r"\mathrm{a\_}", False)),
("_a", False, True, (r"\mathrm{\_a}", False)),
("_1", False, True, (r"\mathrm{\_1}", False)),
("__", False, True, (r"\mathrm{\_\_}", False)),
("a_a", False, True, (r"\mathrm{a\_a}", False)),
("a__", False, True, (r"\mathrm{a\_\_}", False)),
("a_1", False, True, (r"\mathrm{a\_1}", False)),
("alpha", False, True, (r"\mathrm{alpha}", False)),
("alpha", True, True, (r"\alpha", True)),
("foo", False, True, (r"\mathrm{foo}", False)),
("foo", True, True, (r"\mathrm{foo}", False)),
("foo", True, False, (r"foo", False)),
],
)
def test_identifier_converter(
name: str, use_math_symbols: bool, expected: tuple[str, bool]
name: str, use_math_symbols: bool, use_mathrm: bool, expected: tuple[str, bool]
) -> None:
assert (
identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols
use_math_symbols=use_math_symbols, use_mathrm=use_mathrm
).convert(name)
== expected
)

0 comments on commit 01a775d

Please sign in to comment.