#!/usr/bin/env python3 """ mem0 Client for OpenClaw - 生产级纯异步架构 Pre-Hook 检索注入 + Post-Hook 异步写入 """ import os import asyncio import logging import time from typing import List, Dict, Optional, Any from collections import deque from datetime import datetime # 设置环境变量 os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' os.environ['OPENAI_API_KEY'] = os.getenv('MEM0_DASHSCOPE_API_KEY', 'sk-c1715ee0479841399fd359c574647648') try: from mem0 import Memory from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig except ImportError as e: print(f"⚠️ mem0ai 导入失败:{e}") Memory = None logger = logging.getLogger(__name__) class AsyncMemoryQueue: """纯异步记忆写入队列""" def __init__(self, max_size: int = 100): self.queue = deque(maxlen=max_size) self.lock = asyncio.Lock() self.running = False self._worker_task = None self.callback = None self.batch_size = 10 self.flush_interval = 60 def add(self, item: Dict[str, Any]): """添加任务到队列""" try: if len(self.queue) < self.queue.maxlen: self.queue.append({ 'messages': item['messages'], 'user_id': item['user_id'], 'agent_id': item['agent_id'], 'timestamp': datetime.now().isoformat() }) else: logger.warning("异步队列已满,丢弃旧任务") except Exception as e: logger.error(f"队列添加失败:{e}") async def get_batch(self, batch_size: int) -> List[Dict]: """获取批量任务""" async with self.lock: batch = [] while len(batch) < batch_size and self.queue: batch.append(self.queue.popleft()) return batch def start_worker(self, callback, batch_size: int, flush_interval: int): """启动异步后台任务""" self.running = True self.callback = callback self.batch_size = batch_size self.flush_interval = flush_interval self._worker_task = asyncio.create_task(self._worker_loop()) logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)") async def _worker_loop(self): """异步工作循环""" last_flush = time.time() while self.running: try: if self.queue: batch = await self.get_batch(self.batch_size) if batch: asyncio.create_task(self._process_batch(batch)) if time.time() - last_flush > self.flush_interval: if self.queue: batch = await self.get_batch(self.batch_size) if batch: asyncio.create_task(self._process_batch(batch)) last_flush = time.time() await asyncio.sleep(1) except asyncio.CancelledError: logger.info("异步工作线程已取消") break except Exception as e: logger.error(f"异步工作线程错误:{e}") await asyncio.sleep(5) async def _process_batch(self, batch: List[Dict]): """处理批量任务""" try: logger.debug(f"开始处理批量任务:{len(batch)} 条") for item in batch: await self.callback(item) logger.debug(f"批量任务处理完成") except Exception as e: logger.error(f"批量处理失败:{e}") async def stop(self): """优雅关闭""" self.running = False if self._worker_task: self._worker_task.cancel() try: await self._worker_task except asyncio.CancelledError: pass logger.info("异步工作线程已关闭") class Mem0Client: """生产级 mem0 客户端""" def __init__(self, config: Dict = None): self.config = config or self._load_default_config() self.local_memory = None self.async_queue = None self.cache = {} self._init_memory() self._start_async_worker() def _load_default_config(self) -> Dict: """加载默认配置""" return { "qdrant": { "host": os.getenv('MEM0_QDRANT_HOST', '100.115.94.1'), "port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), "collection_name": "mem0_shared" }, "llm": { "provider": "openai", "config": {"model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus')} }, "retrieval": { "enabled": True, "top_k": 5, "min_confidence": 0.7, "timeout_ms": 2000 }, "async_write": { "enabled": True, "queue_size": 100, "batch_size": 10, "flush_interval": 60, "timeout_ms": 5000 }, "cache": { "enabled": True, "ttl": 300, "max_size": 1000 }, "fallback": { "enabled": True, "log_level": "WARNING", "retry_attempts": 2 }, "metadata": { "default_user_id": "default", "default_agent_id": "general" } } def _init_memory(self): """初始化 mem0""" if Memory is None: logger.warning("mem0ai 未安装") return try: config = MemoryConfig( vector_store=VectorStoreConfig( provider="qdrant", config={ "host": self.config['qdrant']['host'], "port": self.config['qdrant']['port'], "collection_name": self.config['qdrant']['collection_name'], "on_disk": True } ), llm=LlmConfig( provider="openai", config=self.config['llm']['config'] ) ) self.local_memory = Memory(config=config) logger.info("✅ mem0 初始化成功") except Exception as e: logger.error(f"❌ mem0 初始化失败:{e}") self.local_memory = None def _start_async_worker(self): """启动异步写入工作线程""" if not self.config['async_write']['enabled']: return try: loop = asyncio.get_event_loop() self.async_queue = AsyncMemoryQueue( max_size=self.config['async_write']['queue_size'] ) self.async_queue.start_worker( callback=self._async_write_memory, batch_size=self.config['async_write']['batch_size'], flush_interval=self.config['async_write']['flush_interval'] ) except RuntimeError: logger.debug("当前无事件循环,异步队列将在首次使用时初始化") async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: """Pre-Hook: 对话前智能检索""" if not self.config['retrieval']['enabled'] or self.local_memory is None: return [] cache_key = f"{user_id}:{agent_id}:{query}" if self.config['cache']['enabled'] and cache_key in self.cache: cached = self.cache[cache_key] if time.time() - cached['time'] < self.config['cache']['ttl']: logger.debug(f"Cache hit: {cache_key}") return cached['results'] timeout_ms = self.config['retrieval']['timeout_ms'] try: memories = await asyncio.wait_for( self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']), timeout=timeout_ms / 1000 ) if self.config['cache']['enabled'] and memories: self.cache[cache_key] = {'results': memories, 'time': time.time()} self._cleanup_cache() logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆") return memories except asyncio.TimeoutError: logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)") return [] except Exception as e: logger.warning(f"Pre-Hook 检索失败:{e}") return [] async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: """执行检索(使用 asyncio.to_thread 隔离阻塞)""" if self.local_memory is None: return [] user_memories = [] if user_id: try: user_memories = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, limit=top_k ) except Exception as e: logger.debug(f"用户记忆检索失败:{e}") agent_memories = [] if agent_id and agent_id != 'general': try: agent_memories = await asyncio.to_thread( self.local_memory.search, query, user_id=f"{user_id}:{agent_id}" if user_id else agent_id, limit=top_k ) except Exception as e: logger.debug(f"业务记忆检索失败:{e}") all_memories = {} for mem in user_memories + agent_memories: mem_id = mem.get('id') if isinstance(mem, dict) else None if mem_id and mem_id not in all_memories: all_memories[mem_id] = mem min_confidence = self.config['retrieval']['min_confidence'] filtered = [m for m in all_memories.values() if m.get('score', 1.0) >= min_confidence] filtered.sort(key=lambda x: x.get('score', 0), reverse=True) return filtered[:top_k] def format_memories_for_prompt(self, memories: List[Dict]) -> str: """格式化记忆为 Prompt 片段""" if not memories: return "" prompt = "\n\n=== 相关记忆 ===\n" for i, mem in enumerate(memories, 1): memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) created_at = mem.get('created_at', '') if isinstance(mem, dict) else '' prompt += f"{i}. {memory_text}" if created_at: prompt += f" (记录于:{created_at})" prompt += "\n" prompt += "===============\n" return prompt def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None): """Post-Hook: 对话后异步写入""" if not self.config['async_write']['enabled']: return if not user_id: user_id = self.config['metadata']['default_user_id'] if not agent_id: agent_id = self.config['metadata']['default_agent_id'] messages = [ {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message} ] if self.async_queue: self.async_queue.add({ 'messages': messages, 'user_id': user_id, 'agent_id': agent_id }) logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}") else: logger.warning("异步队列未初始化") async def _async_write_memory(self, item: Dict): """异步写入记忆""" if self.local_memory is None: return timeout_ms = self.config['async_write']['timeout_ms'] try: await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000) logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}") except asyncio.TimeoutError: logger.warning(f"异步写入超时 ({timeout_ms}ms)") except Exception as e: logger.warning(f"异步写入失败:{e}") async def _execute_write(self, item: Dict): """执行写入(使用 asyncio.to_thread 隔离阻塞)""" if self.local_memory is None: return full_user_id = f"{item['user_id']}:{item['agent_id']}" await asyncio.to_thread( self.local_memory.add, messages=item['messages'], user_id=full_user_id, agent_id=item['agent_id'] ) def _cleanup_cache(self): """清理过期缓存""" if not self.config['cache']['enabled']: return current_time = time.time() ttl = self.config['cache']['ttl'] expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl] for key in expired_keys: del self.cache[key] if len(self.cache) > self.config['cache']['max_size']: oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']] for key in oldest_keys: del self.cache[key] def get_status(self) -> Dict: """获取状态""" return { "initialized": self.local_memory is not None, "async_queue_enabled": self.config['async_write']['enabled'], "queue_size": len(self.async_queue.queue) if self.async_queue else 0, "cache_size": len(self.cache), "qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}" } async def shutdown(self): """优雅关闭""" if self.async_queue: await self.async_queue.stop() logger.info("mem0 Client 已关闭") # 全局客户端实例 mem0_client = Mem0Client()