From 2006338f36962a14d7c1e007ce808470136eccc7 Mon Sep 17 00:00:00 2001 From: scicco Date: Sun, 3 Nov 2024 22:38:30 +0100 Subject: [PATCH] RecallSettings now extends cat.utils.BaseModelDict --- core/cat/looking_glass/recall_settings.py | 55 +++++++++++------------ core/cat/looking_glass/stray_cat.py | 17 ++++--- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/core/cat/looking_glass/recall_settings.py b/core/cat/looking_glass/recall_settings.py index 1bf4442f0..7c95c36a5 100644 --- a/core/cat/looking_glass/recall_settings.py +++ b/core/cat/looking_glass/recall_settings.py @@ -1,31 +1,30 @@ """Module for retrieving default configurations for episodic, declarative and procedural memories""" +from typing import Any +from cat.utils import BaseModelDict -class RecallSettings: - """Class for retrieving default configurations for episodic, declarative and procedural memories""" - - DEFAULT_K = 3 - DEFAULT_TRESHOLD = 0.5 - - def _build_settings( - self, - recall_query_embedding, - user_id=None, - k=DEFAULT_K, - treshold=DEFAULT_TRESHOLD, - ): - return { - "embedding": recall_query_embedding, - "k": k, - "threshold": treshold, - "metadata": {"source": user_id} if user_id else None, - } - - def default_episodic_config(self, recall_query_embedding, user_id): - return self._build_settings(recall_query_embedding, user_id) - - def default_declarative_config(self, recall_query_embedding): - return self._build_settings(recall_query_embedding) - - def default_procedural_config(self, recall_query_embedding): - return self._build_settings(recall_query_embedding) + +class RecallSettingsMetadata(BaseModelDict): + """Settigs's metadata for default configurations + + Variables: + source (str): the source of the recall query + """ + + source: str + + +class RecallSettings(BaseModelDict): + """Class for retrieving default configurations for episodic, declarative and procedural memories + + Variables: + embedding (Any): the embedding of the recall query - default None + k (int): the number of memories to return - default 3 + threshold (float): the threshold - default 0.5 + metadata (RecallSettingsMetadata): metadata - default None + """ + + embedding: Any + k: int = 3 + threshold: float = 0.5 + metadata: RecallSettingsMetadata = None diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index b93754e45..9e9d042f0 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -15,7 +15,7 @@ from cat.log import log from cat.looking_glass.cheshire_cat import CheshireCat from cat.looking_glass.callbacks import NewTokenHandler, ModelInteractionHandler -from cat.looking_glass.recall_settings import RecallSettings +from cat.looking_glass.recall_settings import RecallSettingsMetadata, RecallSettings from cat.memory.working_memory import WorkingMemory from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction from cat.agents import AgentOutput @@ -233,13 +233,18 @@ def recall_relevant_memories_to_working_memory(self, query=None): self.mad_hatter.execute_hook("before_cat_recalls_memories", cat=self) # Setting default recall configs for each memory - recall_settings = RecallSettings() - - default_episodic_recall_config = recall_settings.default_episodic_config(recall_query_embedding=recall_query_embedding, user_id=self.user_id) + default_episodic_recall_config = RecallSettings( + embedding=recall_query_embedding, + metadata=RecallSettingsMetadata(source=self.user_id), + ) - default_declarative_recall_config = recall_settings.default_declarative_config(recall_query_embedding=recall_query_embedding) + default_declarative_recall_config = RecallSettings( + embedding=recall_query_embedding + ) - default_procedural_recall_config = recall_settings.default_procedural_config(recall_query_embedding=recall_query_embedding) + default_procedural_recall_config = RecallSettings( + embedding=recall_query_embedding + ) # hooks to change recall configs for each memory recall_configs = [