Auto-Retrieval: RAG的智能进化

人工智能
我们借助LlamaCloud来实现,主要通过在LlamaCloud检索器上设置一个Auto-Retrieval功能。在高层次上,我们的自动检索函数使用一个调用函数的LLM来推断用户查询的元数据过滤器——比仅仅使用原始语义查询产生更精确和相关的检索结果。

Auto-Retrieval是一种高级的RAG技术,它在启动向量数据库检索之前使用Agent LLM动态推断元数据过滤器参数和语义查询,而不是将用户查询直接发送到向量数据库检索接口(例如密集向量搜索)的朴素RAG。您可以将其视为查询扩展/重写的一种形式,也可以将其视为函数调用的一种特定形式;后文我们将给出实现逻辑和代码。达到效果如下:

用户输入

Give me a summary of the SWE-bench paper

推理结果

改写查询: summary of the SWE-bench paper
过滤参数: {"filters": [{"key": "file_name", "value": "swebench.pdf", "operator": "=="}], "condition": "and"}

实现步骤

我们借助LlamaCloud来实现,主要通过在LlamaCloud检索器上设置一个Auto-Retrieval功能。在高层次上,我们的自动检索函数使用一个调用函数的LLM来推断用户查询的元数据过滤器——比仅仅使用原始语义查询产生更精确和相关的检索结果。
  • 定义一个自定义prompt来生成元数据过滤器
  • 给定一个用户查询,首先执行块级检索,从检索到的块中动态召回元数据。
  • 在auto-retrieval prompt中注入元数据作为少量示例。目的是向LLM展示现有的、相关的元数据值示例,以便LLM可以推断出正确的元数据过滤器。

文档级检索器返回整个文件级别的文档,而块级检索器返回特定的块,实现如此简单。

from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
import os


index = LlamaCloudIndex(
  name="research_papers_page",
  project_name="llamacloud_demo",
  api_key=os.environ["LLAMA_CLOUD_API_KEY"]
)


doc_retriever = index.as_retriever(
    retrieval_mode="files_via_content",
    # retrieval_mode="files_via_metadata",
    files_top_k=1
)


chunk_retriever = index.as_retriever(
    retrieval_mode="chunks",
    rerank_top_n=5
)

代码实现

接下来我们将根据上面的流程给出实现代码:

from llama_index.core.prompts import ChatPromptTemplate
from llama_index.core.vector_stores.types import VectorStoreInfo, VectorStoreQuerySpec, MetadataInfo, MetadataFilters
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import Response


import json


SYS_PROMPT = """\
Your goal is to structure the user's query to match the request schema provided below.
You MUST call the tool in order to generate the query spec.


<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:


{schema_str}


The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.


Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters take into account the descriptions of attributes.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return [] for the filter value.\


If the user's query explicitly mentions number of documents to retrieve, set top_k to \
that number, otherwise do not set top_k.


The schema of the metadata filters in the vector db table is listed below, along with some example metadata dictionaries from relevant rows.
The user will send the input query string.


Data Source:
```json
{info_str}
```


Example metadata from relevant chunks:
{example_rows}


"""


example_rows_retriever = index.as_retriever(
    retrieval_mode="chunks",
    rerank_top_n=4
)


def get_example_rows_fn(**kwargs):
    """Retrieve relevant few-shot examples."""
    query_str = kwargs["query_str"]
    nodes = example_rows_retriever.retrieve(query_str)
    # get the metadata, join them
    metadata_list = [n.metadata for n in nodes]


    return "\n".join([json.dumps(m) for m in metadata_list])
        
    


# TODO: define function mapping for `example_rows`.
chat_prompt_tmpl = ChatPromptTemplate.from_messages(
    [
        ("system", SYS_PROMPT),
        ("user", "{query_str}"),
    ],
    function_mappings={
        "example_rows": get_example_rows_fn
    }
)




## NOTE: this is a dataclass that contains information about the metadata
vector_store_info = VectorStoreInfo(
    content_info="contains content from various research papers",
    metadata_info=[
        MetadataInfo(
            name="file_name",
            type="str",
            description="Name of the source paper",
        ),
    ],
)


def auto_retriever_rag(query: str, retriever: BaseRetriever) -> Response:
    """Synthesizes an answer to your question by feeding in an entire relevant document as context."""
    print(f"> User query string: {query}")
    # Use structured predict to infer the metadata filters and query string.
    query_spec = llm.structured_predict(
        VectorStoreQuerySpec,
        chat_prompt_tmpl,
        info_str=vector_store_info.json(indent=4),
        schema_str=VectorStoreQuerySpec.schema_json(indent=4),
        query_str=query
    )
    # build retriever and query engine
    filters = MetadataFilters(filters=query_spec.filters) if len(query_spec.filters) > 0 else None
    print(f"> Inferred query string: {query_spec.query}")
    if filters:
        print(f"> Inferred filters: {filters.json()}")
    query_engine = RetrieverQueryEngine.from_args(
        retriever, 
        llm=llm,
        response_mode="tree_summarize"
    )
    # run query
    return query_engine.query(query_spec.query)

效果展示

auto_doc_rag("Give me a summary of the SWE-bench paper") 
print(str(response))
> User query string: Give me a summary of the SWE-bench paper
> Inferred query string: summary of the SWE-bench paper
> Inferred filters: {"filters": [{"key": "file_name", "value": "swebench.pdf", "operator": "=="}], "condition": "and"}
The construction of SWE-Bench involves a three-stage pipeline:


1. **Repo Selection and Data Scraping**: Pull requests (PRs) are collected from 12 popular open-source Python repositories on GitHub, resulting in approximately 90,000 PRs. These repositories are chosen for their popularity, better maintenance, clear contributor guidelines, and extensive test coverage.


2. **Attribute-Based Filtering**: Candidate tasks are created by selecting merged PRs that resolve a GitHub issue and make changes to the test files of the repository. This indicates that the user likely contributed tests to check whether the issue has been resolved.


3. **Execution-Based Filtering**: For each candidate task, the PR’s test content is applied, and the associated test results are logged before and after the PR’s other content is applied. Tasks are filtered out if they do not have at least one test where its status changes from fail to pass or if they result in installation or runtime errors.


Through these stages, the original 90,000 PRs are filtered down to 2,294 task


责任编辑:武晓燕 来源: 哎呀AIYA
相关推荐

2023-07-07 07:06:47

2023-06-06 10:19:28

2019-03-08 09:54:29

华为

2024-09-03 11:31:04

2021-09-07 10:06:00

人工智能机器学习技术

2019-05-30 20:54:05

华为

2023-05-26 14:02:29

AI智能

2019-04-28 09:19:33

存储

2019-05-09 22:10:36

AI

2024-04-08 07:52:24

2021-12-10 18:53:43

百度数字化转型

2013-01-10 09:58:50

CESCES 2013智能手表

2019-03-21 18:59:18

华为中国生态伙伴大会201智能进化

2019-03-21 19:21:33

华为中国生态伙伴大会201智能进化

2022-12-22 13:42:14

华为云

2020-04-13 11:10:45

华为

2019-03-18 09:29:23

华为中国生态伙伴大会201

2019-03-25 09:50:24

华为云

2018-06-25 08:24:46

人工智能机器学习技术
点赞
收藏

51CTO技术栈公众号