Skip to content

Commit

Permalink
[python/knowpro] Change class structure to suit Python's type system …
Browse files Browse the repository at this point in the history
…better (#803)

(This is only for the Python port.)

Change TMeta into a base class. We then get:

- IMessage has IKnowledgeSource as a base class, and no longer has a
metadata field -- you just call message.get_knowledge().
- IConversation is generic in the message type (TMessage) and its
messages field is a list of TMessage.
- add_metadata_to_index() is generic in the message type (I think
there's a pattern here :-).
- PodcastMessageMeta becomes base class PodcastMessageBase (but it could
easily be just blended into PodcastMessage).
- PodcastMessage derives both from PodcastMessageBase and IMessage.
- Podcast derives from IConversation[PodcastMessage] and its messages
field is a list of PodcastMessage.
- assign_message() takes a sequence of PodcastMessage.
  • Loading branch information
gvanrossum-ms authored Mar 8, 2025
1 parent c685f0c commit 15dbaca
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 35 deletions.
4 changes: 2 additions & 2 deletions python/kp/knowpro/convindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def text_range_from_location(
]


def add_metadata_to_index(
messages: list[IMessage],
def add_metadata_to_index[TMessage: IMessage](
messages: list[TMessage],
semantic_refs: list[SemanticRef],
semantic_ref_index: ITermToSemanticRefIndex,
knowledge_validator: KnowledgeValidator | None = None,
Expand Down
9 changes: 3 additions & 6 deletions python/kp/knowpro/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ class DeletionInfo(Protocol):


@runtime_checkable
class IMessage[TMeta: IKnowledgeSource = Any](Protocol):
class IMessage(IKnowledgeSource, Protocol):
# The text of the message, split into chunks.
text_chunks: list[str]
# For example, e-mail has subject, from and to fields;
# a chat message has a sender and a recipient.
metadata: TMeta
timestamp: str | None = None
tags: list[str]
deletion_info: DeletionInfo | None = None
Expand Down Expand Up @@ -270,10 +267,10 @@ class IConversationSecondaryIndexes(Protocol):


@runtime_checkable
class IConversation[TMeta: IKnowledgeSource = Any](Protocol):
class IConversation[TMessage: IMessage = Any](Protocol):
name_tag: str
tags: list[str]
messages: list[IMessage[TMeta]]
messages: list[TMessage]
semantic_refs: list[SemanticRef] | None
semantic_ref_index: ITermToSemanticRefIndex | None
secondary_indexes: IConversationSecondaryIndexes | None
Expand Down
45 changes: 18 additions & 27 deletions python/kp/memconv/import_podcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,11 @@


@dataclass
class PodcastMessageMeta(interfaces.IKnowledgeSource):
"""Metadata for podcast messages."""
class PodcastMessageBase(interfaces.IKnowledgeSource):
"""Base class for podcast messages."""

# Instance variables types.
speaker: str
listeners: list[str]

def __init__(self, speaker: str):
self.speaker = speaker
self.listeners = []
listeners: list[str] = field(init=False, default_factory=list)

def get_knowledge(self) -> kplib.KnowledgeResponse:
if not self.speaker:
Expand Down Expand Up @@ -64,21 +59,10 @@ def get_knowledge(self) -> kplib.KnowledgeResponse:
)


def assign_message_listeners(
msgs: Sequence[interfaces.IMessage[PodcastMessageMeta]],
participants: set[str],
) -> None:
for msg in msgs:
if msg.metadata.speaker:
listeners = [p for p in participants if p != msg.metadata.speaker]
msg.metadata.listeners = listeners


@dataclass
class PodcastMessage(interfaces.IMessage[PodcastMessageMeta]):
class PodcastMessage(interfaces.IMessage, PodcastMessageBase):
timestamp: str | None = field(init=False, default=None)
text_chunks: list[str]
metadata: PodcastMessageMeta
tags: list[str] = field(default_factory=list)

def add_timestamp(self, timestamp: str) -> None:
Expand All @@ -89,7 +73,7 @@ def add_content(self, content: str) -> None:


@dataclass
class Podcast(interfaces.IConversation[PodcastMessageMeta]):
class Podcast(interfaces.IConversation[PodcastMessage]):
# Instance variables not passed to `__init__()`.
# TODO
# settings: ConversationSettings = field(
Expand All @@ -107,10 +91,7 @@ class Podcast(interfaces.IConversation[PodcastMessageMeta]):

# __init__() parameters, in that order (via `@dataclass`).
name_tag: str = field(default="")
# NOTE: `messages: list[PodcastMessage]` doesn't work because of invariance.
messages: list[interfaces.IMessage[PodcastMessageMeta]] = field(
default_factory=list
)
messages: list[PodcastMessage] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
semantic_refs: list[interfaces.SemanticRef] | None = field(default_factory=list)

Expand Down Expand Up @@ -181,6 +162,16 @@ async def build_index(
# pass


def assign_message_listeners(
msgs: Sequence[PodcastMessage],
participants: set[str],
) -> None:
for msg in msgs:
if msg.speaker:
listeners = [p for p in participants if p != msg.speaker]
msg.listeners = listeners


# NOTE: Doesn't need to be async (Python file I/O is synchronous)
def import_podcast(
transcript_file_path: str,
Expand All @@ -195,7 +186,7 @@ def import_podcast(
transcript_lines = [line.rstrip() for line in transcript_lines if line.strip()]
turn_parse_regex = re.compile(r"^(?P<speaker>[A-Z0-9 ]+:)?(?P<speech>.*)$")
participants: set[str] = set()
msgs: list[interfaces.IMessage[PodcastMessageMeta]] = []
msgs: list[PodcastMessage] = []
cur_msg: PodcastMessage | None = None
for line in transcript_lines:
match = turn_parse_regex.match(line)
Expand All @@ -215,7 +206,7 @@ def import_podcast(
speaker = speaker[:-1]
speaker = speaker.lower() # TODO: locale
participants.add(speaker)
cur_msg = PodcastMessage([speech], PodcastMessageMeta(speaker))
cur_msg = PodcastMessage(speaker, [speech])
if cur_msg:
msgs.append(cur_msg)
assign_message_listeners(msgs, participants)
Expand Down

0 comments on commit 15dbaca

Please sign in to comment.