让Agent拥有记忆:从认知心理学到工程实践

📖 本文是学习 Hello-Agents 项目时整理的笔记与思考。


一、问题的起点:无状态的大模型

当前大模型虽然能力很强,但在设计上它们天然是无状态的。这意味着,每一次对话在模型眼中都是全新开始——它不会"记住"你之前问过什么,也不知道你上周讨论了哪些话题。

这种设计在单次问答场景下没有任何问题,但一旦我们希望构建一个真正可用的 Agent 系统,无状态就会带来一系列令人头疼的问题:

  • 上下文丢失:多轮对话无法延续,用户不得不反复重申背景信息,体验极差。
  • 个性化缺失:无法记住用户偏好、习惯与历史,每次交互都像是"第一次见面"。
  • 学习能力受限:无法从过往经验中积累知识、迭代改进,Agent 永远停留在出厂状态。
  • 一致性问题:在多轮对话中,模型可能给出相互矛盾的回答,严重损害用户信任。

除了"遗忘"这一先天缺陷,大模型的知识本身也有根本性的局限——它是静态的、有时效性的。这些知识完全来自训练数据,并由此带来另一系列问题:

  1. 知识时效性:训练数据有截止日期,模型无法获取最新信息,对于快速变化的领域尤为致命。
  2. 专业领域深度不足:通用模型在细分领域的知识往往流于表面,无法满足专业需求。
  3. 幻觉问题:缺乏外部知识验证,模型容易"一本正经地胡说八道",即所谓的幻觉(Hallucination)。
  4. 可解释性差:回答缺乏信息来源,用户无法验证真实性,可信度打折扣。

正是为了解决以上这些问题,在构建 Agent 时,记忆(Memory)系统检索增强生成(RAG)系统成为两大核心基础设施。


二、向人类学习:认知心理学中的记忆模型

在设计 Agent 的记忆系统之前,不妨先回头看看人类自身的记忆是如何运作的。毕竟,人类经过数百万年演化形成的记忆机制,本身就是一套经过极致优化的信息处理系统。

根据认知心理学的研究,人类记忆可以划分为以下几个层次:

感觉记忆(Sensory Memory)

持续时间极短(0.5–3 秒),容量巨大,负责暂时保存感官接收到的所有原始信息。这是记忆处理的"缓冲区",绝大多数信息在此阶段被过滤丢弃,只有少数被注意力选中后才进入下一阶段。

工作记忆(Working Memory)

持续时间较短(15–30 秒),容量极为有限(经典的"7±2"法则),负责当前任务的信息加工与推理。工作记忆相当于人类大脑的"CPU",是一切有意识思维活动的舞台。

长期记忆(Long-term Memory)

持续时间长达数十年甚至终生,容量几乎无限。长期记忆进一步细分为:

  • 程序性记忆:技能与习惯,往往内隐于行为中(如骑自行车、打字盲打),难以用语言描述。
  • 陈述性记忆:可以用语言明确表达的知识,又分为:
    • 语义记忆:抽象的通用知识与概念(如"巴黎是法国首都"、"水的化学式是 H₂O")。
    • 情景记忆:与时间、地点绑定的个人经历与事件(如"昨天的会议内容"、"上次旅行的见闻")。

这套分层记忆体系的精妙之处在于:不同类型的信息,以不同的形式、存储在不同的位置,并以不同的机制被检索。这正是我们在设计 Agent 记忆系统时最值得借鉴的核心思想。


三、Agent 的记忆系统设计

借鉴人类记忆的分层架构,我们可以为 Agent 构建对应的记忆体系。简单来说,Agent 同样需要短期记忆长期记忆,并在长期记忆层面进一步细化。

3.1 工作记忆(Working Memory)

工作记忆是记忆系统中最活跃的部分,负责存储当前对话会话中的临时信息。它的设计重点在于快速访问自动清理,以确保系统的响应速度与资源效率。

在实现上,工作记忆通常采用纯内存存储方案,配合 TTL(Time To Live)机制进行自动过期清理。这种设计的优势是访问速度极快(毫秒级),但也意味着工作记忆的内容在系统重启后会丢失。这种特性本身正符合工作记忆的定位:存储临时的、易变的、会话级的信息,而非需要持久化的长期知识。

