You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

98 lines
2.8 KiB

#!/usr/bin/env python3
"""
Session Initialization Hook - 冷启动记忆预加载
在每次 /new 或新会话启动时自动检索最近活跃上下文
"""
import asyncio
import sys
import os
import logging
sys.path.insert(0, os.path.dirname(__file__))
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') or os.getenv('DASHSCOPE_API_KEY', '')
if _dashscope_key:
os.environ['OPENAI_API_KEY'] = _dashscope_key
from mem0_client import Mem0Client
logger = logging.getLogger(__name__)
# 冷启动检索查询列表
COLD_START_QUERIES = [
"current active tasks deployment progress",
"pending work items todos",
"recent important decisions configurations",
]
async def cold_start_retrieval(agent_id: str = "main", user_id: str = "default", top_k: int = 3) -> str:
"""
冷启动记忆检索 - 在会话初始化时调用
返回格式化的记忆片段,注入到 System Prompt
"""
client = Mem0Client()
await client.start()
all_memories = []
for query in COLD_START_QUERIES:
try:
memories = await client.pre_hook_search(
query=query,
user_id=user_id,
agent_id=agent_id,
top_k=top_k
)
all_memories.extend(memories)
except Exception as e:
logger.debug(f"检索失败 {query}: {e}")
# 去重(按 memory 内容)
seen = set()
unique_memories = []
for m in all_memories:
mem_text = m.get('memory', '')
if mem_text and mem_text not in seen:
seen.add(mem_text)
unique_memories.append(m)
# 按分数排序,取前 top_k 条
unique_memories.sort(key=lambda x: x.get('score', 0), reverse=True)
unique_memories = unique_memories[:top_k]
await client.shutdown()
# 格式化为 Prompt 片段
if not unique_memories:
return ""
prompt = "\n\n=== 最近活跃上下文(自动加载) ===\n"
for i, mem in enumerate(unique_memories, 1):
mem_text = mem.get('memory', '')
metadata = mem.get('metadata', {})
agent = metadata.get('agent_id', 'unknown')
prompt += f"{i}. [{agent}] {mem_text}\n"
prompt += "=================================\n"
logger.info(f"冷启动检索完成:{len(unique_memories)} 条记忆")
return prompt
async def main():
"""测试冷启动检索"""
result = await cold_start_retrieval(
agent_id="main",
user_id="wang 院长",
top_k=5
)
if result:
print("✅ 冷启动检索成功:")
print(result)
else:
print(" 无记忆记录(Qdrant 为空)")
if __name__ == '__main__':
asyncio.run(main())