From 8ea5a272f433a9fcd8a1da2096d0695ebc574a62 Mon Sep 17 00:00:00 2001 From: thammuio Date: Thu, 4 Jan 2024 15:08:06 -0500 Subject: [PATCH] feat: lazyload model and performance --- app/chatbot/model.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/app/chatbot/model.py b/app/chatbot/model.py index b8a264a..c86c7fe 100644 --- a/app/chatbot/model.py +++ b/app/chatbot/model.py @@ -7,16 +7,35 @@ from app.embeddings.chunk_utils import * from app.utils.constants import ENGINE_NAME import pinecone -if os.getenv('VECTOR_DB').upper() == "MILVUS": + + +# Store the value of the VECTOR_DB environment variable in a variable +vector_db = os.getenv('VECTOR_DB').upper() + +if vector_db == "MILVUS": from pymilvus import connections, Collection import app.utils.vectordb.start_milvus as vector_db import app.embeddings.embeddings_utils as model_embedding -if os.getenv('VECTOR_DB').upper() == "PINECONE": +if vector_db == "PINECONE": from sentence_transformers import SentenceTransformer +class LazyModel: + def __init__(self, load_model_func): + self.load_model_func = load_model_func + self._model = None + + @property + def model(self): + if self._model is None: + self._model = self.load_model_func() # Model is loaded here if it hasn't been loaded already + return self._model + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) -llama2_model = load_llama_model() +# This creates a LazyModel instance that will load the model the first time it's used +llama2_model = LazyModel(load_llama_model) class TextInput(BaseModel): inputs: str @@ -59,7 +78,7 @@ def get_responses(engine, temperature, token_count, question): if token_count == "" or token_count is None: token_count = 100 - if os.getenv('VECTOR_DB').upper() == "MILVUS": + if vector_db == "MILVUS": # Load Milvus Vector DB collection vector_db_collection = Collection('cloudera_docs') vector_db_collection.load() @@ -67,11 +86,11 @@ def get_responses(engine, temperature, token_count, question): # Phase 1: Get nearest knowledge base chunk for a user question from a vector db vdb_question = question - if os.getenv('VECTOR_DB').upper() == "MILVUS": + if vector_db == "MILVUS": context_chunk = get_nearest_chunk_from_milvus_vectordb(vector_db_collection, vdb_question) vector_db_collection.release() - if os.getenv('VECTOR_DB').upper() == "PINECONE": + if vector_db == "PINECONE": PINECONE_API_KEY = os.getenv('PINECONE_API_KEY') PINECONE_ENVIRONMENT = os.getenv('PINECONE_ENVIRONMENT') PINECONE_INDEX = os.getenv('PINECONE_INDEX')