典型存储内容:当前对话的历史消息、正在处理的任务状态、临时计算结果等。

3.2 情景记忆(Episodic Memory)

情景记忆负责存储具体的事件和经历,设计重点在于保持事件的完整性与时间序列关系。举例来说,"用户上周询问过 Python 异步编程的问题,并对某种解决方案表示满意",这就是一条典型的情景记忆。

在实现上,情景记忆通常采用 SQLite + Qdrant 的混合存储方案

  • SQLite 负责结构化数据的存储和复杂查询(如按时间段、按话题筛选历史事件);
  • Qdrant 负责高效的向量检索,支持基于语义相似度找到相关历史场景。

这种混合架构兼顾了结构化查询的精确性和语义检索的灵活性。

3.3 语义记忆(Semantic Memory)

语义记忆是记忆系统中最复杂的部分,负责存储抽象的概念、规则和知识。与情景记忆关注"发生了什么"不同,语义记忆关注"这个领域的知识是什么",例如用户的专业背景、行业术语、产品知识库等。

在实现上,语义记忆通常采用 Neo4j 图数据库 + Qdrant 向量数据库的混合架构

  • Neo4j 负责存储实体与实体之间的关系(知识图谱),支持复杂的关系推理,例如"A 是 B 的子类,B 属于 C 领域";
  • Qdrant 负责高维向量的语义检索,支持模糊的语义匹配。

语义记忆的检索实现了混合搜索策略,结合向量检索的语义理解能力和图检索的关系推理能力,两者相辅相成,大幅提升了知识检索的质量。

3.4 感知记忆(Perceptual Memory)

感知记忆支持文本、图像、音频等多种模态的数据存储与检索,对应人类的多感官记忆。

在实现上,感知记忆采用模态分离的存储策略,为不同模态的数据创建独立的向量集合。这种设计避免了不同模态特征向量维度不匹配的问题,同时保证了检索的准确性。

感知记忆的检索支持两种模式:

  • 同模态检索:利用专业的单模态编码器进行精确匹配,例如用图像检索图像;
  • 跨模态检索:需要更复杂的语义对齐机制,例如用文字描述检索相关图片,这通常依赖多模态嵌入模型(如 CLIP)来实现语义空间的统一。

四、知识的外延:什么是 RAG?

解决了对话记忆的问题,我们还需要面对另一个挑战:如何让 Agent 具备超出训练数据范围的、实时的、私有的知识?这正是 RAG 技术的用武之地。

检索增强生成(Retrieval-Augmented Generation,RAG) 是一种结合了信息检索与文本生成的技术。它的核心思想是:在生成回答之前,先从外部知识库中检索相关信息,然后将检索到的内容作为上下文提供给大语言模型,从而生成更准确、更可靠的回答。

"检索增强生成"可以拆解为三个关键词:

  • 检索:从知识库中查询与问题最相关的内容片段;
  • 增强:将检索结果注入 Prompt,辅助模型生成,弥补模型知识的盲区;
  • 生成:输出兼具准确性与透明度的最终答案,并可附带信息来源,增强可信度。

4.1 基本工作流

一个完整的 RAG 应用流程主要分为两大核心环节:

数据准备阶段:系统对外部文档进行数据提取、文本分割和向量化,将非结构化的知识转化为一个可检索的向量数据库。这个过程通常是离线完成的。

应用推理阶段:用户发起提问后,系统首先将问题向量化,然后从知识库中检索最相关的文本片段,将其与原始问题一同注入 Prompt,最终由大语言模型生成综合了检索内容的高质量答案。

4.2 RAG 的发展历程

RAG 技术并非一蹴而就,它经历了清晰的三个发展阶段:

第一阶段:朴素 RAG(Naive RAG,2020–2021)

这是 RAG 的萌芽阶段,流程直接而简单,通常被称为"检索-读取"(Retrieve-Read)模式。检索方面主要依赖传统的关键词匹配算法,如 TF-IDF 或 BM25,这些方法对字面匹配效果不错,但难以理解语义上的相似性——用户换一种说法提问,检索结果可能天差地别。生成方面则将检索到的文档内容不加处理地直接拼接到 Prompt 上下文中。这个阶段验证了 RAG 的可行性,但存在明显的精度瓶颈。

