diff --git a/app/chatbot/model.py b/app/chatbot/model.py index c86c7fe..b8a264a 100644 --- a/app/chatbot/model.py +++ b/app/chatbot/model.py @@ -7,35 +7,16 @@ from app.embeddings.chunk_utils import * from app.utils.constants import ENGINE_NAME import pinecone - - -# Store the value of the VECTOR_DB environment variable in a variable -vector_db = os.getenv('VECTOR_DB').upper() - -if vector_db == "MILVUS": +if os.getenv('VECTOR_DB').upper() == "MILVUS": from pymilvus import connections, Collection import app.utils.vectordb.start_milvus as vector_db import app.embeddings.embeddings_utils as model_embedding -if vector_db == "PINECONE": +if os.getenv('VECTOR_DB').upper() == "PINECONE": from sentence_transformers import SentenceTransformer -class LazyModel: - def __init__(self, load_model_func): - self.load_model_func = load_model_func - self._model = None - - @property - def model(self): - if self._model is None: - self._model = self.load_model_func() # Model is loaded here if it hasn't been loaded already - return self._model - - def __call__(self, *args, **kwargs): - return self.model(*args, **kwargs) -# This creates a LazyModel instance that will load the model the first time it's used -llama2_model = LazyModel(load_llama_model) +llama2_model = load_llama_model() class TextInput(BaseModel): inputs: str @@ -78,7 +59,7 @@ def get_responses(engine, temperature, token_count, question): if token_count == "" or token_count is None: token_count = 100 - if vector_db == "MILVUS": + if os.getenv('VECTOR_DB').upper() == "MILVUS": # Load Milvus Vector DB collection vector_db_collection = Collection('cloudera_docs') vector_db_collection.load() @@ -86,11 +67,11 @@ def get_responses(engine, temperature, token_count, question): # Phase 1: Get nearest knowledge base chunk for a user question from a vector db vdb_question = question - if vector_db == "MILVUS": + if os.getenv('VECTOR_DB').upper() == "MILVUS": context_chunk = get_nearest_chunk_from_milvus_vectordb(vector_db_collection, vdb_question) vector_db_collection.release() - if vector_db == "PINECONE": + if os.getenv('VECTOR_DB').upper() == "PINECONE": PINECONE_API_KEY = os.getenv('PINECONE_API_KEY') PINECONE_ENVIRONMENT = os.getenv('PINECONE_ENVIRONMENT') PINECONE_INDEX = os.getenv('PINECONE_INDEX')