|
|
#!/usr/bin/env python3 |
|
|
""" |
|
|
mem0 Client for OpenClaw - 生产级纯异步架构 |
|
|
Pre-Hook 检索注入 + Post-Hook 异步写入 |
|
|
元数据维度隔离 (user_id + agent_id) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import asyncio |
|
|
import logging |
|
|
import time |
|
|
from typing import List, Dict, Optional, Any |
|
|
from collections import deque |
|
|
from datetime import datetime |
|
|
|
|
|
# ========== DashScope 环境变量配置 ========== |
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' # 关键:兼容模式需要此变量 |
|
|
os.environ['OPENAI_API_KEY'] = os.getenv('MEM0_DASHSCOPE_API_KEY', 'sk-c1715ee0479841399fd359c574647648') |
|
|
|
|
|
try: |
|
|
from mem0 import Memory |
|
|
from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig, EmbedderConfig |
|
|
except ImportError as e: |
|
|
print(f"⚠️ mem0ai 导入失败:{e}") |
|
|
Memory = None |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AsyncMemoryQueue: |
|
|
"""纯异步记忆写入队列""" |
|
|
|
|
|
def __init__(self, max_size: int = 100): |
|
|
self.queue = deque(maxlen=max_size) |
|
|
self.lock = asyncio.Lock() |
|
|
self.running = False |
|
|
self._worker_task = None |
|
|
self.callback = None |
|
|
self.batch_size = 10 |
|
|
self.flush_interval = 60 |
|
|
|
|
|
def add(self, item: Dict[str, Any]): |
|
|
"""添加任务到队列(同步方法)""" |
|
|
try: |
|
|
if len(self.queue) < self.queue.maxlen: |
|
|
self.queue.append({ |
|
|
'messages': item['messages'], |
|
|
'user_id': item['user_id'], |
|
|
'agent_id': item['agent_id'], |
|
|
'timestamp': item.get('timestamp', datetime.now().isoformat()) |
|
|
}) |
|
|
else: |
|
|
logger.warning("异步队列已满,丢弃旧任务") |
|
|
except Exception as e: |
|
|
logger.error(f"队列添加失败:{e}") |
|
|
|
|
|
async def get_batch(self, batch_size: int) -> List[Dict]: |
|
|
"""获取批量任务(异步方法)""" |
|
|
async with self.lock: |
|
|
batch = [] |
|
|
while len(batch) < batch_size and self.queue: |
|
|
batch.append(self.queue.popleft()) |
|
|
return batch |
|
|
|
|
|
def start_worker(self, callback, batch_size: int, flush_interval: int): |
|
|
"""启动异步后台任务""" |
|
|
self.running = True |
|
|
self.callback = callback |
|
|
self.batch_size = batch_size |
|
|
self.flush_interval = flush_interval |
|
|
self._worker_task = asyncio.create_task(self._worker_loop()) |
|
|
logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)") |
|
|
|
|
|
async def _worker_loop(self): |
|
|
"""异步工作循环""" |
|
|
last_flush = time.time() |
|
|
|
|
|
while self.running: |
|
|
try: |
|
|
if self.queue: |
|
|
batch = await self.get_batch(self.batch_size) |
|
|
if batch: |
|
|
asyncio.create_task(self._process_batch(batch)) |
|
|
|
|
|
if time.time() - last_flush > self.flush_interval: |
|
|
if self.queue: |
|
|
batch = await self.get_batch(self.batch_size) |
|
|
if batch: |
|
|
asyncio.create_task(self._process_batch(batch)) |
|
|
last_flush = time.time() |
|
|
|
|
|
await asyncio.sleep(1) |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
logger.info("异步工作线程已取消") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"异步工作线程错误:{e}") |
|
|
await asyncio.sleep(5) |
|
|
|
|
|
async def _process_batch(self, batch: List[Dict]): |
|
|
"""处理批量任务""" |
|
|
try: |
|
|
logger.debug(f"开始处理批量任务:{len(batch)} 条") |
|
|
for item in batch: |
|
|
await self.callback(item) |
|
|
logger.debug(f"批量任务处理完成") |
|
|
except Exception as e: |
|
|
logger.error(f"批量处理失败:{e}") |
|
|
|
|
|
async def stop(self): |
|
|
"""优雅关闭""" |
|
|
self.running = False |
|
|
if self._worker_task: |
|
|
self._worker_task.cancel() |
|
|
try: |
|
|
await self._worker_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
logger.info("异步工作线程已关闭") |
|
|
|
|
|
|
|
|
class Mem0Client: |
|
|
""" |
|
|
生产级 mem0 客户端 |
|
|
纯异步架构 + 阻塞操作隔离 + 元数据维度隔离 |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Dict = None): |
|
|
self.config = config or self._load_default_config() |
|
|
self.local_memory = None |
|
|
self.async_queue = None |
|
|
self.cache = {} |
|
|
self._started = False |
|
|
# 不在 __init__ 中启动异步任务 |
|
|
self._init_memory() |
|
|
|
|
|
def _load_default_config(self) -> Dict: |
|
|
"""加载默认配置""" |
|
|
return { |
|
|
"qdrant": { |
|
|
"host": os.getenv('MEM0_QDRANT_HOST', 'localhost'), |
|
|
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), |
|
|
"collection_name": "mem0_shared" |
|
|
}, |
|
|
"llm": { |
|
|
"provider": "openai", |
|
|
"config": { |
|
|
"model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus') |
|
|
} |
|
|
}, |
|
|
"embedder": { |
|
|
"provider": "openai", |
|
|
"config": { |
|
|
"model": os.getenv('MEM0_EMBEDDER_MODEL', 'text-embedding-v3'), |
|
|
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1" |
|
|
} |
|
|
}, |
|
|
"retrieval": { |
|
|
"enabled": True, |
|
|
"top_k": 5, |
|
|
"min_confidence": 0.7, |
|
|
"timeout_ms": 2000 |
|
|
}, |
|
|
"async_write": { |
|
|
"enabled": True, |
|
|
"queue_size": 100, |
|
|
"batch_size": 10, |
|
|
"flush_interval": 60, |
|
|
"timeout_ms": 5000 |
|
|
}, |
|
|
"cache": { |
|
|
"enabled": True, |
|
|
"ttl": 300, |
|
|
"max_size": 1000 |
|
|
}, |
|
|
"fallback": { |
|
|
"enabled": True, |
|
|
"log_level": "WARNING", |
|
|
"retry_attempts": 2 |
|
|
}, |
|
|
"metadata": { |
|
|
"default_user_id": "default", |
|
|
"default_agent_id": "general" |
|
|
} |
|
|
} |
|
|
|
|
|
def _init_memory(self): |
|
|
"""初始化 mem0(同步操作)- 三位一体完整配置""" |
|
|
if Memory is None: |
|
|
logger.warning("mem0ai 未安装") |
|
|
return |
|
|
|
|
|
try: |
|
|
config = MemoryConfig( |
|
|
vector_store=VectorStoreConfig( |
|
|
provider="qdrant", |
|
|
config={ |
|
|
"host": self.config['qdrant']['host'], |
|
|
"port": self.config['qdrant']['port'], |
|
|
"collection_name": self.config['qdrant']['collection_name'], |
|
|
"on_disk": True |
|
|
} |
|
|
), |
|
|
llm=LlmConfig( |
|
|
provider="openai", |
|
|
config=self.config['llm']['config'] |
|
|
), |
|
|
embedder=EmbedderConfig( |
|
|
provider="openai", |
|
|
config={ |
|
|
"model": "text-embedding-v3" # 显式指定 DashScope 支持的向量模型 |
|
|
} |
|
|
) |
|
|
) |
|
|
self.local_memory = Memory(config=config) |
|
|
logger.info("✅ mem0 初始化成功(含 Embedder)") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ mem0 初始化失败:{e}") |
|
|
self.local_memory = None |
|
|
|
|
|
async def start(self): |
|
|
""" |
|
|
显式启动异步工作线程 |
|
|
必须在事件循环中调用:await mem0_client.start() |
|
|
""" |
|
|
if self._started: |
|
|
logger.debug("mem0 Client 已启动") |
|
|
return |
|
|
|
|
|
if self.config['async_write']['enabled']: |
|
|
self.async_queue = AsyncMemoryQueue( |
|
|
max_size=self.config['async_write']['queue_size'] |
|
|
) |
|
|
self.async_queue.start_worker( |
|
|
callback=self._async_write_memory, |
|
|
batch_size=self.config['async_write']['batch_size'], |
|
|
flush_interval=self.config['async_write']['flush_interval'] |
|
|
) |
|
|
self._started = True |
|
|
logger.info("✅ mem0 Client 异步工作线程已启动") |
|
|
|
|
|
# ========== Pre-Hook: 智能检索 ========== |
|
|
|
|
|
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: |
|
|
"""Pre-Hook: 对话前智能检索""" |
|
|
if not self.config['retrieval']['enabled'] or self.local_memory is None: |
|
|
return [] |
|
|
|
|
|
cache_key = f"{user_id}:{agent_id}:{query}" |
|
|
if self.config['cache']['enabled'] and cache_key in self.cache: |
|
|
cached = self.cache[cache_key] |
|
|
if time.time() - cached['time'] < self.config['cache']['ttl']: |
|
|
logger.debug(f"Cache hit: {cache_key}") |
|
|
return cached['results'] |
|
|
|
|
|
timeout_ms = self.config['retrieval']['timeout_ms'] |
|
|
|
|
|
try: |
|
|
memories = await asyncio.wait_for( |
|
|
self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']), |
|
|
timeout=timeout_ms / 1000 |
|
|
) |
|
|
|
|
|
if self.config['cache']['enabled'] and memories: |
|
|
self.cache[cache_key] = {'results': memories, 'time': time.time()} |
|
|
self._cleanup_cache() |
|
|
|
|
|
logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆") |
|
|
return memories |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)") |
|
|
return [] |
|
|
except Exception as e: |
|
|
logger.warning(f"Pre-Hook 检索失败:{e}") |
|
|
return [] |
|
|
|
|
|
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: |
|
|
""" |
|
|
执行检索 - 使用 metadata 过滤器实现维度隔离 |
|
|
""" |
|
|
if self.local_memory is None: |
|
|
return [] |
|
|
|
|
|
# 策略 1: 检索全局用户记忆 |
|
|
user_memories = [] |
|
|
if user_id: |
|
|
try: |
|
|
user_memories = await asyncio.to_thread( |
|
|
self.local_memory.search, |
|
|
query, |
|
|
user_id=user_id, |
|
|
limit=top_k |
|
|
) |
|
|
except Exception as e: |
|
|
logger.debug(f"用户记忆检索失败:{e}") |
|
|
|
|
|
# 策略 2: 检索业务域记忆(使用 metadata 过滤器) |
|
|
agent_memories = [] |
|
|
if agent_id and agent_id != 'general': |
|
|
try: |
|
|
agent_memories = await asyncio.to_thread( |
|
|
self.local_memory.search, |
|
|
query, |
|
|
user_id=user_id, |
|
|
filters={"agent_id": agent_id}, # metadata 过滤,实现垂直隔离 |
|
|
limit=top_k |
|
|
) |
|
|
except Exception as e: |
|
|
logger.debug(f"业务记忆检索失败:{e}") |
|
|
|
|
|
# 合并结果(去重) |
|
|
all_memories = {} |
|
|
for mem in user_memories + agent_memories: |
|
|
mem_id = mem.get('id') if isinstance(mem, dict) else None |
|
|
if mem_id and mem_id not in all_memories: |
|
|
all_memories[mem_id] = mem |
|
|
|
|
|
# 按置信度过滤 |
|
|
min_confidence = self.config['retrieval']['min_confidence'] |
|
|
filtered = [ |
|
|
m for m in all_memories.values() |
|
|
if m.get('score', 1.0) >= min_confidence |
|
|
] |
|
|
|
|
|
# 按置信度排序 |
|
|
filtered.sort(key=lambda x: x.get('score', 0), reverse=True) |
|
|
|
|
|
return filtered[:top_k] |
|
|
|
|
|
def format_memories_for_prompt(self, memories: List[Dict]) -> str: |
|
|
"""格式化记忆为 Prompt 片段""" |
|
|
if not memories: |
|
|
return "" |
|
|
|
|
|
prompt = "\n\n=== 相关记忆 ===\n" |
|
|
for i, mem in enumerate(memories, 1): |
|
|
memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) |
|
|
created_at = mem.get('created_at', '') if isinstance(mem, dict) else '' |
|
|
prompt += f"{i}. {memory_text}" |
|
|
if created_at: |
|
|
prompt += f" (记录于:{created_at})" |
|
|
prompt += "\n" |
|
|
prompt += "===============\n" |
|
|
|
|
|
return prompt |
|
|
|
|
|
# ========== Post-Hook: 异步写入 ========== |
|
|
|
|
|
def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None): |
|
|
"""Post-Hook: 对话后异步写入(同步方法,仅添加到队列)""" |
|
|
if not self.config['async_write']['enabled']: |
|
|
return |
|
|
|
|
|
if not user_id: |
|
|
user_id = self.config['metadata']['default_user_id'] |
|
|
if not agent_id: |
|
|
agent_id = self.config['metadata']['default_agent_id'] |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": user_message}, |
|
|
{"role": "assistant", "content": assistant_message} |
|
|
] |
|
|
|
|
|
if self.async_queue: |
|
|
self.async_queue.add({ |
|
|
'messages': messages, |
|
|
'user_id': user_id, |
|
|
'agent_id': agent_id, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
}) |
|
|
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}") |
|
|
else: |
|
|
logger.warning("异步队列未初始化") |
|
|
|
|
|
async def _async_write_memory(self, item: Dict): |
|
|
"""异步写入记忆(后台任务)""" |
|
|
if self.local_memory is None: |
|
|
return |
|
|
|
|
|
timeout_ms = self.config['async_write']['timeout_ms'] |
|
|
|
|
|
try: |
|
|
await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000) |
|
|
logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}") |
|
|
except asyncio.TimeoutError: |
|
|
logger.warning(f"异步写入超时 ({timeout_ms}ms)") |
|
|
except Exception as e: |
|
|
logger.warning(f"异步写入失败:{e}") |
|
|
|
|
|
async def _execute_write(self, item: Dict): |
|
|
""" |
|
|
执行写入 - 使用 metadata 实现维度隔离 |
|
|
关键:通过 metadata 字典传递 agent_id,而非直接参数 |
|
|
""" |
|
|
if self.local_memory is None: |
|
|
return |
|
|
|
|
|
# 构建元数据,实现业务隔离 |
|
|
custom_metadata = { |
|
|
"agent_id": item['agent_id'], |
|
|
"source": "openclaw", |
|
|
"timestamp": item.get('timestamp'), |
|
|
"business_type": item['agent_id'] |
|
|
} |
|
|
|
|
|
# 阻塞操作,放入线程池执行 |
|
|
await asyncio.to_thread( |
|
|
self.local_memory.add, |
|
|
messages=item['messages'], |
|
|
user_id=item['user_id'], # 原生支持的全局用户标识 |
|
|
metadata=custom_metadata # 注入自定义业务维度 |
|
|
) |
|
|
|
|
|
def _cleanup_cache(self): |
|
|
"""清理过期缓存""" |
|
|
if not self.config['cache']['enabled']: |
|
|
return |
|
|
|
|
|
current_time = time.time() |
|
|
ttl = self.config['cache']['ttl'] |
|
|
|
|
|
expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl] |
|
|
for key in expired_keys: |
|
|
del self.cache[key] |
|
|
|
|
|
if len(self.cache) > self.config['cache']['max_size']: |
|
|
oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']] |
|
|
for key in oldest_keys: |
|
|
del self.cache[key] |
|
|
|
|
|
def get_status(self) -> Dict: |
|
|
"""获取状态""" |
|
|
return { |
|
|
"initialized": self.local_memory is not None, |
|
|
"started": self._started, |
|
|
"async_queue_enabled": self.config['async_write']['enabled'], |
|
|
"queue_size": len(self.async_queue.queue) if self.async_queue else 0, |
|
|
"cache_size": len(self.cache), |
|
|
"qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}" |
|
|
} |
|
|
|
|
|
async def shutdown(self): |
|
|
"""优雅关闭""" |
|
|
if self.async_queue: |
|
|
await self.async_queue.stop() |
|
|
logger.info("mem0 Client 已关闭") |
|
|
|
|
|
|
|
|
# 全局客户端实例 |
|
|
mem0_client = Mem0Client()
|
|
|
|