From 15dbaca1b0dc32f141b0d9f0c1b0d37b438f182e Mon Sep 17 00:00:00 2001 From: gvanrossum-ms Date: Sat, 8 Mar 2025 11:14:57 -0800 Subject: [PATCH] [python/knowpro] Change class structure to suit Python's type system better (#803) (This is only for the Python port.) Change TMeta into a base class. We then get: - IMessage has IKnowledgeSource as a base class, and no longer has a metadata field -- you just call message.get_knowledge(). - IConversation is generic in the message type (TMessage) and its messages field is a list of TMessage. - add_metadata_to_index() is generic in the message type (I think there's a pattern here :-). - PodcastMessageMeta becomes base class PodcastMessageBase (but it could easily be just blended into PodcastMessage). - PodcastMessage derives both from PodcastMessageBase and IMessage. - Podcast derives from IConversation[PodcastMessage] and its messages field is a list of PodcastMessage. - assign_message() takes a sequence of PodcastMessage. --- python/kp/knowpro/convindex.py | 4 +-- python/kp/knowpro/interfaces.py | 9 ++---- python/kp/memconv/import_podcasts.py | 45 +++++++++++----------------- 3 files changed, 23 insertions(+), 35 deletions(-) diff --git a/python/kp/knowpro/convindex.py b/python/kp/knowpro/convindex.py index 2594cb5db..f66cfd34d 100644 --- a/python/kp/knowpro/convindex.py +++ b/python/kp/knowpro/convindex.py @@ -42,8 +42,8 @@ def text_range_from_location( ] -def add_metadata_to_index( - messages: list[IMessage], +def add_metadata_to_index[TMessage: IMessage]( + messages: list[TMessage], semantic_refs: list[SemanticRef], semantic_ref_index: ITermToSemanticRefIndex, knowledge_validator: KnowledgeValidator | None = None, diff --git a/python/kp/knowpro/interfaces.py b/python/kp/knowpro/interfaces.py index d85201c8a..a0632f6bc 100644 --- a/python/kp/knowpro/interfaces.py +++ b/python/kp/knowpro/interfaces.py @@ -36,12 +36,9 @@ class DeletionInfo(Protocol): @runtime_checkable -class IMessage[TMeta: IKnowledgeSource = Any](Protocol): +class IMessage(IKnowledgeSource, Protocol): # The text of the message, split into chunks. text_chunks: list[str] - # For example, e-mail has subject, from and to fields; - # a chat message has a sender and a recipient. - metadata: TMeta timestamp: str | None = None tags: list[str] deletion_info: DeletionInfo | None = None @@ -270,10 +267,10 @@ class IConversationSecondaryIndexes(Protocol): @runtime_checkable -class IConversation[TMeta: IKnowledgeSource = Any](Protocol): +class IConversation[TMessage: IMessage = Any](Protocol): name_tag: str tags: list[str] - messages: list[IMessage[TMeta]] + messages: list[TMessage] semantic_refs: list[SemanticRef] | None semantic_ref_index: ITermToSemanticRefIndex | None secondary_indexes: IConversationSecondaryIndexes | None diff --git a/python/kp/memconv/import_podcasts.py b/python/kp/memconv/import_podcasts.py index b4277e8cc..8bf7427cf 100644 --- a/python/kp/memconv/import_podcasts.py +++ b/python/kp/memconv/import_podcasts.py @@ -11,16 +11,11 @@ @dataclass -class PodcastMessageMeta(interfaces.IKnowledgeSource): - """Metadata for podcast messages.""" +class PodcastMessageBase(interfaces.IKnowledgeSource): + """Base class for podcast messages.""" - # Instance variables types. speaker: str - listeners: list[str] - - def __init__(self, speaker: str): - self.speaker = speaker - self.listeners = [] + listeners: list[str] = field(init=False, default_factory=list) def get_knowledge(self) -> kplib.KnowledgeResponse: if not self.speaker: @@ -64,21 +59,10 @@ def get_knowledge(self) -> kplib.KnowledgeResponse: ) -def assign_message_listeners( - msgs: Sequence[interfaces.IMessage[PodcastMessageMeta]], - participants: set[str], -) -> None: - for msg in msgs: - if msg.metadata.speaker: - listeners = [p for p in participants if p != msg.metadata.speaker] - msg.metadata.listeners = listeners - - @dataclass -class PodcastMessage(interfaces.IMessage[PodcastMessageMeta]): +class PodcastMessage(interfaces.IMessage, PodcastMessageBase): timestamp: str | None = field(init=False, default=None) text_chunks: list[str] - metadata: PodcastMessageMeta tags: list[str] = field(default_factory=list) def add_timestamp(self, timestamp: str) -> None: @@ -89,7 +73,7 @@ def add_content(self, content: str) -> None: @dataclass -class Podcast(interfaces.IConversation[PodcastMessageMeta]): +class Podcast(interfaces.IConversation[PodcastMessage]): # Instance variables not passed to `__init__()`. # TODO # settings: ConversationSettings = field( @@ -107,10 +91,7 @@ class Podcast(interfaces.IConversation[PodcastMessageMeta]): # __init__() parameters, in that order (via `@dataclass`). name_tag: str = field(default="") - # NOTE: `messages: list[PodcastMessage]` doesn't work because of invariance. - messages: list[interfaces.IMessage[PodcastMessageMeta]] = field( - default_factory=list - ) + messages: list[PodcastMessage] = field(default_factory=list) tags: list[str] = field(default_factory=list) semantic_refs: list[interfaces.SemanticRef] | None = field(default_factory=list) @@ -181,6 +162,16 @@ async def build_index( # pass +def assign_message_listeners( + msgs: Sequence[PodcastMessage], + participants: set[str], +) -> None: + for msg in msgs: + if msg.speaker: + listeners = [p for p in participants if p != msg.speaker] + msg.listeners = listeners + + # NOTE: Doesn't need to be async (Python file I/O is synchronous) def import_podcast( transcript_file_path: str, @@ -195,7 +186,7 @@ def import_podcast( transcript_lines = [line.rstrip() for line in transcript_lines if line.strip()] turn_parse_regex = re.compile(r"^(?P[A-Z0-9 ]+:)?(?P.*)$") participants: set[str] = set() - msgs: list[interfaces.IMessage[PodcastMessageMeta]] = [] + msgs: list[PodcastMessage] = [] cur_msg: PodcastMessage | None = None for line in transcript_lines: match = turn_parse_regex.match(line) @@ -215,7 +206,7 @@ def import_podcast( speaker = speaker[:-1] speaker = speaker.lower() # TODO: locale participants.add(speaker) - cur_msg = PodcastMessage([speech], PodcastMessageMeta(speaker)) + cur_msg = PodcastMessage(speaker, [speech]) if cur_msg: msgs.append(cur_msg) assign_message_listeners(msgs, participants)