Skip to content

Commit

Permalink
Revert "[ONNX] Adjust and add deprecation messages (pytorch#146639)"
Browse files Browse the repository at this point in the history
This reverts commit 63c2909.

Reverted pytorch#146639 on behalf of https://github.com/atalman due to Sorry Need to revert pytorch#146425 ([comment](pytorch#146639 (comment)))
  • Loading branch information
pytorchmergebot committed Feb 10, 2025
1 parent a36c22f commit 1557b7b
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 120 deletions.
24 changes: 1 addition & 23 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
]

from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import deprecated

import torch
from torch import _C
Expand Down Expand Up @@ -169,19 +168,6 @@ def export(
) -> ONNXProgram | None:
r"""Exports a model into ONNX format.
.. versionchanged:: 2.6
*training* is now deprecated. Instead, set the training mode of the model before exporting.
.. versionchanged:: 2.6
*operator_export_type* is now deprecated. Only ONNX is supported.
.. versionchanged:: 2.6
*do_constant_folding* is now deprecated. It is always enabled.
.. versionchanged:: 2.6
*export_modules_as_functions* is now deprecated.
.. versionchanged:: 2.6
*autograd_inlining* is now deprecated.
.. versionchanged:: 2.7
*optimize* is now True by default.
Args:
model: The model to be exported.
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
Expand Down Expand Up @@ -356,9 +342,6 @@ def forward(self, x):
autograd_inlining: Deprecated.
Flag used to control whether to inline autograd functions.
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
Returns:
:class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None.
"""
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
from torch.onnx._internal.exporter import _compat
Expand Down Expand Up @@ -419,9 +402,6 @@ def forward(self, x):
return None


@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
)
def dynamo_export(
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
/,
Expand All @@ -431,9 +411,6 @@ def dynamo_export(
) -> ONNXProgram:
"""Export a torch.nn.Module to an ONNX graph.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Args:
model: The PyTorch model to be exported to ONNX.
model_args: Positional inputs to ``model``.
Expand Down Expand Up @@ -475,6 +452,7 @@ def forward(self, x, bias=None):
onnx_program.save("my_dynamic_model.onnx")
"""

# NOTE: The new exporter is experimental and is not enabled by default.
import warnings

from torch.onnx import _flags
Expand Down
72 changes: 72 additions & 0 deletions torch/onnx/_deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Utility for deprecating functions."""

import functools
import textwrap
import warnings
from typing import Callable, TypeVar
from typing_extensions import ParamSpec


_T = TypeVar("_T")
_P = ParamSpec("_P")


def deprecated(
since: str, removed_in: str, instructions: str
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Marks functions as deprecated.
It will result in a warning when the function is called and a note in the
docstring.
Args:
since: The version when the function was first deprecated.
removed_in: The version when the function will be removed.
instructions: The action users should take.
"""

def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(function)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
warnings.warn(
f"'{function.__module__}.{function.__name__}' "
f"is deprecated in version {since} and will be "
f"removed in {removed_in}. Please {instructions}.",
category=DeprecationWarning,
stacklevel=2,
)
return function(*args, **kwargs)

# Add a deprecation note to the docstring.
docstring = function.__doc__ or ""

# Add a note to the docstring.
deprecation_note = textwrap.dedent(
f"""\
.. deprecated:: {since}
Deprecated and will be removed in version {removed_in}.
Please {instructions}.
"""
)

# Split docstring at first occurrence of newline
summary_and_body = docstring.split("\n\n", 1)

if len(summary_and_body) > 1:
summary, body = summary_and_body

# Dedent the body. We cannot do this with the presence of the summary because
# the body contains leading whitespaces when the summary does not.
body = textwrap.dedent(body)

new_docstring_parts = [deprecation_note, "\n\n", summary, body]
else:
summary = summary_and_body[0]

new_docstring_parts = [deprecation_note, "\n\n", summary]

wrapper.__doc__ = "".join(new_docstring_parts)

return wrapper

return decorator
12 changes: 12 additions & 0 deletions torch/onnx/_exporter_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations


class ExportTypes:
"""Specifies how the ONNX model is stored."""

# TODO(justinchuby): Deprecate and remove this class.

PROTOBUF_FILE = "Saves model in the specified protobuf file."
ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
DIRECTORY = "Saves model in the specified folder."
29 changes: 10 additions & 19 deletions torch/onnx/_internal/_exporter_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,12 @@ class ONNXFakeContext:


@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
)
class OnnxRegistry:
"""Registry for ONNX functions.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
The registry maintains a mapping from qualified names to symbolic functions under a
fixed opset version. It supports registering custom onnx-script functions and for
dispatcher to dispatch calls to the appropriate function.
Expand Down Expand Up @@ -231,14 +229,12 @@ def _all_registered_ops(self) -> set[str]:


@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
)
class ExportOptions:
"""Options to influence the TorchDynamo ONNX exporter.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Attributes:
dynamic_shapes: Shape information hint for input/output tensors.
When ``None``, the exporter determines the most compatible setting.
Expand Down Expand Up @@ -389,9 +385,8 @@ def enable_fake_mode():
It is highly recommended to initialize the model in fake mode when exporting models that
are too large to fit into memory.
.. note::
This function does not support torch.onnx.export(..., dynamo=True, optimize=True).
Please call ONNXProgram.optimize() outside of the function after the model is exported.
NOTE: This function does not support torch.onnx.export(..., dynamo=True, optimize=True), so
please call ONNXProgram.optimize() outside of the function after the model is exported.
Example::
Expand Down Expand Up @@ -448,14 +443,12 @@ def enable_fake_mode():


