You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

439 lines
16 KiB

#!/usr/bin/env python3
"""
mem0 Client for OpenClaw - 生产级纯异步架构
Pre-Hook 检索注入 + Post-Hook 异步写入
元数据维度隔离 (user_id + agent_id)
"""
import os
import asyncio
import logging
import time
from typing import List, Dict, Optional, Any
from collections import deque
from datetime import datetime
# ========== DashScope 环境变量配置 ==========
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' # 关键:兼容模式需要此变量
os.environ['OPENAI_API_KEY'] = os.getenv('MEM0_DASHSCOPE_API_KEY', 'sk-c1715ee0479841399fd359c574647648')
try:
from mem0 import Memory
from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig
except ImportError as e:
print(f" mem0ai 导入失败:{e}")
Memory = None
logger = logging.getLogger(__name__)
class AsyncMemoryQueue:
"""纯异步记忆写入队列"""
def __init__(self, max_size: int = 100):
self.queue = deque(maxlen=max_size)
self.lock = asyncio.Lock()
self.running = False
self._worker_task = None
self.callback = None
self.batch_size = 10
self.flush_interval = 60
def add(self, item: Dict[str, Any]):
"""添加任务到队列(同步方法)"""
try:
if len(self.queue) < self.queue.maxlen:
self.queue.append({
'messages': item['messages'],
'user_id': item['user_id'],
'agent_id': item['agent_id'],
'timestamp': item.get('timestamp', datetime.now().isoformat())
})
else:
logger.warning("异步队列已满,丢弃旧任务")
except Exception as e:
logger.error(f"队列添加失败:{e}")
async def get_batch(self, batch_size: int) -> List[Dict]:
"""获取批量任务(异步方法)"""
async with self.lock:
batch = []
while len(batch) < batch_size and self.queue:
batch.append(self.queue.popleft())
return batch
def start_worker(self, callback, batch_size: int, flush_interval: int):
"""启动异步后台任务"""
self.running = True
self.callback = callback
self.batch_size = batch_size
self.flush_interval = flush_interval
self._worker_task = asyncio.create_task(self._worker_loop())
logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)")
async def _worker_loop(self):
"""异步工作循环"""
last_flush = time.time()
while self.running:
try:
if self.queue:
batch = await self.get_batch(self.batch_size)
if batch:
asyncio.create_task(self._process_batch(batch))
if time.time() - last_flush > self.flush_interval:
if self.queue:
batch = await self.get_batch(self.batch_size)
if batch:
asyncio.create_task(self._process_batch(batch))
last_flush = time.time()
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("异步工作线程已取消")
break
except Exception as e:
logger.error(f"异步工作线程错误:{e}")
await asyncio.sleep(5)
async def _process_batch(self, batch: List[Dict]):
"""处理批量任务"""
try:
logger.debug(f"开始处理批量任务:{len(batch)}")
for item in batch:
await self.callback(item)
logger.debug(f"批量任务处理完成")
except Exception as e:
logger.error(f"批量处理失败:{e}")
async def stop(self):
"""优雅关闭"""
self.running = False
if self._worker_task:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
logger.info("异步工作线程已关闭")
class Mem0Client:
"""
生产级 mem0 客户端
纯异步架构 + 阻塞操作隔离 + 元数据维度隔离
"""
def __init__(self, config: Dict = None):
self.config = config or self._load_default_config()
self.local_memory = None
self.async_queue = None
self.cache = {}
self._started = False
# 不在 __init__ 中启动异步任务
self._init_memory()
def _load_default_config(self) -> Dict:
"""加载默认配置"""
return {
"qdrant": {
"host": os.getenv('MEM0_QDRANT_HOST', 'localhost'),
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')),
"collection_name": "mem0_shared"
},
"llm": {
"provider": "openai",
"config": {"model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus')}
},
"retrieval": {
"enabled": True,
"top_k": 5,
"min_confidence": 0.7,
"timeout_ms": 2000
},
"async_write": {
"enabled": True,
"queue_size": 100,
"batch_size": 10,
"flush_interval": 60,
"timeout_ms": 5000
},
"cache": {
"enabled": True,
"ttl": 300,
"max_size": 1000
},
"fallback": {
"enabled": True,
"log_level": "WARNING",
"retry_attempts": 2
},
"metadata": {
"default_user_id": "default",
"default_agent_id": "general"
}
}
def _init_memory(self):
"""初始化 mem0(同步操作)"""
if Memory is None:
logger.warning("mem0ai 未安装")
return
try:
config = MemoryConfig(
vector_store=VectorStoreConfig(
provider="qdrant",
config={
"host": self.config['qdrant']['host'],
"port": self.config['qdrant']['port'],
"collection_name": self.config['qdrant']['collection_name'],
"on_disk": True
}
),
llm=LlmConfig(
provider="openai",
config=self.config['llm']['config']
)
)
self.local_memory = Memory(config=config)
logger.info("✅ mem0 初始化成功")
except Exception as e:
logger.error(f"❌ mem0 初始化失败:{e}")
self.local_memory = None
async def start(self):
"""
显式启动异步工作线程
必须在事件循环中调用await mem0_client.start()
"""
if self._started:
logger.debug("mem0 Client 已启动")
return
if self.config['async_write']['enabled']:
self.async_queue = AsyncMemoryQueue(
max_size=self.config['async_write']['queue_size']
)
self.async_queue.start_worker(
callback=self._async_write_memory,
batch_size=self.config['async_write']['batch_size'],
flush_interval=self.config['async_write']['flush_interval']
)
self._started = True
logger.info("✅ mem0 Client 异步工作线程已启动")
# ========== Pre-Hook: 智能检索 ==========
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]:
"""Pre-Hook: 对话前智能检索"""
if not self.config['retrieval']['enabled'] or self.local_memory is None:
return []
cache_key = f"{user_id}:{agent_id}:{query}"
if self.config['cache']['enabled'] and cache_key in self.cache:
cached = self.cache[cache_key]
if time.time() - cached['time'] < self.config['cache']['ttl']:
logger.debug(f"Cache hit: {cache_key}")
return cached['results']
timeout_ms = self.config['retrieval']['timeout_ms']
try:
memories = await asyncio.wait_for(
self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']),
timeout=timeout_ms / 1000
)
if self.config['cache']['enabled'] and memories:
self.cache[cache_key] = {'results': memories, 'time': time.time()}
self._cleanup_cache()
logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆")
return memories
except asyncio.TimeoutError:
logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)")
return []
except Exception as e:
logger.warning(f"Pre-Hook 检索失败:{e}")
return []
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]:
"""
执行检索 - 使用 metadata 过滤器实现维度隔离
"""
if self.local_memory is None:
return []
# 策略 1: 检索全局用户记忆
user_memories = []
if user_id:
try:
user_memories = await asyncio.to_thread(
self.local_memory.search,
query,
user_id=user_id,
limit=top_k
)
except Exception as e:
logger.debug(f"用户记忆检索失败:{e}")
# 策略 2: 检索业务域记忆(使用 metadata 过滤器)
agent_memories = []
if agent_id and agent_id != 'general':
try:
agent_memories = await asyncio.to_thread(
self.local_memory.search,
query,
user_id=user_id,
filters={"agent_id": agent_id}, # metadata 过滤,实现垂直隔离
limit=top_k
)
except Exception as e:
logger.debug(f"业务记忆检索失败:{e}")
# 合并结果(去重)
all_memories = {}
for mem in user_memories + agent_memories:
mem_id = mem.get('id') if isinstance(mem, dict) else None
if mem_id and mem_id not in all_memories:
all_memories[mem_id] = mem
# 按置信度过滤
min_confidence = self.config['retrieval']['min_confidence']
filtered = [
m for m in all_memories.values()
if m.get('score', 1.0) >= min_confidence
]
# 按置信度排序
filtered.sort(key=lambda x: x.get('score', 0), reverse=True)
return filtered[:top_k]
def format_memories_for_prompt(self, memories: List[Dict]) -> str:
"""格式化记忆为 Prompt 片段"""
if not memories:
return ""
prompt = "\n\n=== 相关记忆 ===\n"
for i, mem in enumerate(memories, 1):
memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem)
created_at = mem.get('created_at', '') if isinstance(mem, dict) else ''
prompt += f"{i}. {memory_text}"
if created_at:
prompt += f" (记录于:{created_at})"
prompt += "\n"
prompt += "===============\n"
return prompt
# ========== Post-Hook: 异步写入 ==========
def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None):
"""Post-Hook: 对话后异步写入(同步方法,仅添加到队列)"""
if not self.config['async_write']['enabled']:
return
if not user_id:
user_id = self.config['metadata']['default_user_id']
if not agent_id:
agent_id = self.config['metadata']['default_agent_id']
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
if self.async_queue:
self.async_queue.add({
'messages': messages,
'user_id': user_id,
'agent_id': agent_id,
'timestamp': datetime.now().isoformat()
})
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}")
else:
logger.warning("异步队列未初始化")
async def _async_write_memory(self, item: Dict):
"""异步写入记忆(后台任务)"""
if self.local_memory is None:
return
timeout_ms = self.config['async_write']['timeout_ms']
try:
await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000)
logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}")
except asyncio.TimeoutError:
logger.warning(f"异步写入超时 ({timeout_ms}ms)")
except Exception as e:
logger.warning(f"异步写入失败:{e}")
async def _execute_write(self, item: Dict):
"""
执行写入 - 使用 metadata 实现维度隔离
关键通过 metadata 字典传递 agent_id而非直接参数
"""
if self.local_memory is None:
return
# 构建元数据,实现业务隔离
custom_metadata = {
"agent_id": item['agent_id'],
"source": "openclaw",
"timestamp": item.get('timestamp'),
"business_type": item['agent_id']
}
# 阻塞操作,放入线程池执行
await asyncio.to_thread(
self.local_memory.add,
messages=item['messages'],
user_id=item['user_id'], # 原生支持的全局用户标识
metadata=custom_metadata # 注入自定义业务维度
)
def _cleanup_cache(self):
"""清理过期缓存"""
if not self.config['cache']['enabled']:
return
current_time = time.time()
ttl = self.config['cache']['ttl']
expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl]
for key in expired_keys:
del self.cache[key]
if len(self.cache) > self.config['cache']['max_size']:
oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']]
for key in oldest_keys:
del self.cache[key]
def get_status(self) -> Dict:
"""获取状态"""
return {
"initialized": self.local_memory is not None,
"started": self._started,
"async_queue_enabled": self.config['async_write']['enabled'],
"queue_size": len(self.async_queue.queue) if self.async_queue else 0,
"cache_size": len(self.cache),
"qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}"
}
async def shutdown(self):
"""优雅关闭"""
if self.async_queue:
await self.async_queue.stop()
logger.info("mem0 Client 已关闭")
# 全局客户端实例
mem0_client = Mem0Client()