Skip to content

Commit

Permalink
fix: model lazy load
Browse files Browse the repository at this point in the history
  • Loading branch information
thammuio committed Jan 4, 2024
1 parent 352d006 commit 941f222
Showing 1 changed file with 6 additions and 25 deletions.
31 changes: 6 additions & 25 deletions app/chatbot/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,16 @@
from app.embeddings.chunk_utils import *
from app.utils.constants import ENGINE_NAME
import pinecone


# Store the value of the VECTOR_DB environment variable in a variable
vector_db = os.getenv('VECTOR_DB').upper()

if vector_db == "MILVUS":
if os.getenv('VECTOR_DB').upper() == "MILVUS":
from pymilvus import connections, Collection
import app.utils.vectordb.start_milvus as vector_db
import app.embeddings.embeddings_utils as model_embedding

if vector_db == "PINECONE":
if os.getenv('VECTOR_DB').upper() == "PINECONE":
from sentence_transformers import SentenceTransformer

class LazyModel:
def __init__(self, load_model_func):
self.load_model_func = load_model_func
self._model = None

@property
def model(self):
if self._model is None:
self._model = self.load_model_func() # Model is loaded here if it hasn't been loaded already
return self._model

def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)

# This creates a LazyModel instance that will load the model the first time it's used
llama2_model = LazyModel(load_llama_model)
llama2_model = load_llama_model()

class TextInput(BaseModel):
inputs: str
Expand Down Expand Up @@ -78,19 +59,19 @@ def get_responses(engine, temperature, token_count, question):
if token_count == "" or token_count is None:
token_count = 100

if vector_db == "MILVUS":
if os.getenv('VECTOR_DB').upper() == "MILVUS":
# Load Milvus Vector DB collection
vector_db_collection = Collection('cloudera_docs')
vector_db_collection.load()

# Phase 1: Get nearest knowledge base chunk for a user question from a vector db
vdb_question = question

if vector_db == "MILVUS":
if os.getenv('VECTOR_DB').upper() == "MILVUS":
context_chunk = get_nearest_chunk_from_milvus_vectordb(vector_db_collection, vdb_question)
vector_db_collection.release()

if vector_db == "PINECONE":
if os.getenv('VECTOR_DB').upper() == "PINECONE":
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
PINECONE_ENVIRONMENT = os.getenv('PINECONE_ENVIRONMENT')
PINECONE_INDEX = os.getenv('PINECONE_INDEX')
Expand Down

0 comments on commit 941f222

Please sign in to comment.