From afa4f2c5d18a18ff22e9512e19490b7861cfb307 Mon Sep 17 00:00:00 2001 From: "Estrada Irribarra, Rodrigo Andres" Date: Wed, 23 Oct 2024 12:07:27 -0300 Subject: [PATCH] feat(chat): integrate prompt_toolkit for advanced input handling and rich for output formatting - Added prompt_toolkit PromptSession to handle multi-line input, history, and cursor movement. - Integrated rich for rendering Markdown and colorized output. - Adjusted input processing to pass user queries to the assistant via create_message. - Ensured the output is displayed using rich's console.print for proper formatting. - Included error handling for invalid commands and exceptions during execution. --- examples/example_usage.sh | 4 +-- poetry.lock | 10 +++---- pyproject.toml | 1 + storycraftr/agent/agents.py | 4 ++- storycraftr/cmd/chat.py | 60 ++++++++++++++++++++++++------------- 5 files changed, 49 insertions(+), 30 deletions(-) diff --git a/examples/example_usage.sh b/examples/example_usage.sh index 463ab78..a1ea84c 100644 --- a/examples/example_usage.sh +++ b/examples/example_usage.sh @@ -12,9 +12,7 @@ while [[ "$#" -gt 0 ]]; do done # Ejecuta los comandos con la variable COMMAND, que puede ser 'poetry run storycraftr' o 'storycraftr' -$COMMAND init "La Purga de los dioses" --primary-language "es" --alternate-languages "en" --author "Rodrigo Estrada" --genre "science fiction" --behavior "behavior.txt" --reference-author="Brandon Sanderson" - -cd "La Purga de los dioses" +$COMMAND init "The Purge of the gods" --primary-language "en" --alternate-languages "es" --author "Rodrigo Estrada" --genre "science fiction" --behavior "behavior.txt" --reference-author="Brandon Sanderson" $COMMAND outline general-outline "Summarize the overall plot of a dystopian science fiction where advanced technology, resembling magic, has led to the fall of humanity’s elite and the rise of a manipulative villain who seeks to destroy both the ruling class and the workers." diff --git a/poetry.lock b/poetry.lock index 8ca53b3..8372ea2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -574,7 +574,7 @@ files = [ parso = ">=0.8.3,<0.9.0" [package.extras] -docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.22)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] @@ -1402,7 +1402,7 @@ files = [ ] [package.dependencies] -alabaster = ">=0.7.24" +alabaster = ">=0.7.14" babel = ">=2.13" colorama = {version = ">=0.4.6", markers = "sys_platform == \"win32\""} docutils = ">=0.20,<0.22" @@ -1423,7 +1423,7 @@ tomli = {version = ">=2", markers = "python_version < \"3.11\""} [package.extras] docs = ["sphinxcontrib-websupport"] lint = ["flake8 (>=6.0)", "mypy (==1.11.1)", "pyright (==1.1.384)", "pytest (>=6.0)", "ruff (==0.6.9)", "sphinx-lint (>=0.9)", "tomli (>=2)", "types-Pillow (==10.2.0.20240822)", "types-Pygments (==2.18.0.20240506)", "types-colorama (==0.4.15.20240311)", "types-defusedxml (==0.7.0.20240218)", "types-docutils (==0.21.0.20241005)", "types-requests (==2.32.0.20240914)", "types-urllib3 (==1.26.25.14)"] -test = ["cython (>=3.0)", "defusedxml (>=0.7.2)", "pytest (>=8.0)", "setuptools (>=70.0)", "typing_extensions (>=4.9)"] +test = ["cython (>=3.0)", "defusedxml (>=0.7.1)", "pytest (>=8.0)", "setuptools (>=70.0)", "typing_extensions (>=4.9)"] [[package]] name = "sphinx-click" @@ -1517,7 +1517,7 @@ files = [ [package.extras] lint = ["mypy", "ruff (==0.5.5)", "types-docutils"] standalone = ["Sphinx (>=5)"] -test = ["defusedxml (>=0.7.2)", "pytest"] +test = ["defusedxml (>=0.7.1)", "pytest"] [[package]] name = "sphinxcontrib-serializinghtml" @@ -1673,4 +1673,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "20efef1e1b074ee0018d068f9e4416e6a0c475ff586f115c890668c7c6f5126f" +content-hash = "4754fb8e0b11603dad3988f4a16e029de751c2923726f8940f7ecd5e0fc2fbb0" diff --git a/pyproject.toml b/pyproject.toml index 4096799..6eb5837 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ rich = "^13.9.2" python-dotenv = "^1.0.1" pyyaml = "^6.0.2" requests = "^2.32.3" +prompt-toolkit = "^3.0.48" [tool.poetry.dev-dependencies] pytest = "^6.2.4" diff --git a/storycraftr/agent/agents.py b/storycraftr/agent/agents.py index 67eeefb..c5ba030 100644 --- a/storycraftr/agent/agents.py +++ b/storycraftr/agent/agents.py @@ -174,8 +174,10 @@ def create_or_get_assistant(book_path: str, progress: Progress = None, task=None assistant = client.beta.assistants.create( instructions=instructions, name=name, - tools=[{"type": "code_interpreter"}, {"type": "file_search"}], + tools=[{"type": "file_search"}], model="gpt-4o", + temperature=0.7, # Nivel de creatividad balanceado + top_p=1.0, # Considerar todas las opciones ) client.beta.assistants.update( diff --git a/storycraftr/cmd/chat.py b/storycraftr/cmd/chat.py index d726849..ac326a1 100644 --- a/storycraftr/cmd/chat.py +++ b/storycraftr/cmd/chat.py @@ -6,6 +6,8 @@ from storycraftr.utils.core import load_book_config from storycraftr.agent.agents import get_thread, create_or_get_assistant, create_message from storycraftr.cmd import iterate, outline, worldbuilding, chapters +from prompt_toolkit import PromptSession +from prompt_toolkit.history import InMemoryHistory console = Console() @@ -39,37 +41,53 @@ def chat(book_path=None): assistant = create_or_get_assistant(book_path) thread = get_thread() + session = PromptSession(history=InMemoryHistory()) + console.print("[bold green]USE help() to get help and exit() to exit[/bold green]") while True: - user_input = console.input("[bold blue]You:[/bold blue] ") - - if user_input.lower() == "exit()": - console.print("[bold red]Exiting chat...[/bold red]") - break + try: + # Capture user input with prompt_toolkit + user_input = session.prompt("You: ") - if user_input.lower() == "help()": - display_help() - continue + if user_input.lower() == "exit()": + console.print("[bold red]Exiting chat...[/bold red]") + break - if user_input.startswith("!"): - execute_cli_command(user_input[1:]) - continue + if user_input.lower() == "help()": + display_help() + continue - user_input = ( - f"Answer the next prompt formatted on markdown (text): {user_input}" - ) + if user_input.startswith("!"): + execute_cli_command(user_input[1:]) + continue - try: - response = create_message( - book_path, thread_id=thread.id, content=user_input, assistant=assistant + # Pass the user input to the assistant for processing + user_input = ( + f"Answer the next prompt formatted on markdown (text): {user_input}" ) + + try: + # Generate the response + response = create_message( + book_path, + thread_id=thread.id, + content=user_input, + assistant=assistant, + ) + + # Render Markdown response + markdown_response = Markdown(response) + console.print(markdown_response) + + except Exception as e: + console.print(f"[bold red]Error: {str(e)}[/bold red]") + + except KeyboardInterrupt: + console.print("[bold red]Exiting chat...[/bold red]") + break except Exception as e: console.print(f"[bold red]Error: {str(e)}[/bold red]") - continue - - markdown_response = Markdown(response) - console.print(markdown_response) def execute_cli_command(user_input):