|
|
#!/usr/bin/env python3 |
|
|
""" |
|
|
Session Initialization Hook - 冷启动记忆预加载 |
|
|
在每次 /new 或新会话启动时自动检索最近活跃上下文 |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__)) |
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') or os.getenv('DASHSCOPE_API_KEY', '') |
|
|
if _dashscope_key: |
|
|
os.environ['OPENAI_API_KEY'] = _dashscope_key |
|
|
|
|
|
from mem0_client import Mem0Client |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
# 冷启动检索查询列表 |
|
|
COLD_START_QUERIES = [ |
|
|
"current active tasks deployment progress", |
|
|
"pending work items todos", |
|
|
"recent important decisions configurations", |
|
|
] |
|
|
|
|
|
async def cold_start_retrieval(agent_id: str = "main", user_id: str = "default", top_k: int = 3) -> str: |
|
|
""" |
|
|
冷启动记忆检索 - 在会话初始化时调用 |
|
|
返回格式化的记忆片段,注入到 System Prompt |
|
|
""" |
|
|
client = Mem0Client() |
|
|
await client.start() |
|
|
|
|
|
all_memories = [] |
|
|
|
|
|
for query in COLD_START_QUERIES: |
|
|
try: |
|
|
memories = await client.pre_hook_search( |
|
|
query=query, |
|
|
user_id=user_id, |
|
|
agent_id=agent_id, |
|
|
top_k=top_k |
|
|
) |
|
|
all_memories.extend(memories) |
|
|
except Exception as e: |
|
|
logger.debug(f"检索失败 {query}: {e}") |
|
|
|
|
|
# 去重(按 memory 内容) |
|
|
seen = set() |
|
|
unique_memories = [] |
|
|
for m in all_memories: |
|
|
mem_text = m.get('memory', '') |
|
|
if mem_text and mem_text not in seen: |
|
|
seen.add(mem_text) |
|
|
unique_memories.append(m) |
|
|
|
|
|
# 按分数排序,取前 top_k 条 |
|
|
unique_memories.sort(key=lambda x: x.get('score', 0), reverse=True) |
|
|
unique_memories = unique_memories[:top_k] |
|
|
|
|
|
await client.shutdown() |
|
|
|
|
|
# 格式化为 Prompt 片段 |
|
|
if not unique_memories: |
|
|
return "" |
|
|
|
|
|
prompt = "\n\n=== 最近活跃上下文(自动加载) ===\n" |
|
|
for i, mem in enumerate(unique_memories, 1): |
|
|
mem_text = mem.get('memory', '') |
|
|
metadata = mem.get('metadata', {}) |
|
|
agent = metadata.get('agent_id', 'unknown') |
|
|
prompt += f"{i}. [{agent}] {mem_text}\n" |
|
|
prompt += "=================================\n" |
|
|
|
|
|
logger.info(f"冷启动检索完成:{len(unique_memories)} 条记忆") |
|
|
return prompt |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""测试冷启动检索""" |
|
|
result = await cold_start_retrieval( |
|
|
agent_id="main", |
|
|
user_id="wang 院长", |
|
|
top_k=5 |
|
|
) |
|
|
if result: |
|
|
print("✅ 冷启动检索成功:") |
|
|
print(result) |
|
|
else: |
|
|
print("⚠️ 无记忆记录(Qdrant 为空)") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
asyncio.run(main())
|
|
|
|