第二阶段:高级 RAG(Advanced RAG,2022–2023)

随着向量数据库和文本嵌入技术的成熟,RAG 进入快速发展阶段。检索方式转向基于稠密嵌入(Dense Embedding) 的语义检索——将文本转换为高维向量后,模型能够理解和匹配语义上的相似性,不再局限于关键词的字面匹配。此外,研究者和工程师在检索与生成的各个环节引入了大量优化技术,如查询重写(让问题更清晰)、文档分块策略(更合理地切割文本)、重排序(Re-ranking)(对检索结果进行二次精排)等。

第三阶段:模块化 RAG(Modular RAG,2023–至今)

在高级 RAG 的基础上,现代 RAG 系统进一步向模块化、自动化和智能化的方向演进。系统的各个部分被设计成可插拔、可组合的独立模块,以适应更多样化、更复杂的应用场景。在检索侧,涌现出混合检索(结合稀疏检索与稠密检索的优势)、多查询扩展(对同一个问题生成多个变体并聚合结果)、假设性文档嵌入(HyDE) 等创新方法。在生成侧,则引入了思维链推理自我反思与修正等技术,让模型能够对自己的答案进行评估和迭代改进。


五、Memory 设计的 5 大核心问题

一个生产可用的记忆系统,绝不只是"把信息存起来"这么简单。在工程实践中,我们需要认真回答以下五个关键问题。


问题 1:什么时候存储?

信息的价值是不均等的。并非所有对话内容都值得被永久记忆——如果对每句话都照单全收,记忆库将迅速膨胀,噪声淹没信号,反而降低检索质量。

常见策略对比

策略 描述 优点 缺点
全部存储 不加筛选地存储所有信息 实现简单 内存爆炸,噪声多
关键信息存储 只存储重要性超过阈值的信息 质量高,存储精准 需要设计重要性判断逻辑
周期性存储 每 N 轮对话总结一次 自然压缩信息 可能遗漏细节

推荐方案——重要性评分机制

def calculate_importance(text: str) -> float:
    """计算信息重要性(0-1)"""
    score = 0.0

    # 1. 语义相关性(与用户画像相关)
    score += semantic_relevance(text) * 0.4

    # 2. 实体密度(包含多少关键实体)
    score += entity_density(text) * 0.3

    # 3. 时效性(是否是最新信息)
    score += recency(text) * 0.2

    # 4. 用户明确要求记住
    if "记住" in text or "别忘了" in text:
        score += 0.5

    return min(score, 1.0)

# 只存储重要性 > 0.6 的信息
if calculate_importance(text) > 0.6:
    memory.store(text)

这套评分机制结合了语义相关性、信息密度、时效性和用户意图四个维度,能在自动化的前提下较好地过滤噪声、保留真正有价值的信息。


问题 2:如何存储?

确定了"存什么",接下来要考虑"用什么格式存"。存储格式直接影响后续的检索效率与信息质量。

存储格式选择

格式 优点 缺点 适合场景
原文 信息完整,无损失 冗余、占用空间大 重要对话、需要精确还原
摘要 节省空间,便于快速阅读 可能丢失关键细节 一般性对话记录
结构化(实体+关系) 便于检索和推理 提取成本高,依赖 NLP 能力 知识密集型对话

推荐方案——混合存储

class HybridMemoryStorage:
    def store(self, conversation: str):
        # 1. 原文存储(向量检索用)
        self.vector_store.add(conversation)

        # 2. 结构化提取(知识图谱用)
        entities = extract_entities(conversation)
        relations = extract_relations(conversation)
        self.knowledge_graph.add(entities, relations)

        # 3. 摘要存储(压缩用)
        summary = summarize(conversation)
        self.summary_store.add(summary)

混合存储的思路是:用原文保底,用摘要提速,用结构化增强推理,三种格式各司其职,互为补充。


问题 3:如何检索?

存储只是手段,检索才是目的。一个优秀的记忆检索系统,需要在速度、准确性和覆盖率之间取得平衡。

推荐方案——混合检索策略

