Spaces:
Sleeping
Sleeping
Added retrieval num chunks options
Browse files- .gitignore +1 -0
- app.py +24 -82
- handler.py +0 -14
- input_reader.py +0 -22
- rag.py +12 -4
.gitignore
CHANGED
|
@@ -3,3 +3,4 @@
|
|
| 3 |
.env
|
| 4 |
__pycache__
|
| 5 |
__pycache__/*
|
|
|
|
|
|
| 3 |
.env
|
| 4 |
__pycache__
|
| 5 |
__pycache__/*
|
| 6 |
+
__DELETE__*
|
app.py
CHANGED
|
@@ -65,6 +65,8 @@ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k,
|
|
| 65 |
"temperature": temperature
|
| 66 |
}
|
| 67 |
|
|
|
|
|
|
|
| 68 |
output, context, source = generate(input_, model_parameters)
|
| 69 |
sources_markup = ""
|
| 70 |
|
|
@@ -87,13 +89,7 @@ def clear():
|
|
| 87 |
None,
|
| 88 |
None,
|
| 89 |
None,
|
| 90 |
-
gr.
|
| 91 |
-
gr.Slider(value=MAX_NEW_TOKENS),
|
| 92 |
-
gr.Slider(value=1.0),
|
| 93 |
-
gr.Slider(value=50),
|
| 94 |
-
gr.Slider(value=0.99),
|
| 95 |
-
gr.Checkbox(value=False),
|
| 96 |
-
gr.Slider(value=0.35),
|
| 97 |
)
|
| 98 |
|
| 99 |
|
|
@@ -102,25 +98,12 @@ def gradio_app():
|
|
| 102 |
# App Description
|
| 103 |
# =====================================================================================================================================
|
| 104 |
with gr.Row():
|
| 105 |
-
with gr.Column():
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# """# Demo de Retrieval-Augmented Generation per la Viquipèdia
|
| 109 |
-
# 🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
|
| 110 |
-
# en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
|
| 111 |
-
# fent servir només la informació existent en els documents del repositori.
|
| 112 |
-
|
| 113 |
-
# 🎯 **Objectiu:** Aquest és un demostrador amb Viquipèdia i genera la resposta fent servir el model salamandra-7b-instruct.
|
| 114 |
-
|
| 115 |
-
# ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
|
| 116 |
-
# Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
|
| 117 |
-
# contacteu amb nosaltres a Langtech.
|
| 118 |
-
# """
|
| 119 |
-
)
|
| 120 |
|
| 121 |
-
|
| 122 |
-
# with gr.Row(equal_height=True):
|
| 123 |
with gr.Row(equal_height=False):
|
|
|
|
| 124 |
# User Input
|
| 125 |
# =====================================================================================================================================
|
| 126 |
with gr.Column(scale=2, variant="panel"):
|
|
@@ -131,69 +114,25 @@ def gradio_app():
|
|
| 131 |
placeholder="Qui va crear la guerra de les Galaxies ?",
|
| 132 |
)
|
| 133 |
|
| 134 |
-
|
| 135 |
-
# with gr.Column(variant="panel"):
|
| 136 |
with gr.Row(variant="default"):
|
| 137 |
-
# with gr.Row(variant="panel"):
|
| 138 |
clear_btn = Button("Clear",)
|
| 139 |
submit_btn = Button("Submit", variant="primary", interactive=False)
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
with gr.Accordion("Model parameters (not used)", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
| 144 |
-
num_chunks = Slider(
|
| 145 |
-
minimum=1,
|
| 146 |
-
maximum=6,
|
| 147 |
-
step=1,
|
| 148 |
-
value=5,
|
| 149 |
-
label="Number of chunks"
|
| 150 |
-
)
|
| 151 |
-
max_new_tokens = Slider(
|
| 152 |
-
minimum=50,
|
| 153 |
-
maximum=2000,
|
| 154 |
-
step=1,
|
| 155 |
-
value=MAX_NEW_TOKENS,
|
| 156 |
-
label="Max tokens"
|
| 157 |
-
)
|
| 158 |
-
repetition_penalty = Slider(
|
| 159 |
-
minimum=0.1,
|
| 160 |
-
maximum=2.0,
|
| 161 |
-
step=0.1,
|
| 162 |
-
value=1.0,
|
| 163 |
-
label="Repetition penalty"
|
| 164 |
-
)
|
| 165 |
-
top_k = Slider(
|
| 166 |
-
minimum=1,
|
| 167 |
-
maximum=100,
|
| 168 |
-
step=1,
|
| 169 |
-
value=50,
|
| 170 |
-
label="Top k"
|
| 171 |
-
)
|
| 172 |
-
top_p = Slider(
|
| 173 |
-
minimum=0.01,
|
| 174 |
-
maximum=0.99,
|
| 175 |
-
value=0.99,
|
| 176 |
-
label="Top p"
|
| 177 |
-
)
|
| 178 |
-
do_sample = Checkbox(
|
| 179 |
-
value=False,
|
| 180 |
-
label="Do sample"
|
| 181 |
-
)
|
| 182 |
-
temperature = Slider(
|
| 183 |
-
minimum=0.1,
|
| 184 |
-
maximum=1,
|
| 185 |
-
value=0.35,
|
| 186 |
-
label="Temperature"
|
| 187 |
-
)
|
| 188 |
|
| 189 |
-
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
|
| 190 |
|
| 191 |
# Add Examples manually
|
| 192 |
-
gr.Examples(
|
| 193 |
-
examples=[
|
| 194 |
["Qui va crear la guerra de les Galaxies?"],
|
| 195 |
["Quin era el nom real de Voltaire?"],
|
| 196 |
-
["Què fan al BSC?"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
],
|
| 198 |
inputs=[input_], # only inputs
|
| 199 |
)
|
|
@@ -246,14 +185,16 @@ def gradio_app():
|
|
| 246 |
clear_btn.click(
|
| 247 |
fn=clear,
|
| 248 |
inputs=[],
|
| 249 |
-
outputs=[input_, output, source_context, context_evaluation
|
| 250 |
-
|
| 251 |
-
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
submit_btn.click(
|
| 255 |
fn=submit_input,
|
| 256 |
-
inputs=[input_]+ parameters_compontents,
|
|
|
|
| 257 |
outputs=[output, source_context, context_evaluation],
|
| 258 |
api_name="get-results"
|
| 259 |
)
|
|
@@ -269,6 +210,7 @@ def gradio_app():
|
|
| 269 |
# fn=submit_input,
|
| 270 |
# )
|
| 271 |
|
|
|
|
| 272 |
demo.launch(show_api=True)
|
| 273 |
|
| 274 |
|
|
|
|
| 65 |
"temperature": temperature
|
| 66 |
}
|
| 67 |
|
| 68 |
+
print("Model parameters: ", model_parameters)
|
| 69 |
+
|
| 70 |
output, context, source = generate(input_, model_parameters)
|
| 71 |
sources_markup = ""
|
| 72 |
|
|
|
|
| 89 |
None,
|
| 90 |
None,
|
| 91 |
None,
|
| 92 |
+
gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
|
|
|
|
| 98 |
# App Description
|
| 99 |
# =====================================================================================================================================
|
| 100 |
with gr.Row():
|
| 101 |
+
with gr.Column():
|
| 102 |
+
gr.Markdown("""# Demo de Retrieval (only) Viquipèdia""")
|
| 103 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
|
|
|
|
|
|
| 105 |
with gr.Row(equal_height=False):
|
| 106 |
+
|
| 107 |
# User Input
|
| 108 |
# =====================================================================================================================================
|
| 109 |
with gr.Column(scale=2, variant="panel"):
|
|
|
|
| 114 |
placeholder="Qui va crear la guerra de les Galaxies ?",
|
| 115 |
)
|
| 116 |
|
|
|
|
|
|
|
| 117 |
with gr.Row(variant="default"):
|
|
|
|
| 118 |
clear_btn = Button("Clear",)
|
| 119 |
submit_btn = Button("Submit", variant="primary", interactive=False)
|
| 120 |
|
| 121 |
+
with gr.Row(variant="default"):
|
| 122 |
+
num_chunks = gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
|
|
|
| 124 |
|
| 125 |
# Add Examples manually
|
| 126 |
+
gr.Examples( examples=[
|
|
|
|
| 127 |
["Qui va crear la guerra de les Galaxies?"],
|
| 128 |
["Quin era el nom real de Voltaire?"],
|
| 129 |
+
["Què fan al BSC?"],
|
| 130 |
+
|
| 131 |
+
# No existèix aquesta entrada a la VDB
|
| 132 |
+
# https://ca.wikipedia.org/wiki/Imperi_Gal%C3%A0ctic
|
| 133 |
+
# ["Què és un Imperi Galàctic?"],
|
| 134 |
+
# ["Què és l'Imperi Galàctic d'Isaac Asimov?"],
|
| 135 |
+
# ["Què és l'Imperi Galàctic de la Guerra de les Galàxies?"]
|
| 136 |
],
|
| 137 |
inputs=[input_], # only inputs
|
| 138 |
)
|
|
|
|
| 185 |
clear_btn.click(
|
| 186 |
fn=clear,
|
| 187 |
inputs=[],
|
| 188 |
+
outputs=[input_, output, source_context, context_evaluation, num_chunks],
|
| 189 |
+
# outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
|
| 190 |
+
queue=False,
|
| 191 |
+
api_name=False
|
| 192 |
)
|
| 193 |
|
| 194 |
submit_btn.click(
|
| 195 |
fn=submit_input,
|
| 196 |
+
# inputs=[input_] + parameters_compontents,
|
| 197 |
+
inputs=[input_] + [num_chunks],
|
| 198 |
outputs=[output, source_context, context_evaluation],
|
| 199 |
api_name="get-results"
|
| 200 |
)
|
|
|
|
| 210 |
# fn=submit_input,
|
| 211 |
# )
|
| 212 |
|
| 213 |
+
# input_, output, source_context, context_evaluation, num_chunks = clear()
|
| 214 |
demo.launch(show_api=True)
|
| 215 |
|
| 216 |
|
handler.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
|
| 3 |
-
class ContentHandler():
|
| 4 |
-
content_type = "application/json"
|
| 5 |
-
accepts = "application/json"
|
| 6 |
-
|
| 7 |
-
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
|
| 8 |
-
input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
|
| 9 |
-
return input_str.encode('utf-8')
|
| 10 |
-
|
| 11 |
-
def transform_output(self, output: bytes) -> str:
|
| 12 |
-
response_json = json.loads(output.read().decode("utf-8"))
|
| 13 |
-
return response_json[0]["generated_text"]
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_reader.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
|
| 3 |
-
from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
| 4 |
-
from llama_index.core.readers import SimpleDirectoryReader
|
| 5 |
-
from llama_index.core.schema import Document
|
| 6 |
-
from llama_index.core import Settings
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class InputReader:
|
| 10 |
-
def __init__(self, input_dir: str) -> None:
|
| 11 |
-
self.reader = SimpleDirectoryReader(input_dir=input_dir)
|
| 12 |
-
|
| 13 |
-
def parse_documents(
|
| 14 |
-
self,
|
| 15 |
-
show_progress: bool = True,
|
| 16 |
-
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
| 17 |
-
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
| 18 |
-
) -> List[Document]:
|
| 19 |
-
Settings.chunk_size = chunk_size
|
| 20 |
-
Settings.chunk_overlap = chunk_overlap
|
| 21 |
-
documents = self.reader.load_data(show_progress=show_progress)
|
| 22 |
-
return documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag.py
CHANGED
|
@@ -42,6 +42,7 @@ class RAG:
|
|
| 42 |
logging.info("RAG loaded!")
|
| 43 |
logging.info( self.vectore_store)
|
| 44 |
|
|
|
|
| 45 |
def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
|
| 46 |
"""
|
| 47 |
Rerank the contexts based on their relevance to the given instruction.
|
|
@@ -86,21 +87,28 @@ class RAG:
|
|
| 86 |
|
| 87 |
logging.info("RETRIEVE DOCUMENTS")
|
| 88 |
logging.info(f"Instruction: {instruction}")
|
|
|
|
|
|
|
|
|
|
| 89 |
embedding = self.vectore_store._embed_query(instruction)
|
| 90 |
logging.info(f"Query embedding generated: {len(embedding)}")
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
logging.info(f"Documents retrieved: {len(documents_retrieved)}")
|
| 95 |
|
| 96 |
-
# documents_retrieved = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
|
| 97 |
|
|
|
|
|
|
|
| 98 |
if self.rerank_model:
|
| 99 |
logging.info("RERANK DOCUMENTS")
|
| 100 |
documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
|
| 101 |
else:
|
| 102 |
logging.info("NO RERANKING")
|
| 103 |
documents_reranked = documents_retrieved[:number_of_contexts]
|
|
|
|
| 104 |
|
| 105 |
return documents_reranked
|
| 106 |
|
|
|
|
| 42 |
logging.info("RAG loaded!")
|
| 43 |
logging.info( self.vectore_store)
|
| 44 |
|
| 45 |
+
|
| 46 |
def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
|
| 47 |
"""
|
| 48 |
Rerank the contexts based on their relevance to the given instruction.
|
|
|
|
| 87 |
|
| 88 |
logging.info("RETRIEVE DOCUMENTS")
|
| 89 |
logging.info(f"Instruction: {instruction}")
|
| 90 |
+
|
| 91 |
+
# Embed the query
|
| 92 |
+
# ==============================================================================================================
|
| 93 |
embedding = self.vectore_store._embed_query(instruction)
|
| 94 |
logging.info(f"Query embedding generated: {len(embedding)}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Retrieve documents
|
| 98 |
+
# ==============================================================================================================
|
| 99 |
+
documents_retrieved = self.vectore_store.similarity_search_with_score_by_vector(embedding, k=number_of_contexts)
|
| 100 |
logging.info(f"Documents retrieved: {len(documents_retrieved)}")
|
| 101 |
|
|
|
|
| 102 |
|
| 103 |
+
# Reranking
|
| 104 |
+
# ==============================================================================================================
|
| 105 |
if self.rerank_model:
|
| 106 |
logging.info("RERANK DOCUMENTS")
|
| 107 |
documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
|
| 108 |
else:
|
| 109 |
logging.info("NO RERANKING")
|
| 110 |
documents_reranked = documents_retrieved[:number_of_contexts]
|
| 111 |
+
# ==============================================================================================================
|
| 112 |
|
| 113 |
return documents_reranked
|
| 114 |
|