Skip to content

Commit

Permalink
Light GraphRAG (#4585)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored Jan 22, 2025
1 parent 1a36766 commit dd0ebbe
Show file tree
Hide file tree
Showing 55 changed files with 5,523 additions and 4,062 deletions.
15 changes: 12 additions & 3 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def set():
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, arr[0], arr[1], not any(
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))

v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
Expand Down Expand Up @@ -270,6 +270,7 @@ def retrieval_test():
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
tenant_ids = []

Expand Down Expand Up @@ -301,12 +302,20 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question)

labels = label_question(question, [kb])
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
Expand Down
2 changes: 1 addition & 1 deletion api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from api import settings
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor
from graphrag.general.mind_map_extractor import MindMapExtractor


@manager.route('/set', methods=['POST']) # noqa: F821
Expand Down
36 changes: 35 additions & 1 deletion api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json

from flask import request
from flask_login import login_required, current_user

Expand Down Expand Up @@ -272,4 +274,36 @@ def rename_tags(kb_id):
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)
return get_json_result(data=True)


@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
req = {
"kb_id": [kb_id],
"knowledge_graph_kwd": ["graph"]
}
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:1]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["content_with_weight"])
except Exception:
continue

obj[ty] = content_json

if "nodes" in obj["graph"]:
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
obj["graph"]["edges"] = sorted(obj["graph"]["edges"], key=lambda x: x.get("weight", 0), reverse=True)[:128]
return get_json_result(data=obj)
16 changes: 13 additions & 3 deletions api/apps/sdk/dify_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
from flask import request, jsonify

from api.db import LLMType, ParserType
from api.db import LLMType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
Expand All @@ -30,6 +30,7 @@ def retrieval(tenant_id):
req = request.json
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
Expand All @@ -45,8 +46,7 @@ def retrieval(tenant_id):

embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
ranks = settings.retrievaler.retrieval(
question,
embd_mdl,
kb.tenant_id,
Expand All @@ -58,6 +58,16 @@ def retrieval(tenant_id):
top=top,
rank_feature=label_question(question, [kb])
)

if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

records = []
for c in ranks["chunks"]:
c.pop("vector", None)
Expand Down
17 changes: 13 additions & 4 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,22 +1297,23 @@ def retrieval_test(tenant_id):
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids, list):
return get_error_data_result("`dataset_ids` should be a list")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',
code=settings.RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.DATA_ERROR,
)
if "question" not in req:
return get_error_data_result("`question` is required.")
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
question = req["question"]
doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False)
if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list")
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
Expand Down Expand Up @@ -1342,8 +1343,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)

retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
ranks = settings.retrievaler.retrieval(
question,
embd_mdl,
kb.tenant_id,
Expand All @@ -1358,6 +1358,15 @@ def retrieval_test(tenant_id):
highlight=highlight,
rank_feature=label_question(question, kbs)
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[k.tenant_id for k in kbs],
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

for c in ranks["chunks"]:
c.pop("vector", None)

Expand Down
2 changes: 1 addition & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def init_llm_factory():
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
Expand Down
11 changes: 9 additions & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def chat(dialog, messages, stream=True, **kwargs):

embedding_model_name = embedding_list[0]

is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
retriever = settings.retrievaler

questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
Expand Down Expand Up @@ -275,6 +274,14 @@ def chat(dialog, messages, stream=True, **kwargs):
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids,
dialog.kb_ids,
embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)

retrieval_ts = timer()

Expand Down
62 changes: 46 additions & 16 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from api.db.db_utils import bulk_insert_into_db
from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer
Expand Down Expand Up @@ -105,8 +105,19 @@ def insert(cls, doc):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id)
try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
return cls.delete_by_id(doc.id)

@classmethod
Expand Down Expand Up @@ -142,7 +153,7 @@ def get_newly_uploaded(cls):
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run]
cls.model.run, cls.model.parser_id]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
Expand Down Expand Up @@ -295,9 +306,9 @@ def get_chunking_config(cls, doc_id):
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
Expand Down Expand Up @@ -365,6 +376,12 @@ def begin2parse(cls, docid):
@classmethod
@DB.connection_context()
def update_progress(cls):
MSG = {
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
"graphrag": "Start Graph Extraction",
"graph_resolution": "Start Graph Resolution",
"graph_community": "Start Graph Community Reports Generation"
}
docs = cls.get_unfinished_docs()
for d in docs:
try:
Expand All @@ -390,15 +407,27 @@ def update_progress(cls):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
" raptor") < 0:
queue_raptor_tasks(d)
m = "\n".join(sorted(msg))
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("resolution") \
and m.find(MSG["graph_resolution"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("community") \
and m.find(MSG["graph_community"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ RAPTOR -------")
else:
status = TaskStatus.DONE.value

msg = "\n".join(msg)
msg = "\n".join(sorted(msg))
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
Expand Down Expand Up @@ -430,7 +459,7 @@ def do_cancel(cls, doc_id):
return False


def queue_raptor_tasks(doc):
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
Expand All @@ -443,15 +472,16 @@ def new_task():
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)."
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
}

task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["type"] = "raptor"
task["task_type"] = ty
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."


Expand Down Expand Up @@ -489,7 +519,7 @@ def dummy(prog=None, msg=""):
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}
Expand Down Expand Up @@ -592,4 +622,4 @@ def embedding(doc_id, cnts, batch_size=16):
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

return [d["id"] for d, _ in files]
return [d["id"] for d, _ in files]
2 changes: 1 addition & 1 deletion api/db/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def dummy(prog=None, msg=""):
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
Expand Down
Loading

0 comments on commit dd0ebbe

Please sign in to comment.