def retrieve_memory(query: str, user_id: str):
    """混合检索策略"""
    results = []

    # 1. 向量检索(语义相似)
    vector_results = vector_db.search(query, top_k=5)
    results.extend(vector_results)

    # 2. 图谱检索(关系推理)
    entities = extract_entities(query)
    graph_results = knowledge_graph.query(entities)
    results.extend(graph_results)

    # 3. 时间过滤(最近的优先)
    results = filter_by_recency(results, days=30)

    # 4. 重排序(综合评分)
    results = rerank(results, query)

    return results[:5]

这套策略将语义相似度检索(向量检索)结构化关系推理(图谱检索)时间衰减过滤结果重排序整合在一起,最终返回质量最高的 Top-K 条记忆,大幅提升了检索的召回率和精准度。


问题 4:何时遗忘?

"忘记"在记忆设计中往往被忽视,但它与"记住"同等重要。无限堆积的记忆不仅浪费存储资源,还会增加检索噪声,降低系统整体性能。合理的遗忘机制,是保持记忆系统健康运转的必要条件。

遗忘策略

def should_forget(memory_item) -> bool:
    """判断是否应该遗忘"""

    # 1. 时间衰减:久远且不重要的信息自动淘汰
    age_days = (now - memory_item.timestamp).days
    if age_days > 90 and memory_item.importance < 0.5:
        return True

    # 2. 访问频率:从未被检索到的记忆意义有限
    if memory_item.access_count == 0 and age_days > 30:
        return True

    # 3. 空间限制:超出容量时淘汰最不重要的记忆
    if memory_store.size() > MAX_SIZE:
        return memory_item.importance < threshold

    # 4. 信息冗余:重复的信息只保留最新版本
    if has_duplicate(memory_item):
        return True

    return False

这套遗忘机制从时间、频率、空间和冗余四个维度综合判断,策略上与人类的记忆遗忘曲线(艾宾浩斯遗忘曲线)有异曲同工之妙:不常用的、老旧的、重复的信息会逐渐淡出,而频繁访问的、重要的信息则得以长期保留。


问题 5:如何更新?

现实世界是不断变化的。用户今天的想法可能和上个月完全不同,事实信息也可能随时间推移而失效。如果记忆系统不具备更新能力,存储的旧信息反而会"毒化"Agent 的判断,产生错误的个性化。

更新策略

def update_memory(new_info: str, user_id: str):
    """更新记忆"""

    # 1. 检索与新信息相关的已有记忆
    existing = memory.search(new_info, user_id)

    for mem in existing:
        # 2. 如果存在冲突,判断哪条信息更新
        if is_conflict(mem, new_info):
            if is_newer(new_info):
                # 用新信息覆盖旧信息
                memory.update(mem.id, new_info)
            else:
                # 保留已有的旧信息(新信息反而更旧)
                pass

        # 3. 如果是补充信息,则合并
        elif is_complementary(mem, new_info):
            merged = merge(mem, new_info)
            memory.update(mem.id, merged)

    # 4. 如果是全新信息,直接添加
    if not existing:
        memory.add(new_info, user_id)

更新逻辑的核心在于冲突检测补充合并的区分:冲突时以更新的信息为准,补充时则将新旧信息融合,避免覆盖有效信息;对于全新的信息,则直接写入。


六、实战案例

以下案例是基于 HelloAgents 开发的智能文档问答助手,旨在深化理解 agent 的 memory 和 rag 是如何工作的。

"""  
智能问答助手 - 基于 HelloAgents 的智能文档问答系统  

这是一个完整的PDF学习助手应用,支持:  
- 加载 PDF 文档并构建知识库  
- 智能问答(基于 rag)  
- 学习历程记录(基于 memory)  
- 学习回顾和报告生成  
"""  
from dotenv import load_dotenv  
load_dotenv()  
import os  
import json  
import time  
from datetime import datetime  
from typing import Dict, Any, Optional, List, Tuple  
import gradio as gr  

from hello_agents.tools import RAGTool, MemoryTool  

def _mask_secret(value: Optional[str], keep: int = 6) -> str:  
    """对敏感信息做最小必要脱敏"""  
    if not value:  
        return "(empty)"  
    if len(value) <= keep * 2:  
        return "*" * len(value)  
    return f"{value[:keep]}...{value[-keep:]}"  

