From def2c317381f71cf00566324ad23fccede1ddfe5 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Sun, 27 Oct 2024 21:34:39 +0100 Subject: [PATCH] feat: Add Redis prompt registry (#21) * add redis registry * run redis in the CI * split the jobs * skip coverage on mac and win * reformat * fix linting * add docs --- .github/workflows/test.yml | 98 +++++++++++---- .gitignore | 1 + docs/python.md | 12 ++ docs/registry.md | 62 ++++++++- mkdocs.yml | 29 ++--- pyproject.toml | 200 ++++++++++++++---------------- src/banks/registries/directory.py | 70 +++++++++++ src/banks/registries/file.py | 65 +++++++++- src/banks/registries/redis.py | 82 ++++++++++++ tests/conftest.py | 21 ++++ tests/test_redis_registry.py | 131 +++++++++++++++++++ 11 files changed, 622 insertions(+), 149 deletions(-) create mode 100644 src/banks/registries/redis.py create mode 100644 tests/conftest.py create mode 100644 tests/test_redis_registry.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0a1e062..cbc9f1d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,7 +5,7 @@ name: test on: push: branches: - - main + - main pull_request: concurrency: @@ -17,33 +17,89 @@ env: FORCE_COLOR: "1" jobs: - run: - name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} - runs-on: ${{ matrix.os }} + linux: + name: Python ${{ matrix.python-version }} on Linux + runs-on: ubuntu-latest + services: + redis: + image: redis + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.10', '3.11', '3.12'] + python-version: ["3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - if: matrix.python-version == '3.12' + name: Lint + run: hatch run lint:all + + - name: Run tests + run: hatch run cov + + - if: matrix.python-version == '3.12' + name: Report Coveralls + uses: coverallsapp/github-action@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} + macos: + name: Python ${{ matrix.python-version }} on macOS + runs-on: macos-latest - - name: Install Hatch - run: pip install --upgrade hatch + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Run tests + run: hatch run test + + windows: + name: Python ${{ matrix.python-version }} on Windows + runs-on: windows-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 - - if: matrix.python-version == '3.12' && runner.os == 'Linux' - name: Lint - run: hatch run lint:all + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - - name: Run tests - run: hatch run cov + - name: Install Hatch + run: pip install --upgrade hatch - - if: matrix.python-version == '3.12' && runner.os == 'Linux' - name: Report Coveralls - uses: coverallsapp/github-action@v2 \ No newline at end of file + - name: Run tests + run: hatch run test diff --git a/.gitignore b/.gitignore index 1b374ca..ea7f815 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ cython_debug/ # IDEs and editors .idea/ .vscode/ +.zed/ diff --git a/docs/python.md b/docs/python.md index 516b5e9..1748ed5 100644 --- a/docs/python.md +++ b/docs/python.md @@ -5,6 +5,18 @@ ::: banks.prompt.AsyncPrompt +::: banks.registries.directory.DirectoryPromptRegistry + options: + inherited_members: true + +::: banks.registries.file.FilePromptRegistry + options: + inherited_members: true + +::: banks.registries.redis.RedisPromptRegistry + options: + inherited_members: true + ## Default macros Banks' package comes with default template macros you can use in your prompts. diff --git a/docs/registry.md b/docs/registry.md index 14205cb..32380c0 100644 --- a/docs/registry.md +++ b/docs/registry.md @@ -1,9 +1,63 @@ ## Prompt registry (BETA) -Prompt registry is a storage API for versioned prompts. It allows you to store and retrieve prompts from local storage. -Currently, it supports storing templates in a single JSON file or in a file system directory, but it can be extended to -support other storage backends. +The Prompt Registry provides a storage API for managing versioned prompts. It allows you to store and retrieve prompts from different storage backends. Currently, Banks supports two storage implementations: + +- Directory-based storage +- Redis-based storage ### Usage -Coming soon. \ No newline at end of file +```python +from banks import Prompt +from banks.registries.directory import DirectoryPromptRegistry +from pathlib import Path + +# Create a registry +registry = DirectoryPromptRegistry(Path("./prompts")) + +# Create and store a prompt +prompt = Prompt( + text="Write a blog post about {{topic}}", + name="blog_writer", + version="1.0", + metadata={"author": "John Doe"} +) +registry.set(prompt=prompt) + +# Retrieve a prompt +retrieved_prompt = registry.get(name="blog_writer", version="1.0") +``` + +### Directory Registry + +The DirectoryPromptRegistry stores prompts as individual files in a directory. Each prompt is saved as a `.jinja` file with the naming pattern `{name}.{version}.jinja`. + +```python +# Initialize directory registry +registry = DirectoryPromptRegistry( + directory_path=Path("./prompts"), + force_reindex=False # Set to True to rebuild the index +) +``` + +### Redis Registry + +The RedisPromptRegistry stores prompts in Redis using a key-value structure. + +```python +from banks.registries.redis import RedisPromptRegistry + +registry = RedisPromptRegistry( + redis_url="redis://localhost:6379", + prefix="banks:prompt:" +) +``` + +### Common Features + +Both implementations support: + +- Versioning with automatic "0" default version +- Overwrite protection with `overwrite=True` option +- Metadata storage +- Error handling for missing/invalid prompts diff --git a/mkdocs.yml b/mkdocs.yml index 76c7652..f0de8f0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -8,12 +8,12 @@ theme: scheme: slate nav: - - Home: 'index.md' - - Examples: 'examples.md' - - Python API: 'python.md' - - Prompt API: 'prompt.md' - - Configuration: 'config.md' - - Prompt Registry: 'registry.md' + - Home: "index.md" + - Examples: "examples.md" + - Python API: "python.md" + - Prompt API: "prompt.md" + - Configuration: "config.md" + - Prompt Registry: "registry.md" plugins: - search @@ -22,13 +22,14 @@ plugins: python: paths: [src] options: - docstring_style: google - show_root_heading: true - show_root_full_path: true - show_symbol_type_heading: true - show_source: false - show_signature_annotations: true - show_bases: false + docstring_style: google + show_root_heading: true + show_root_full_path: true + show_symbol_type_heading: false + show_source: false + show_signature_annotations: true + show_bases: false + separate_signature: true markdown_extensions: - attr_list @@ -42,4 +43,4 @@ markdown_extensions: - pymdownx.inlinehilite - pymdownx.snippets - pymdownx.superfences - - admonition \ No newline at end of file + - admonition diff --git a/pyproject.toml b/pyproject.toml index f5ada15..cca1ce3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,24 +10,23 @@ readme = "README.md" requires-python = ">=3.10" license = "MIT" keywords = [] -authors = [ - { name = "Massimiliano Pippi", email = "mpippi@gmail.com" }, -] +authors = [{ name = "Massimiliano Pippi", email = "mpippi@gmail.com" }] classifiers = [ - "Development Status :: 4 - Beta", - "Programming Language :: Python", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "griffe", - "jinja2", - "litellm", - "pydantic", - "deprecated", + "griffe", + "jinja2", + "litellm", + "pydantic", + "deprecated", + "redis", ] [project.urls] @@ -40,51 +39,34 @@ path = "src/banks/__about__.py" [tool.hatch.envs.default] dependencies = [ - "coverage[toml]>=6.5", - "pytest", - "pytest-cov", - "pytest-asyncio", - "mkdocs-material", - "mkdocstrings[python]", - "simplemma", + "coverage[toml]>=6.5", + "pytest", + "pytest-cov", + "pytest-asyncio", + "mkdocs-material", + "mkdocstrings[python]", + "simplemma", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "pytest --cov --cov-report=xml {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report -m", -] -cov = [ - "test-cov", - "cov-report", -] +cov-report = ["- coverage combine", "coverage report -m"] +cov = ["test-cov", "cov-report"] docs = "mkdocs {args:build}" [[tool.hatch.envs.all.matrix]] python = ["3.10", "3.11", "3.12"] [tool.hatch.envs.lint] -detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need -dependencies = [ - "mypy>=1.0.0", - "ruff>=0.0.243", - "pylint", -] +detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need +dependencies = ["mypy>=1.0.0", "ruff>=0.0.243", "pylint"] [tool.hatch.envs.lint.scripts] -check = [ - "ruff format --check {args}", - "ruff check {args:.}", -] +check = ["ruff format --check {args}", "ruff check {args:.}"] lint = "pylint {args:src/banks}" typing = "mypy --install-types --non-interactive {args:src/banks}" -all = [ - "check", - "typing", - "lint", -] +all = ["check", "typing", "lint"] fmt = "ruff format {args}" [tool.hatch.build.targets.wheel] @@ -100,51 +82,59 @@ line-length = 120 exclude = ["cookbooks"] [tool.ruff.lint] select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "EM", - "F", - "FBT", - "I", - "ICN", - "ISC", - "N", - "PLC", - "PLE", - "PLR", - "PLW", - "Q", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", ] ignore = [ - # Allow methods like 'set' - "A003", - # Allow non-abstract empty methods in abstract base classes - "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", - # Ignore checks for possible passwords - "S105", "S106", "S107", - # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", - # Avoid conflicts with the formatter - "ISC001", - # Magic numbers - "PLR2004", + # Allow unused arguments + "ARG001", + # Allow methods like 'set' + "A003", + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Avoid conflicts with the formatter + "ISC001", + # Magic numbers + "PLR2004", ] unfixable = [ - # Don't touch unused imports - "F401", + # Don't touch unused imports + "F401", ] [tool.ruff.lint.isort] @@ -163,43 +153,39 @@ source_pkgs = ["banks", "tests"] branch = true parallel = true omit = [ - "src/banks/__about__.py", - "tests/*", - "src/banks/extensions/docs.py", - # deprecated modules, to be removed - "src/banks/extensions/inference_endpoint.py", - "src/banks/extensions/generate.py", + "src/banks/__about__.py", + "tests/*", + "src/banks/extensions/docs.py", + # deprecated modules, to be removed + "src/banks/extensions/inference_endpoint.py", + "src/banks/extensions/generate.py", ] [tool.coverage.paths] banks = ["src/banks", "*/banks/src/banks"] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ - # "griffe.*", - "litellm.*", - "simplemma.*", + # "griffe.*", + "litellm.*", + "simplemma.*", ] ignore_missing_imports = true [tool.pylint] disable = [ - "line-too-long", - "too-few-public-methods", - "missing-module-docstring", - "missing-class-docstring", - "missing-function-docstring", - "cyclic-import", + "line-too-long", + "too-few-public-methods", + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "cyclic-import", ] max-args = 10 [tool.pytest.ini_options] -asyncio_default_fixture_loop_scope = "function" \ No newline at end of file +asyncio_default_fixture_loop_scope = "function" diff --git a/src/banks/registries/directory.py b/src/banks/registries/directory.py index 437adf1..b6c46b3 100644 --- a/src/banks/registries/directory.py +++ b/src/banks/registries/directory.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +""" +Directory-based prompt registry implementation that stores prompts as files. +""" + import time from pathlib import Path @@ -19,10 +23,22 @@ class PromptFile(PromptModel): + """Model representing a prompt file stored on disk.""" + path: Path = Field(exclude=True) @classmethod def from_prompt_path(cls: type[Self], prompt: Prompt, path: Path) -> Self: + """ + Create a PromptFile instance from a Prompt object and save it to disk. + + Args: + prompt: The Prompt object to save + path: Directory path where the prompt file should be stored + + Returns: + A new PromptFile instance + """ prompt_file = path / f"{prompt.name}.{prompt.version}.jinja" prompt_file.write_text(prompt.raw) return cls( @@ -31,11 +47,25 @@ def from_prompt_path(cls: type[Self], prompt: Prompt, path: Path) -> Self: class PromptFileIndex(BaseModel): + """Index tracking all prompt files in the directory.""" + files: list[PromptFile] = Field(default=[]) class DirectoryPromptRegistry: + """Registry that stores prompts as files in a directory structure.""" + def __init__(self, directory_path: Path, *, force_reindex: bool = False): + """ + Initialize the directory prompt registry. + + Args: + directory_path: Path to directory where prompts will be stored + force_reindex: Whether to force rebuilding the index from disk + + Raises: + ValueError: If directory_path is not a directory + """ if not directory_path.is_dir(): msg = "{directory_path} must be a directory." raise ValueError(msg) @@ -49,9 +79,23 @@ def __init__(self, directory_path: Path, *, force_reindex: bool = False): @property def path(self) -> Path: + """Get the directory path where prompts are stored.""" return self._path def get(self, *, name: str, version: str | None = None) -> Prompt: + """ + Retrieve a prompt by name and version. + + Args: + name: Name of the prompt to retrieve + version: Version of the prompt (optional) + + Returns: + The requested Prompt object + + Raises: + PromptNotFoundError: If prompt doesn't exist + """ version = version or DEFAULT_VERSION for pf in self._index.files: if pf.name == name and pf.version == version and pf.path.exists(): @@ -59,6 +103,16 @@ def get(self, *, name: str, version: str | None = None) -> Prompt: raise PromptNotFoundError def set(self, *, prompt: Prompt, overwrite: bool = False): + """ + Store a prompt in the registry. + + Args: + prompt: The Prompt object to store + overwrite: Whether to overwrite existing prompt + + Raises: + InvalidPromptError: If prompt exists and overwrite=False + """ try: version = prompt.version or DEFAULT_VERSION idx, pf = self._get_prompt_file(name=prompt.name, version=version) @@ -76,12 +130,15 @@ def set(self, *, prompt: Prompt, overwrite: bool = False): self._save() def _load(self): + """Load the prompt index from disk.""" self._index = PromptFileIndex.model_validate_json(self._index_path.read_text()) def _save(self): + """Save the prompt index to disk.""" self._index_path.write_text(self._index.model_dump_json()) def _scan(self): + """Scan directory for prompt files and build the index.""" self._index: PromptFileIndex = PromptFileIndex() for path in self._path.glob("*.jinja*"): name, version = path.stem.rsplit(".", 1) if "." in path.stem else (path.stem, DEFAULT_VERSION) @@ -91,6 +148,19 @@ def _scan(self): self._index_path.write_text(self._index.model_dump_json()) def _get_prompt_file(self, *, name: str | None, version: str) -> tuple[int, PromptFile]: + """ + Find a prompt file in the index. + + Args: + name: Name of the prompt + version: Version of the prompt + + Returns: + Tuple of (index position, PromptFile) + + Raises: + PromptNotFoundError: If prompt doesn't exist in index + """ for i, pf in enumerate(self._index.files): if pf.name == name and pf.version == version: return i, pf diff --git a/src/banks/registries/file.py b/src/banks/registries/file.py index f10b1f8..97f2d2c 100644 --- a/src/banks/registries/file.py +++ b/src/banks/registries/file.py @@ -1,6 +1,14 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +""" +File-based prompt registry implementation that stores all prompts in a single JSON file. + +This module provides functionality to store and retrieve prompts using a single JSON file +as the storage backend. The file contains an index of all prompts with their associated +metadata and content. +""" + from pathlib import Path from pydantic import BaseModel @@ -10,17 +18,27 @@ class PromptRegistryIndex(BaseModel): + """ + Model representing the registry index containing all prompts. + + Stores a list of PromptModel objects that represent all prompts in the registry. + """ + prompts: list[PromptModel] = [] class FilePromptRegistry: - """A prompt registry storing all the prompt data in a single JSON file.""" + """A prompt registry storing all prompt data in a single JSON file.""" def __init__(self, registry_index: Path) -> None: - """Creates an instance of the File Prompt Registry. + """ + Initialize the file prompt registry. Args: - registry_index: The path to the index file. + registry_index: Path to the JSON file that will store the prompts + + Note: + Creates parent directories if they don't exist. """ self._index_fpath: Path = registry_index self._index: PromptRegistryIndex = PromptRegistryIndex(prompts=[]) @@ -31,10 +49,33 @@ def __init__(self, registry_index: Path) -> None: self._index_fpath.parent.mkdir(parents=True, exist_ok=True) def get(self, *, name: str, version: str | None = None) -> Prompt: + """ + Retrieve a prompt by name and version. + + Args: + name: Name of the prompt to retrieve + version: Version of the prompt (optional) + + Returns: + The requested Prompt object + + Raises: + PromptNotFoundError: If the requested prompt doesn't exist + """ _, model = self._get_prompt_model(name, version) return Prompt(**model.model_dump()) def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: + """ + Store a prompt in the registry. + + Args: + prompt: The Prompt object to store + overwrite: Whether to overwrite an existing prompt + + Raises: + InvalidPromptError: If prompt exists and overwrite=False + """ try: idx, p_model = self._get_prompt_model(prompt.name, prompt.version) if overwrite: @@ -49,10 +90,28 @@ def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: self._save() def _save(self) -> None: + """ + Save the prompt index to the JSON file. + + Writes the current state of the registry to disk. + """ with open(self._index_fpath, "w", encoding="locale") as f: f.write(self._index.model_dump_json()) def _get_prompt_model(self, name: str | None, version: str | None) -> tuple[int, PromptModel]: + """ + Find a prompt model in the index by name and version. + + Args: + name: Name of the prompt + version: Version of the prompt + + Returns: + Tuple of (index position, PromptModel) + + Raises: + PromptNotFoundError: If the prompt doesn't exist in the index + """ for i, model in enumerate(self._index.prompts): if model.name == name and model.version == version: return i, model diff --git a/src/banks/registries/redis.py b/src/banks/registries/redis.py new file mode 100644 index 0000000..283c341 --- /dev/null +++ b/src/banks/registries/redis.py @@ -0,0 +1,82 @@ +import json +from typing import cast + +import redis + +from banks import Prompt +from banks.errors import InvalidPromptError, PromptNotFoundError +from banks.prompt import DEFAULT_VERSION, PromptModel + + +class RedisPromptRegistry: + """A prompt registry that stores prompts in Redis.""" + + def __init__( + self, + redis_url: str = "redis://localhost:6379", + prefix: str = "banks:prompt:", + ) -> None: + """ + Initialize the Redis prompt registry. + + Parameters: + redis_url: Redis connection URL + prefix: Key prefix for storing prompts in Redis + """ + self._redis = redis.from_url(redis_url, decode_responses=True) + self._prefix = prefix + + def _make_key(self, name: str, version: str) -> str: + """Create Redis key for a prompt.""" + return f"{self._prefix}{name}:{version}" + + def get(self, *, name: str, version: str | None = None) -> Prompt: + """ + Get a prompt by name and version. + + Parameters: + name: Name of the prompt + version: Version of the prompt (optional) + + Returns: + Prompt instance + + Raises: + PromptNotFoundError: If prompt doesn't exist + """ + version = version or DEFAULT_VERSION + key = self._make_key(name, version) + + data = self._redis.get(key) + if not data: + msg = f"Cannot find prompt with name '{name}' and version '{version}'" + raise PromptNotFoundError(msg) + + prompt_data = json.loads(cast(str, data)) + return Prompt(**prompt_data) + + def set(self, *, prompt: Prompt, overwrite: bool = False) -> None: + """ + Store a prompt in Redis. + + Parameters: + prompt: Prompt instance to store + overwrite: Whether to overwrite existing prompt + + Raises: + InvalidPromptError: If prompt exists and overwrite=False + """ + version = prompt.version or DEFAULT_VERSION + key = self._make_key(prompt.name, version) + + # Check if prompt already exists + if self._redis.exists(key) and not overwrite: + msg = f"Prompt with name '{prompt.name}' already exists. Use overwrite=True to overwrite" + raise InvalidPromptError(msg) + + # Convert prompt to serializable format + prompt_model = PromptModel.from_prompt(prompt) + prompt_data = prompt_model.model_dump() + + # Store in Redis + self._redis.set(key, json.dumps(prompt_data)) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3e84d99 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,21 @@ +import pytest +import redis + + +def is_redis_available(): + try: + redis.Redis(host="localhost", port=6379).ping() + return True + except redis.ConnectionError: + return False + + +def pytest_configure(config): + config.addinivalue_line("markers", "redis: mark test as requiring Redis") + + +@pytest.fixture(autouse=True) +def skip_by_redis(request): + if request.node.get_closest_marker("redis"): + if not is_redis_available(): + pytest.skip("Redis is not available") diff --git a/tests/test_redis_registry.py b/tests/test_redis_registry.py new file mode 100644 index 0000000..5c5568e --- /dev/null +++ b/tests/test_redis_registry.py @@ -0,0 +1,131 @@ +# banks/tests/test_redis_registry.py +import pytest +import redis + +from banks.errors import InvalidPromptError, PromptNotFoundError +from banks.prompt import Prompt +from banks.registries.redis import RedisPromptRegistry + + +@pytest.fixture +def redis_client(): + client = redis.Redis(host="localhost", port=6379, db=0) + # Clear test database before each test + client.flushdb() + return client + + +@pytest.fixture +def registry(redis_client): # type: ignore[ARG001] + return RedisPromptRegistry(redis_url="redis://localhost:6379") + + +@pytest.mark.redis +def test_get_not_found(registry): + with pytest.raises(PromptNotFoundError): + registry.get(name="nonexistent") + + +@pytest.mark.redis +def test_set_and_get_prompt(registry): + prompt = Prompt("Hello {{name}}!", name="greeting") + registry.set(prompt=prompt) + + retrieved = registry.get(name="greeting") + assert retrieved.raw == "Hello {{name}}!" + assert retrieved.name == "greeting" + assert retrieved.version == "0" # default version + assert retrieved.metadata == {} + + +@pytest.mark.redis +def test_set_existing_no_overwrite(registry): + prompt = Prompt("Hello {{name}}!", name="greeting") + registry.set(prompt=prompt) + + new_prompt = Prompt("Hi {{name}}!", name="greeting") + with pytest.raises( + InvalidPromptError, match="Prompt with name 'greeting' already exists. Use overwrite=True to overwrite" + ): + registry.set(prompt=new_prompt) + + +@pytest.mark.redis +def test_set_existing_overwrite(registry): + prompt = Prompt("Hello {{name}}!", name="greeting") + registry.set(prompt=prompt) + + new_prompt = Prompt("Hi {{name}}!", name="greeting") + registry.set(prompt=new_prompt, overwrite=True) + + retrieved = registry.get(name="greeting") + assert retrieved.raw == "Hi {{name}}!" + + +@pytest.mark.redis +def test_set_multiple_versions(registry): + prompt_v1 = Prompt("Version 1", name="multi", version="1") + prompt_v2 = Prompt("Version 2", name="multi", version="2") + + registry.set(prompt=prompt_v1) + registry.set(prompt=prompt_v2) + + retrieved_v1 = registry.get(name="multi", version="1") + assert retrieved_v1.raw == "Version 1" + + retrieved_v2 = registry.get(name="multi", version="2") + assert retrieved_v2.raw == "Version 2" + + +@pytest.mark.redis +def test_get_with_version(registry): + prompt = Prompt("Test {{var}}", name="test", version="1.0") + registry.set(prompt=prompt) + + retrieved = registry.get(name="test", version="1.0") + assert retrieved.version == "1.0" + assert retrieved.raw == "Test {{var}}" + + +@pytest.mark.redis +def test_set_with_metadata(registry): + prompt = Prompt("Test prompt", name="test", metadata={"author": "John", "category": "test"}) + registry.set(prompt=prompt) + + retrieved = registry.get(name="test") + assert retrieved.metadata == {"author": "John", "category": "test"} + + +@pytest.mark.redis +def test_update_metadata(registry): + # Initial prompt with metadata + prompt = Prompt("Test prompt", name="test", metadata={"score": 90}) + registry.set(prompt=prompt) + + # Update metadata + updated_prompt = registry.get(name="test") + updated_prompt.metadata["score"] = 95 + registry.set(prompt=updated_prompt, overwrite=True) + + # Verify update + retrieved = registry.get(name="test") + assert retrieved.metadata["score"] == 95 + + +@pytest.mark.redis +def test_invalid_redis_connection(): + with pytest.raises(redis.ConnectionError): + registry = RedisPromptRegistry(redis_url="redis://nonexistent:6379") + registry.get(name="test") + + +@pytest.mark.redis +def test_custom_prefix(redis_client): + registry = RedisPromptRegistry(redis_url="redis://localhost:6379", prefix="custom:prefix:") + + prompt = Prompt("Test", name="test") + registry.set(prompt=prompt) + + # Verify the key in Redis has the custom prefix + key = "custom:prefix:test:0" + assert redis_client.exists(key)