@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
)
class ONNXRuntimeOptions:
"""Options to influence the execution of the ONNX model through ONNX Runtime.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Attributes:
session_options: ONNX Runtime session options.
execution_providers: ONNX Runtime execution providers to use during model execution.
Expand Down Expand Up @@ -708,7 +701,8 @@ def missing_opset(package_name: str):


@deprecated(
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead."
"torch.onnx.dynamo_export is deprecated since 2.6.0. Please use torch.onnx.export(..., dynamo=True) instead.",
category=DeprecationWarning,
)
def dynamo_export(
model: torch.nn.Module | Callable,
Expand All @@ -719,9 +713,6 @@ def dynamo_export(
) -> _onnx_program.ONNXProgram:
"""Export a torch.nn.Module to an ONNX graph.
.. deprecated:: 2.6
Please use ``torch.onnx.export(..., dynamo=True)`` instead.
Args:
model: The PyTorch model to be exported to ONNX.
model_args: Positional inputs to ``model``.
Expand Down
57 changes: 46 additions & 11 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import sys
import warnings
from typing import Callable, TYPE_CHECKING
from typing_extensions import deprecated

import torch
import torch._C._onnx as _C_onnx
Expand All @@ -24,7 +23,7 @@
from torch import _C

# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _constants, _type_utils, errors, symbolic_helper
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import jit_utils, registration

Expand Down Expand Up @@ -3316,55 +3315,91 @@ def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_co


@_onnx_symbolic("aten::_cast_Byte")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)


@_onnx_symbolic("aten::_cast_Char")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)


@_onnx_symbolic("aten::_cast_Short")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)


@_onnx_symbolic("aten::_cast_Int")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)


@_onnx_symbolic("aten::_cast_Long")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)


@_onnx_symbolic("aten::_cast_Half")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)


@_onnx_symbolic("aten::_cast_Float")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)


@_onnx_symbolic("aten::_cast_Double")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)


@_onnx_symbolic("aten::_cast_Bool")
@deprecated("Avoid using this function and create a Cast node instead")
@_deprecation.deprecated(
"2.0",
"the future",
"Avoid using this function and create a Cast node instead",
)
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)

Expand Down
Loading

0 comments on commit 1557b7b

Please sign in to comment.