def _print_llm_diagnostics(rag_tool) -> None:  
    """打印当前运行时的LLM诊断信息,不修改依赖库逻辑"""  
    print("\n" + "=" * 60)  
    print("🔎 HelloAgents LLM 诊断信息")  
    print("=" * 60)  

    env_summary = {  
        "LLM_MODEL_ID": os.getenv("LLM_MODEL_ID"),  
        "LLM_BASE_URL": os.getenv("LLM_BASE_URL"),  
        "LLM_API_KEY": _mask_secret(os.getenv("LLM_API_KEY")),  
        "OPENAI_API_KEY_exists": bool(os.getenv("OPENAI_API_KEY")),  
        "DEEPSEEK_API_KEY_exists": bool(os.getenv("DEEPSEEK_API_KEY")),  
        "DASHSCOPE_API_KEY_exists": bool(os.getenv("DASHSCOPE_API_KEY")),  
        "MODELSCOPE_API_KEY_exists": bool(os.getenv("MODELSCOPE_API_KEY")),  
        "KIMI_API_KEY_exists": bool(os.getenv("KIMI_API_KEY") or os.getenv("MOONSHOT_API_KEY")),  
        "ZHIPU_API_KEY_exists": bool(os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY")),  
        "OLLAMA_API_KEY_exists": bool(os.getenv("OLLAMA_API_KEY")),  
        "VLLM_API_KEY_exists": bool(os.getenv("VLLM_API_KEY")),  
    }  

    print("📦 环境变量摘要:")  
    for key, value in env_summary.items():  
        print(f"  - {key}: {value}")  

    llm = getattr(rag_tool, "llm", None)  
    if llm is None:  
        print("❌ rag_tool.llm 尚未初始化")  
        print("=" * 60 + "\n")  
        return  

    print("🤖 RAGTool 实际绑定的 LLM 实例:")  
    print(f"  - provider: {getattr(llm, 'provider', None)}")  
    print(f"  - model: {getattr(llm, 'model', None)}")  
    print(f"  - base_url: {getattr(llm, 'base_url', None)}")  
    print(f"  - api_key(masked): {_mask_secret(getattr(llm, 'api_key', None))}")  
    print(f"  - timeout: {getattr(llm, 'timeout', None)}")  
    print("=" * 60 + "\n")  

class PDFLearningAssistant:  
    def __init__(self,user_id:str):  
        """  
        args:        user_id 用户第Id,用来隔离不同用户数据  
        """        self.user_id = user_id  
        self.session_id = f"session_{user_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}"  

        # 初始化工具  
        self.memory_tool = MemoryTool(user_id=user_id)  
        self.rag_tool = RAGTool(rag_namespace=f"pdf_{user_id}")  
        _print_llm_diagnostics(self.rag_tool)  

        # 学习统计  
        self.stats = {  
            "session_start":datetime.now(),  
            "documents_loaded": 0,  
            "questions_asked": 0,  
            "concepts_learned": 0  
        }  

        # 当前加载的文档  
        self.current_document = None  

    def load_document(self,pdf_path:str)->Dict[str,Any]:  
        """加载PDF文档到知识库  

        Args:            pdf_path: PDF文件路径  

        Returns:            Dict: 包含success和message的结果  
        """        if not os.path.exists(pdf_path):  
            return {"success": False, "message": f"文件不存在: {pdf_path}"}  

        start_time = time.time()  

        try:  
            # 使用 rag 处理文档  
            result = self.rag_tool.run({  
                "action": "add_document",  
                "file_path": pdf_path,  
                "chunk_size": 1000,  
                "chunk_overlap": 200  
            })  

            process_time = time.time() - start_time  

            self.current_document = os.path.basename(pdf_path)  
            self.stats["documents_loaded"] += 1  

            # 记录到学习记忆  
            self.memory_tool.run({  
                "action":"add",  
                "content":f"加载了文档《{self.current_document}》",  
                "memory_type":"episodic",  
                "importance":0.9,  
                "event_type":"document_loaded",  
                "session_id":self.session_id  
            })  

            return {  
                "success": True,  
                "message":f"加载成功!(耗时: {process_time:.1f}秒)",  
                "document": self.current_document  
            }  
        except Exception as e:  
            return {  
                "success": False,  
                "message": f"加载失败: {str(e)}"  
            }  

    def ask(self,question:str,use_advanced_search:bool = True)->str:  
        """  
        向文档提问  
        Args:            question: 用户问题  
            use_advanced_search: 是否使用高级检索(MQE + HyDE)  

        Returns:            str: 答案  
        """        if not self.current_document:  
            return "⚠️ 请先加载文档!使用 load_document() 方法加载PDF文档。"  

        # 记录问题到工作记忆  
        self.memory_tool.run({  
            "action": "add",  
            "content": f"提问: {question}",  
            "memory_type": "working",  
            "importance": 0.6,  
            "session_id": self.session_id  
        })  

        # 使用 rag 检索答案  
        answer = self.rag_tool.run({  
            "action": "ask",  
            "question": question,  
            "limit": 5,  
            "enable_advanced_search": use_advanced_search,  
            "enable_mqe": use_advanced_search,  
            "enable_hyde": use_advanced_search  
        })  

        # 记录到情景记忆  
        self.memory_tool.run({  
            "action": "add",  
            "content": f"关于'{question}'的学习",  
            "memory_type": "episodic",  
            "importance": 0.7,  
            "event_type": "qa_interaction",  
            "session_id": self.session_id  
        })  

        self.stats["questions_asked"] += 1  

        return answer  

    def add_note(self,content:str,concept:Optional[str] = None):  
        """添加学习笔记  

        Args:            content: 笔记内容  
            concept: 相关概念(可选)  
        """        self.memory_tool.run({  
            "action": "add",  
            "content": content,  
            "memory_type": "semantic",  
            "importance": 0.8,  
            "concept": concept or "general",  
            "session_id": self.session_id  
        })  
        self.stats["concepts_learned"] += 1  

    def recall(self,query:str,limit:int=5)->str:  
        """回顾学习历程  

        Args:            query: 查询关键词  
            limit: 返回结果数量  

        Returns:            str: 相关记忆  
        """        result = self.memory_tool.run({  
            "action": "search",  
            "query": query,  
            "limit": limit  
        })  

        return str(result)  

    def get_stats(self)->Dict[str,Any]:  
        """获取学习统计  

        Returns:            Dict: 统计信息  
        """  
        duration = (datetime.now() - self.stats["session_start"]).total_seconds()  

        return {  
            "会话时长": f"{duration:.0f}秒",  
            "加载文档": self.stats["documents_loaded"],  
            "提问次数": self.stats["questions_asked"],  
            "学习笔记": self.stats["concepts_learned"],  
            "当前文档": self.current_document or "未加载"  
        }  

    def generate_report(self,save_to_file:bool=True)->Dict[str,Any]:  
        """生成学习报告  

        Args:            save_to_file: 是否保存到文件  

        Returns:            Dict: 学习报告  
        """        # 获取记忆摘要  
        memory_summary = self.memory_tool.run({  
            "action": "summary",  
            "limit": 10,  
        })  

        # 获取 rag 统计  
        rag_stats = self.rag_tool.run({  
            "action": "stats",  
        })  

        # 生成报告  
        duration = (datetime.now() - self.stats["session_start"]).total_seconds()  
        report = {  
            "session_info": {  
                "session_id": self.session_id,  
                "user_id": self.user_id,  
                "start_time": self.stats["session_start"].isoformat(),  
                "duration_seconds": duration  
            },  
            "learning_metrics": {  
                "documents_loaded": self.stats["documents_loaded"],  
                "questions_asked": self.stats["questions_asked"],  
                "concepts_learned": self.stats["concepts_learned"]  
            },  
            "memory_summary": memory_summary,  
            "rag_status": rag_stats  
        }  

        # 保存到文件  
        if save_to_file:  
            report_file = f"learning_report_{self.session_id}.json"  
            try:  
                with open(report_file,'w',encoding='utf-8') as f:  
                    json.dump(report,f,ensure_ascii=False,indent=2)  
                report["report_file"] = report_file  
            except Exception as e:  
                report["save_error"] = str(e)  

        return report  

def create_gradio_ui():  
    """创建Gradio Web UI"""  
    # 全局助手实例  
    assistant_state = {"assistant": None}  

    def init_assistant(user_id: str) -> str:  
        """初始化助手"""  
        if not user_id:  
            user_id = "web_user"  
        try:  
            assistant_state["assistant"] = PDFLearningAssistant(user_id=user_id)  
            return f"✅ 助手已初始化 (用户: {user_id})"  
        except Exception as e:  
            assistant_state["assistant"] = None  
            return f"❌ 助手初始化失败: {type(e).__name__}: {e}"  

    def load_pdf(pdf_file) -> str:  
        """加载PDF文件"""  
        if assistant_state["assistant"] is None:  
            return "❌ 请先初始化助手"  

        if pdf_file is None:  
            return "❌ 请上传PDF文件"  

        # Gradio上传的文件是临时文件对象  
        pdf_path = pdf_file.name  
        result = assistant_state["assistant"].load_document(pdf_path)  

        if result["success"]:  
            return f"✅ {result['message']}\n📄 文档: {result['document']}"  
        else:  
            return f"❌ {result['message']}"  

    def _normalize_history(history: Optional[List]) -> List[Dict[str, str]]:  
        """兼容旧版二维数组格式,统一转换为 Chatbot messages 格式。"""  
        normalized_history: List[Dict[str, str]] = []  

        if not history:  
            return normalized_history  

        for item in history:  
            if isinstance(item, dict) and "role" in item and "content" in item:  
                normalized_history.append({  
                    "role": str(item["role"]),  
                    "content": str(item["content"]),  
                })  
            elif isinstance(item, (list, tuple)) and len(item) == 2:  
                user_message, assistant_message = item  
                normalized_history.append({  
                    "role": "user",  
                    "content": "" if user_message is None else str(user_message),  
                })  
                normalized_history.append({  
                    "role": "assistant",  
                    "content": "" if assistant_message is None else str(assistant_message),  
                })  

        return normalized_history  

    def chat(message: str, history: Optional[List]) -> Tuple[str, List[Dict[str, str]]]:  
        """聊天功能"""  
        normalized_history = _normalize_history(history)  

        if assistant_state["assistant"] is None:  
            normalized_history.extend([  
                {"role": "user", "content": message},  
                {"role": "assistant", "content": "❌ 请先初始化助手并加载文档"},  
            ])  
            return "", normalized_history  

        if not message.strip():  
            return "", normalized_history  

        # 判断是技术问题还是回顾问题  
        if any(keyword in message for keyword in ["之前", "学过", "回顾", "历史", "记得"]):  
            # 回顾学习历程  
            response = assistant_state["assistant"].recall(message)  
            response = f"🧠 **学习回顾**\n\n{response}"  
        else:  
            # 技术问答  
            response = assistant_state["assistant"].ask(message)  
            response = f"💡 **回答**\n\n{response}"  

        normalized_history.extend([  
            {"role": "user", "content": message},  
            {"role": "assistant", "content": response},  
        ])  
        return "", normalized_history  

    def add_note_ui(note_content: str, concept: str) -> str:  
        """添加笔记"""  
        if assistant_state["assistant"] is None:  
            return "❌ 请先初始化助手"  

        if not note_content.strip():  
            return "❌ 笔记内容不能为空"  

        assistant_state["assistant"].add_note(note_content, concept or None)  
        return f"✅ 笔记已保存: {note_content[:50]}..."  

    def get_stats_ui() -> str:  
        """获取统计信息"""  
        if assistant_state["assistant"] is None:  
            return "❌ 请先初始化助手"  

        stats = assistant_state["assistant"].get_stats()  
        result = "📊 **学习统计**\n\n"  
        for key, value in stats.items():  
            result += f"- **{key}**: {value}\n"  
        return result  

    def generate_report_ui() -> str:  
        """生成报告"""  
        if assistant_state["assistant"] is None:  
            return "❌ 请先初始化助手"  

        report = assistant_state["assistant"].generate_report(save_to_file=True)  

        result = f"✅ 学习报告已生成\n\n"  
        result += f"**会话信息**\n"  
        result += f"- 会话时长: {report['session_info']['duration_seconds']:.0f}秒\n"  
        result += f"- 加载文档: {report['learning_metrics']['documents_loaded']}\n"  
        result += f"- 提问次数: {report['learning_metrics']['questions_asked']}\n"  
        result += f"- 学习笔记: {report['learning_metrics']['concepts_learned']}\n"  

        if "report_file" in report:  
            result += f"\n💾 报告已保存至: {report['report_file']}"  

        return result  

    # 创建Gradio界面  
    with gr.Blocks(title="智能文档问答助手") as demo:  
        gr.Markdown("""  
        # 📚 智能文档问答助手  

        基于HelloAgents的智能文档问答系统,支持:  
        - 📄 加载PDF文档并构建知识库  
        - 💬 智能问答(基于RAG)  
        - 📝 学习笔记记录  
        - 🧠 学习历程回顾  
        - 📊 学习报告生成  
        """)  

        with gr.Tab("🏠 开始使用"):  
            with gr.Row():  
                user_id_input = gr.Textbox(  
                    label="用户ID",  
                    placeholder="输入你的用户ID(可选,默认为web_user)",  
                    value="web_user"  
                )  
                init_btn = gr.Button("初始化助手", variant="primary")  

            init_output = gr.Textbox(label="初始化状态", interactive=False)  
            init_btn.click(init_assistant, inputs=[user_id_input], outputs=[init_output])  

            gr.Markdown("### 📄 加载PDF文档")  
            pdf_upload = gr.File(  
                label="上传PDF文件",  
                file_types=[".pdf"],  
                type="filepath"  
            )  
            load_btn = gr.Button("加载文档", variant="primary")  
            load_output = gr.Textbox(label="加载状态", interactive=False)  
            load_btn.click(load_pdf, inputs=[pdf_upload], outputs=[load_output])  

        with gr.Tab("💬 智能问答"):  
            gr.Markdown("### 向文档提问或回顾学习历程")  
            chatbot = gr.Chatbot(  
                label="对话历史",  
                height=400  
            )  
            with gr.Row():  
                msg_input = gr.Textbox(  
                    label="输入问题",  
                    placeholder="例如:什么是Transformer? 或 我之前学过什么?",  
                    scale=4  
                )  
                send_btn = gr.Button("发送", variant="primary", scale=1)  

            gr.Examples(  
                examples=[  
                    "什么是大语言模型?",  
                    "Transformer架构有哪些核心组件?",  
                    "如何训练大语言模型?",  
                    "我之前学过什么内容?",  
                    "回顾一下关于注意力机制的学习"  
                ],  
                inputs=msg_input  
            )  

            msg_input.submit(chat, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot])  
            send_btn.click(chat, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot])  

        with gr.Tab("📝 学习笔记"):  
            gr.Markdown("### 记录学习心得和重要概念")  
            note_content = gr.Textbox(  
                label="笔记内容",  
                placeholder="输入你的学习笔记...",  
                lines=3  
            )  
            concept_input = gr.Textbox(  
                label="相关概念(可选)",  
                placeholder="例如:transformer, attention"  
            )  
            note_btn = gr.Button("保存笔记", variant="primary")  
            note_output = gr.Textbox(label="保存状态", interactive=False)  
            note_btn.click(add_note_ui, inputs=[note_content, concept_input], outputs=[note_output])  

        with gr.Tab("📊 学习统计"):  
            gr.Markdown("### 查看学习进度和统计信息")  
            stats_btn = gr.Button("刷新统计", variant="primary")  
            stats_output = gr.Markdown()  
            stats_btn.click(get_stats_ui, outputs=[stats_output])  

            gr.Markdown("### 生成学习报告")  
            report_btn = gr.Button("生成报告", variant="primary")  
            report_output = gr.Textbox(label="报告状态", interactive=False)  
            report_btn.click(generate_report_ui, outputs=[report_output])  

    return demo  

def main():  
    """主函数 - 启动Gradio Web UI"""  
    print("\n" + "=" * 60)  
    print("智能文档问答助手")  
    print("=" * 60)  
    print("正在启动Web界面...\n")  

    demo = create_gradio_ui()  
    demo.launch(  
        server_name="0.0.0.0",  
        server_port=7861,  
        share=False,  
        show_error=True,  
        theme=gr.themes.Soft()  
    )  

if __name__ == "__main__":  
    main()
暂无评论

发送评论 编辑评论


				
上一篇