Skip to content

Commit

Permalink
PEP585 update - torch/export (pytorch#145165)
Browse files Browse the repository at this point in the history
See pytorch#145101 for details.
Pull Request resolved: pytorch#145165
Approved by: https://github.com/bobrenjc93
  • Loading branch information
aorenste authored and pytorchmergebot committed Jan 19, 2025
1 parent 316808e commit b6c5562
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 269 deletions.
38 changes: 19 additions & 19 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import typing
import warnings
import zipfile
from collections.abc import Iterator
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -82,12 +82,12 @@

def export_for_training(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
Expand Down Expand Up @@ -177,13 +177,13 @@ def export_for_training(

def export_for_inference(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None,
) -> ExportedProgram:
"""
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
Expand Down Expand Up @@ -262,12 +262,12 @@ def export_for_inference(

def export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
Expand Down Expand Up @@ -383,8 +383,8 @@ def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> None:
"""
Expand Down Expand Up @@ -466,8 +466,8 @@ def forward(self, x):
def load(
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
extra_files: Optional[dict[str, Any]] = None,
expected_opset_version: Optional[dict[str, int]] = None,
) -> ExportedProgram:
"""
Expand Down Expand Up @@ -577,7 +577,7 @@ def load(


def register_dataclass(
cls: Type[Any],
cls: type[Any],
*,
serialized_type_name: Optional[str] = None,
) -> None:
Expand Down
46 changes: 23 additions & 23 deletions torch/export/_draft_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import torch
import torch._logging._internal
Expand All @@ -26,7 +26,7 @@ def __str__(self) -> str:
return self.name


def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str]) -> str:
def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[str, str]) -> str:
res = ""
for frame in stack:
if frame["filename"] not in str_to_filename:
Expand All @@ -38,8 +38,8 @@ def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str])


def filter_stack(
stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
) -> List[Dict[str, str]]:
stack: list[dict[str, str]], str_to_filename: dict[str, str]
) -> list[dict[str, str]]:
for i, s in enumerate(reversed(stack)):
s["filename"] = str(s["filename"])
if s["filename"] not in str_to_filename:
Expand All @@ -50,22 +50,22 @@ def filter_stack(
return stack[-3:]


def hash_stack(stack: List[Dict[str, str]]) -> str:
def hash_stack(stack: list[dict[str, str]]) -> str:
return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)


class FailureReport:
def __init__(
self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False
) -> None:
self.failure_type: FailureType = failure_type
self.data: Dict[str, Any] = data
self.data: dict[str, Any] = data
self.xfail: bool = xfail

def __repr__(self) -> str:
return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"

def print(self, str_to_filename: Dict[str, str]) -> str:
def print(self, str_to_filename: dict[str, str]) -> str:
if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
op = self.data["op"]

Expand Down Expand Up @@ -113,8 +113,8 @@ def print(self, str_to_filename: Dict[str, str]) -> str:


class DraftExportReport:
def __init__(self, failures: List[FailureReport], str_to_filename: Dict[str, str]):
self.failures: List[FailureReport] = failures
def __init__(self, failures: list[FailureReport], str_to_filename: dict[str, str]):
self.failures: list[FailureReport] = failures
self.str_to_filename = str_to_filename

def successful(self) -> bool:
Expand Down Expand Up @@ -156,10 +156,10 @@ def apply_suggested_fixes(self) -> None:


class CaptureStructuredTrace(logging.Handler):
def __init__(self, specific_log_keys: List[str]):
def __init__(self, specific_log_keys: list[str]):
super().__init__()
self.specific_log_keys = specific_log_keys
self.logs: List[Tuple[str, Dict[str, Any]]] = []
self.logs: list[tuple[str, dict[str, Any]]] = []
self.logger = logging.getLogger("torch.__trace")
self.prev_get_dtrace = False

Expand All @@ -185,14 +185,14 @@ def emit(self, record: Any) -> None:

def draft_export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
preserve_module_call_signature: Tuple[str, ...] = (),
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
pre_dispatch: bool = False,
) -> Tuple[ExportedProgram, DraftExportReport]:
) -> tuple[ExportedProgram, DraftExportReport]:
kwargs = kwargs or {}
dynamic_shapes = dynamic_shapes or {}

Expand Down Expand Up @@ -234,15 +234,15 @@ def draft_export(
preserve_module_call_signature=preserve_module_call_signature,
)

str_to_filename: Dict[str, str] = {
str_to_filename: dict[str, str] = {
str(v): k for (k, v) in torch._logging.structured.INTERN_TABLE.items()
}
failures: List[FailureReport] = []
custom_ops_logs: Dict[
Any, Tuple[Dict[str, Any], FailureType]
failures: list[FailureReport] = []
custom_ops_logs: dict[
Any, tuple[dict[str, Any], FailureType]
] = {} # Dedup custom ops
data_dependent_logs: Dict[
str, Dict[str, Any]
data_dependent_logs: dict[
str, dict[str, Any]
] = {} # Dedup data dependent errors based on stacktrace

for log_name, log_contents in capture_structured_log.logs:
Expand Down
11 changes: 5 additions & 6 deletions torch/export/_remove_effect_tokens_pass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import operator
from typing import List

import torch
from torch._higher_order_ops.effects import _get_schema, with_effects
Expand All @@ -22,7 +21,7 @@ def _remove_effect_tokens_from_graph_helper(
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs

output_node = None
with_effect_nodes: List[torch.fx.Node] = []
with_effect_nodes: list[torch.fx.Node] = []

# Output node need to check its args agianst output_token_names (collected from output_spec)
# Therefore, we only need to find the top-levele output node
Expand Down Expand Up @@ -127,8 +126,8 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
This function does an inplace modification on the given ExportedProgram.
"""
num_tokens: int = 0
input_token_names: List[str] = []
new_input_specs: List[InputSpec] = []
input_token_names: list[str] = []
new_input_specs: list[InputSpec] = []
for inp in ep.graph_signature.input_specs:
if inp.kind == InputKind.TOKEN:
num_tokens += 1
Expand All @@ -138,8 +137,8 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
new_input_specs.append(inp)

num_out_tokens: int = 0
new_output_specs: List[OutputSpec] = []
output_token_names: List[OutputSpec] = []
new_output_specs: list[OutputSpec] = []
output_token_names: list[OutputSpec] = []
for out in ep.graph_signature.output_specs:
if out.kind == OutputKind.TOKEN:
num_out_tokens += 1
Expand Down
28 changes: 14 additions & 14 deletions torch/export/_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import operator
import types
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional

import torch
import torch.fx._pytree as fx_pytree
Expand All @@ -19,7 +19,7 @@
log = logging.getLogger(__name__)


def _get_getitem_users(node: torch.fx.Node) -> Set[torch.fx.Node]:
def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
node_users = list(node.users.keys())
getitem_users = set()
for user in node_users:
Expand Down Expand Up @@ -172,9 +172,9 @@ def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
def _construct_inputs(
gm: torch.fx.GraphModule,
signature: ModuleCallSignature,
node_name_map: Dict[str, torch.fx.Node],
) -> Tuple[List[torch.fx.Node], Dict[str, torch.fx.Node]]:
tree_unflatten_args: List[Optional[torch.fx.Node]] = []
node_name_map: dict[str, torch.fx.Node],
) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]:
tree_unflatten_args: list[Optional[torch.fx.Node]] = []
for input_ in signature.inputs:
if isinstance(input_, ConstantArgument) and input_.value is None:
# Constants should be directly embedded into the graph and not used
Expand Down Expand Up @@ -213,8 +213,8 @@ def _construct_inputs(

def _insert_call_module(
gm: torch.fx.GraphModule,
args_nodes: List[torch.fx.Node],
kwargs_nodes: Dict[str, torch.fx.Node],
args_nodes: list[torch.fx.Node],
kwargs_nodes: dict[str, torch.fx.Node],
module_to_swap: torch.nn.Module,
name: str,
) -> torch.fx.Node:
Expand All @@ -229,8 +229,8 @@ def _deconstruct_outputs(
gm: torch.fx.GraphModule,
signature: ModuleCallSignature,
module_node: torch.fx.Node,
node_name_map: Dict[str, torch.fx.Node],
orig_outputs: Tuple[torch.fx.Node, ...],
node_name_map: dict[str, torch.fx.Node],
orig_outputs: tuple[torch.fx.Node, ...],
) -> None:
from .unflatten import _generate_flatten_spec

Expand All @@ -246,17 +246,17 @@ def _deconstruct_outputs(

def _swap_module_helper(
gm: torch.fx.GraphModule,
modules_to_swap: Dict[str, torch.nn.Module],
module_call_graph: Dict[str, ModuleCallSignature],
modules_to_swap: dict[str, torch.nn.Module],
module_call_graph: dict[str, ModuleCallSignature],
) -> torch.fx.GraphModule:
log.debug("Starting graph:")
log.debug(gm.graph)

legalize_graph(gm)

partitions: Dict[str, NodeList] = defaultdict(list)
partitions: dict[str, NodeList] = defaultdict(list)

node_name_map: Dict[str, torch.fx.Node] = {
node_name_map: dict[str, torch.fx.Node] = {
node.name: node for node in gm.graph.nodes
}

Expand Down Expand Up @@ -399,7 +399,7 @@ def _fix_input_output_signature(


def _swap_modules(
ep: ExportedProgram, modules_to_swap: Dict[str, torch.nn.Module]
ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module]
) -> torch.fx.GraphModule:
"""
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
Expand Down
Loading

0 comments on commit b6c5562

Please sign in to comment.