-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathquick_start.py
109 lines (91 loc) · 5.54 KB
/
quick_start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
from loguru import logger
from src.datasets.xinhua import get_task_datasets
from evaluator import BaseEvaluator
from src.llms import GPT
from src.llms import Qwen_7B_Chat
from src.tasks.summary import Summary
from src.tasks.continue_writing import ContinueWriting
from src.tasks.hallucinated_modified import HalluModified
from src.tasks.quest_answer import QuestAnswer1Doc, QuestAnswer2Docs, QuestAnswer3Docs
from src.retrievers import BaseRetriever, CustomBM25Retriever, EnsembleRetriever, EnsembleRerankRetriever
from src.embeddings.base import HuggingfaceEmbeddings
parser = argparse.ArgumentParser()
# Model related options
parser.add_argument('--model_name', default='qwen7b', help="Name of the model to use")
parser.add_argument('--temperature', type=float, default=0.1, help="Controls the randomness of the model's text generation")
parser.add_argument('--max_new_tokens', type=int, default=1280, help="Maximum number of new tokens to be generated by the model")
# Dataset related options
parser.add_argument('--data_path', default='data/crud_split/split_merged.json', help="Path to the dataset")
parser.add_argument('--shuffle', type=bool, default=True, help="Whether to shuffle the dataset")
parser.add_argument('--embedding_name', default='sentence-transformers/bge-base-zh-v1.5')
parser.add_argument('--embedding_dim', type=int, default=768)
# Index related options
parser.add_argument('--docs_path', default='data/tmp', help="Path to the retrieval documents")
parser.add_argument('--docs_type', default="txt", help="Type of the documents")
parser.add_argument('--chunk_size', type=int, default=128, help="Chunk size")
parser.add_argument('--chunk_overlap', type=int, default=0, help="Overlap chunk size")
parser.add_argument('--construct_index', action='store_true', help="Whether to construct an index")
parser.add_argument('--add_index', action='store_true', default=False, help="Whether to add an index")
parser.add_argument('--collection_name', default="docs_80k_chuncksize_128_0", help="Name of the collection")
# Retriever related options
parser.add_argument('--retrieve_top_k', type=int, default=8, help="Top k documents to retrieve")
parser.add_argument('--retriever_name', default="base", help="Name of the retriever")
# Metric related options
parser.add_argument('--quest_eval', action='store_true', help="Whether to use QA metrics(RAGQuestEval)")
parser.add_argument('--bert_score_eval', action='store_true', help="Whether to use bert_score metrics")
# Evaluation related options
parser.add_argument('--task', default='event_summary', help="Task to perform")
parser.add_argument('--num_threads', type=int, default=1, help="Number of threads")
parser.add_argument('--show_progress_bar', action='store', default=True, type=bool, help="Whether to show a progress bar")
parser.add_argument('--contain_original_data', action='store_true', help="Whether to contain original data")
args = parser.parse_args()
logger.info(args)
if args.model_name.startswith("gpt"):
llm = GPT(model_name=args.model_name, temperature=args.temperature, max_new_tokens=args.max_new_tokens)
elif args.model_name == "qwen7b":
llm = Qwen_7B_Chat(model_name=args.model_name, temperature=args.temperature, max_new_tokens=args.max_new_tokens)
embed_model = HuggingfaceEmbeddings(model_name=args.embedding_name)
if args.retriever_name == "base":
retriever = BaseRetriever(
args.docs_path, embed_model=embed_model, embed_dim=args.embedding_dim,
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap,
construct_index=args.construct_index, add_index=args.add_index,
collection_name=args.collection_name, similarity_top_k=args.retrieve_top_k
)
elif args.retriever_name == "bm25":
retriever = CustomBM25Retriever(
args.docs_path, embed_model=embed_model, chunk_size=args.chunk_size,
construct_index=args.construct_index,
chunk_overlap=args.chunk_overlap, similarity_top_k=args.retrieve_top_k
)
elif args.retriever_name == "hybrid":
retriever = EnsembleRetriever(
args.docs_path, embed_model=embed_model, embed_dim=args.embedding_dim,
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap,
construct_index=args.construct_index, add_index=args.add_index,
collection_name=args.collection_name, similarity_top_k=args.retrieve_top_k
)
elif args.retriever_name == "hybrid-rerank":
retriever = EnsembleRerankRetriever(
args.docs_path, embed_model=embed_model, embed_dim=args.embedding_dim,
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap,
construct_index=args.construct_index, add_index=args.add_index,
collection_name=args.collection_name, similarity_top_k=args.retrieve_top_k
)
else:
raise ValueError(f"Unknown retriever: {args.retriever_name}")
task_mapping = {
'event_summary':[Summary],
'continuing_writing': [ContinueWriting],
'hallu_modified': [HalluModified],
'quest_answer': [QuestAnswer1Doc, QuestAnswer2Docs, QuestAnswer3Docs],
'all': [Summary, ContinueWriting, HalluModified, QuestAnswer1Doc, QuestAnswer2Docs, QuestAnswer3Docs]
}
if args.task not in task_mapping:
raise ValueError(f"Unknown task: {args.task}")
tasks = [task(use_quest_eval=args.quest_eval, use_bert_score=args.bert_score_eval) for task in task_mapping[args.task]]
datasets = get_task_datasets(args.data_path, args.task)
for task, dataset in zip(tasks, datasets):
evaluator = BaseEvaluator(task, llm, retriever, dataset, num_threads=args.num_threads)
evaluator.run(show_progress_bar=args.show_progress_bar, contain_original_data=args.contain_original_data)