diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index cce7429118..eb554c5916 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -760,12 +760,37 @@ def branch_writer( config, cast(Sequence[Union[Send, ChannelWriteEntry]], writes) ) + # detect branch input schema + input = None + + # detect input schema annotation in the branch callable + try: + if ( + isinstance(branch.path, RunnableCallable) + and ( + isfunction(branch.path.func) + or ismethod(getattr(branch.path.func, "__call__", None)) + ) + and ( + hints := get_type_hints(getattr(branch.path.func, "__call__")) + or get_type_hints(branch.path.func) + ) + ): + first_parameter_name = next( + iter( + inspect.signature( + cast(FunctionType, branch.path.func) + ).parameters.keys() + ) + ) + if input_hint := hints.get(first_parameter_name): + if isinstance(input_hint, type) and get_type_hints(input_hint): + input = input_hint + except (TypeError, StopIteration): + pass + + schema = input or self.builder.schema # attach branch publisher - schema = ( - self.builder.nodes[start].input - if start in self.builder.nodes - else self.builder.schema - ) self.nodes[start] |= branch.run( branch_writer, _get_state_reader(self.builder, schema) if with_reader else None, diff --git a/libs/langgraph/tests/test_state.py b/libs/langgraph/tests/test_state.py index 0a4a8725ac..3c21d3e76d 100644 --- a/libs/langgraph/tests/test_state.py +++ b/libs/langgraph/tests/test_state.py @@ -1,4 +1,5 @@ import inspect +import operator import warnings from dataclasses import dataclass, field from typing import Annotated as Annotated2 @@ -318,3 +319,68 @@ def class_method(self, state): # class method assert _get_node_name(MyClass().class_method) == "class_method" + + +def test_input_schema_conditional_edge(): + class OverallState(TypedDict): + foo: Annotated[int, operator.add] + bar: str + + class PrivateState(TypedDict): + baz: str + + builder = StateGraph(OverallState) + + def node_1(state: OverallState): + return {"foo": 1, "baz": "bar"} + + def node_2(state: PrivateState): + return {"foo": 1, "bar": state["baz"]} + + def node_3(state: OverallState): + return {"foo": 1} + + def router(state: OverallState): + if state["foo"] == 2: + return "node_3" + else: + return "__end__" + + builder.add_node(node_1) + builder.add_node(node_2) + builder.add_node(node_3) + builder.add_conditional_edges("node_2", router) + builder.add_edge("__start__", "node_1") + builder.add_edge("node_1", "node_2") + graph = builder.compile() + assert graph.invoke({"foo": 0}) == {"foo": 3, "bar": "bar"} + + +def test_private_input_schema_conditional_edge(): + class OverallState(TypedDict): + foo: Annotated[int, operator.add] + bar: str + + class PrivateState(TypedDict): + baz: str + + builder = StateGraph(OverallState) + + def node_1(state: OverallState): + return {"foo": 1, "baz": "meow"} + + def node_2(state: PrivateState): + return {"foo": 1, "bar": state["baz"]} + + def router(state: PrivateState): + if state["baz"] == "meow": + return "node_2" + else: + return "__end__" + + builder.add_node(node_1) + builder.add_node(node_2) + builder.add_conditional_edges("node_1", router) + builder.add_edge("__start__", "node_1") + graph = builder.compile() + assert graph.invoke({"foo": 0}) == {"foo": 2, "bar": "meow"}