#!/usr/bin/env python3 """ mem0 Client for OpenClaw - 生产级纯异步架构 Pre-Hook 检索注入 + Post-Hook 异步写入 三级可见性隔离 (public / project / private) 记忆衰减 (expiration_date) + 智能写入过滤 """ import os import re import asyncio import logging import time import yaml from typing import List, Dict, Any from collections import deque from datetime import datetime, timedelta from pathlib import Path # ========== DashScope 环境变量配置 ========== # Gemini Pro Embedding 模型:text-embedding-v4 (1024 维度) os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' # 关键:兼容模式需要此变量 _dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') if not _dashscope_key: _dashscope_key = os.getenv('DASHSCOPE_API_KEY', '') if _dashscope_key: os.environ['OPENAI_API_KEY'] = _dashscope_key elif not os.environ.get('OPENAI_API_KEY'): logging.warning("MEM0_DASHSCOPE_API_KEY not set; mem0 embedding/LLM calls will fail") try: from mem0 import Memory from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig, EmbedderConfig except ImportError as e: print(f"⚠️ mem0ai 导入失败:{e}") Memory = None logger = logging.getLogger(__name__) class AsyncMemoryQueue: """纯异步记忆写入队列""" def __init__(self, max_size: int = 100): self.queue = deque(maxlen=max_size) self.lock = asyncio.Lock() self.running = False self._worker_task = None self.callback = None self.batch_size = 10 self.flush_interval = 60 def add(self, item: Dict[str, Any]): """添加任务到队列(同步方法)""" try: if len(self.queue) < self.queue.maxlen: self.queue.append({ 'messages': item['messages'], 'user_id': item['user_id'], 'agent_id': item['agent_id'], 'visibility': item.get('visibility', 'private'), 'project_id': item.get('project_id'), 'memory_type': item.get('memory_type', 'session'), 'timestamp': item.get('timestamp', datetime.now().isoformat()) }) else: logger.warning("异步队列已满,丢弃旧任务") except Exception as e: logger.error(f"队列添加失败:{e}") async def get_batch(self, batch_size: int) -> List[Dict]: """获取批量任务(异步方法)""" async with self.lock: batch = [] while len(batch) < batch_size and self.queue: batch.append(self.queue.popleft()) return batch def start_worker(self, callback, batch_size: int, flush_interval: int): """启动异步后台任务""" self.running = True self.callback = callback self.batch_size = batch_size self.flush_interval = flush_interval self._worker_task = asyncio.create_task(self._worker_loop()) logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)") async def _worker_loop(self): """异步工作循环""" last_flush = time.time() while self.running: try: if self.queue: batch = await self.get_batch(self.batch_size) if batch: asyncio.create_task(self._process_batch(batch)) if time.time() - last_flush > self.flush_interval: if self.queue: batch = await self.get_batch(self.batch_size) if batch: asyncio.create_task(self._process_batch(batch)) last_flush = time.time() await asyncio.sleep(1) except asyncio.CancelledError: logger.info("异步工作线程已取消") break except Exception as e: logger.error(f"异步工作线程错误:{e}") await asyncio.sleep(5) async def _process_batch(self, batch: List[Dict]): """处理批量任务""" try: logger.debug(f"开始处理批量任务:{len(batch)} 条") for item in batch: await self.callback(item) logger.debug(f"批量任务处理完成") except Exception as e: logger.error(f"批量处理失败:{e}") async def stop(self): """优雅关闭""" self.running = False if self._worker_task: self._worker_task.cancel() try: await self._worker_task except asyncio.CancelledError: pass logger.info("异步工作线程已关闭") EXPIRATION_MAP = { 'session': timedelta(days=7), 'chat_summary': timedelta(days=30), 'preference': None, 'knowledge': None, } SKIP_PATTERNS = re.compile( r'^(好的|收到|OK|ok|嗯|行|没问题|感谢|谢谢|了解|明白|知道了|👍|✅|❌)$', re.IGNORECASE ) SYSTEM_CMD_PATTERN = re.compile(r'^/') MEMORY_KEYWORDS = re.compile( r'(记住|以后|偏好|配置|设置|规则|永远|始终|总是|不要|禁止)', ) PUBLIC_KEYWORDS = re.compile( r'(所有人|通知|全局|公告|大家|集群)', ) def _load_project_registry() -> Dict: """从 project_registry.yaml 加载项目注册表""" registry_path = Path(__file__).parent / 'project_registry.yaml' if registry_path.exists(): try: with open(registry_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} except Exception: pass return {} def get_agent_projects(agent_id: str) -> List[str]: """查询一个 agent 所属的所有 project_id""" registry = _load_project_registry() projects = registry.get('projects', {}) result = [] for pid, pconf in projects.items(): members = pconf.get('members', []) if '*' in members or agent_id in members: result.append(pid) return result class Mem0Client: """ 生产级 mem0 客户端 纯异步架构 + 三级可见性 + 记忆衰减 + 智能写入过滤 """ def __init__(self, config: Dict = None): self.config = config or self._load_default_config() self.local_memory = None self.async_queue = None self.cache = {} self._started = False self._init_memory() def _load_default_config(self) -> Dict: """加载默认配置 - 单库融合架构""" return { "qdrant": { "host": os.getenv('MEM0_QDRANT_HOST', 'localhost'), "port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), "collection_name": "mem0_v4_shared" # 统一共享 Collection(陈医生/张大师共用) }, "llm": { "provider": "openai", "config": { "model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus') } }, "embedder": { "provider": "openai", "config": { "model": os.getenv('MEM0_EMBEDDER_MODEL', 'text-embedding-v4'), "dimensions": 1024 # DashScope text-embedding-v4 支持的最大维度 } }, "retrieval": { "enabled": True, "top_k": 5, "min_confidence": 0.7, "timeout_ms": 5000 }, "async_write": { "enabled": True, "queue_size": 100, "batch_size": 10, "flush_interval": 60, "timeout_ms": 5000 }, "cache": { "enabled": True, "ttl": 300, "max_size": 1000 }, "fallback": { "enabled": True, "log_level": "WARNING", "retry_attempts": 2 }, "metadata": { "default_user_id": "default", "default_agent_id": "general" }, "write_filter": { "enabled": True, "min_user_message_length": 5, } } def _init_memory(self): """初始化 mem0(同步操作)- 三位一体完整配置""" if Memory is None: logger.warning("mem0ai 未安装") return try: config = MemoryConfig( vector_store=VectorStoreConfig( provider="qdrant", config={ "host": self.config['qdrant']['host'], "port": self.config['qdrant']['port'], "collection_name": self.config['qdrant']['collection_name'], "on_disk": True, "embedding_model_dims": 1024 # 强制同步 Qdrant 集合维度 } ), llm=LlmConfig( provider="openai", config=self.config['llm']['config'] ), embedder=EmbedderConfig( provider="openai", config={ "model": "text-embedding-v4", "embedding_dims": 1024 # 核心修复:强制覆盖默认的 1536 维度 # api_base 和 api_key 通过环境变量 OPENAI_BASE_URL 和 OPENAI_API_KEY 读取 } ) ) self.local_memory = Memory(config=config) logger.info("✅ mem0 初始化成功(含 Embedder,1024 维度)") except Exception as e: logger.error(f"❌ mem0 初始化失败:{e}") self.local_memory = None async def start(self): """ 显式启动异步工作线程 必须在事件循环中调用:await mem0_client.start() """ if self._started: logger.debug("mem0 Client 已启动") return if self.config['async_write']['enabled']: self.async_queue = AsyncMemoryQueue( max_size=self.config['async_write']['queue_size'] ) self.async_queue.start_worker( callback=self._async_write_memory, batch_size=self.config['async_write']['batch_size'], flush_interval=self.config['async_write']['flush_interval'] ) self._started = True logger.info("✅ mem0 Client 异步工作线程已启动") # ========== Pre-Hook: 智能检索 ========== async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: """Pre-Hook: 对话前智能检索""" if not self.config['retrieval']['enabled'] or self.local_memory is None: return [] cache_key = f"{user_id}:{agent_id}:{query}" if self.config['cache']['enabled'] and cache_key in self.cache: cached = self.cache[cache_key] if time.time() - cached['time'] < self.config['cache']['ttl']: logger.debug(f"Cache hit: {cache_key}") return cached['results'] timeout_ms = self.config['retrieval']['timeout_ms'] try: memories = await asyncio.wait_for( self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']), timeout=timeout_ms / 1000 ) if self.config['cache']['enabled'] and memories: self.cache[cache_key] = {'results': memories, 'time': time.time()} self._cleanup_cache() logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆") return memories except asyncio.TimeoutError: logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)") return [] except Exception as e: logger.warning(f"Pre-Hook 检索失败:{e}") return [] async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: """ 三阶段检索 — 按可见性分层,合并去重 Phase 1: public (所有 agent 可见) Phase 2: project (同 project_id 成员可见) Phase 3: private (仅 agent_id 本人可见) """ if self.local_memory is None: return [] all_memories: Dict[str, Dict] = {} per_phase = max(top_k, 3) # Phase 1: 检索 public 记忆 try: public_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={"visibility": "public"}, limit=per_phase ) for mem in (public_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Public 记忆检索失败:{e}") # Phase 2: 检索 project 记忆 (agent 所属的所有项目) if agent_id and agent_id != 'general': agent_projects = get_agent_projects(agent_id) for project_id in agent_projects: if project_id == 'global': continue try: proj_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={ "visibility": "project", "project_id": project_id, }, limit=per_phase ) for mem in (proj_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Project({project_id}) 记忆检索失败:{e}") # Phase 3: 检索 private 记忆 if agent_id and agent_id != 'general': try: private_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={ "visibility": "private", "agent_id": agent_id, }, limit=per_phase ) for mem in (private_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Private 记忆检索失败:{e}") # Fallback: 兼容旧数据(无 visibility 字段) if user_id: try: legacy_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={"agent_id": agent_id} if agent_id and agent_id != 'general' else None, limit=per_phase ) for mem in (legacy_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Legacy 记忆检索失败:{e}") min_confidence = self.config['retrieval']['min_confidence'] filtered = [ m for m in all_memories.values() if m.get('score', 1.0) >= min_confidence ] filtered.sort(key=lambda x: x.get('score', 0), reverse=True) return filtered[:top_k] def format_memories_for_prompt(self, memories: List[Dict]) -> str: """格式化记忆为 Prompt 片段""" if not memories: return "" prompt = "\n\n=== 相关记忆 ===\n" for i, mem in enumerate(memories, 1): memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) created_at = mem.get('created_at', '') if isinstance(mem, dict) else '' prompt += f"{i}. {memory_text}" if created_at: prompt += f" (记录于:{created_at})" prompt += "\n" prompt += "===============\n" return prompt # ========== Post-Hook: 异步写入 ========== def _should_skip_memory(self, user_message: str, assistant_message: str) -> bool: """判断是否应跳过此对话的记忆写入""" if not self.config.get('write_filter', {}).get('enabled', True): return False min_len = self.config.get('write_filter', {}).get('min_user_message_length', 5) if len(user_message.strip()) < min_len: return True if SKIP_PATTERNS.match(user_message.strip()): return True if SYSTEM_CMD_PATTERN.match(user_message.strip()): return True return False def _classify_memory_type(self, user_message: str, assistant_message: str) -> str: """自动分类记忆类型,决定过期策略""" combined = user_message + ' ' + assistant_message if MEMORY_KEYWORDS.search(combined): return 'preference' if any(kw in combined for kw in ('部署', '配置', '架构', '端口', '安装', '版本')): return 'knowledge' return 'session' def _classify_visibility(self, user_message: str, assistant_message: str, agent_id: str = None) -> str: """自动分类记忆可见性""" combined = user_message + ' ' + assistant_message if PUBLIC_KEYWORDS.search(combined): return 'public' return 'private' def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None, visibility: str = None, project_id: str = None, memory_type: str = None): """Post-Hook: 对话后异步写入(同步方法,仅添加到队列) 支持三级可见性 (public/project/private) 和记忆衰减 (expiration_date)。 内置智能写入过滤,跳过无价值对话。 """ if not self.config['async_write']['enabled']: return if self._should_skip_memory(user_message, assistant_message): logger.debug(f"Post-Hook 跳过(写入过滤):{user_message[:30]}") return if not user_id: user_id = self.config['metadata']['default_user_id'] if not agent_id: agent_id = self.config['metadata']['default_agent_id'] if not memory_type: memory_type = self._classify_memory_type(user_message, assistant_message) if not visibility: visibility = self._classify_visibility(user_message, assistant_message, agent_id) messages = [ {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message} ] if self.async_queue: self.async_queue.add({ 'messages': messages, 'user_id': user_id, 'agent_id': agent_id, 'visibility': visibility, 'project_id': project_id, 'memory_type': memory_type, 'timestamp': datetime.now().isoformat() }) logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}, " f"visibility={visibility}, type={memory_type}") else: logger.warning("异步队列未初始化") async def _async_write_memory(self, item: Dict): """异步写入记忆(后台任务)""" if self.local_memory is None: return timeout_ms = self.config['async_write']['timeout_ms'] try: await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000) logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}") except asyncio.TimeoutError: logger.warning(f"异步写入超时 ({timeout_ms}ms)") except Exception as e: logger.warning(f"异步写入失败:{e}") async def _execute_write(self, item: Dict): """ 执行写入 — 三级可见性 + 记忆衰减 metadata 携带 visibility / project_id / agent_id expiration_date 根据 memory_type 自动设置 """ if self.local_memory is None: return visibility = item.get('visibility', 'private') memory_type = item.get('memory_type', 'session') custom_metadata = { "agent_id": item['agent_id'], "visibility": visibility, "project_id": item.get('project_id') or '', "business_type": item.get('business_type', item['agent_id']), "memory_type": memory_type, "source": "openclaw", "timestamp": item.get('timestamp'), } ttl = EXPIRATION_MAP.get(memory_type) expiration_date = (datetime.now() + ttl).isoformat() if ttl else None add_kwargs = dict( messages=item['messages'], user_id=item['user_id'], agent_id=item['agent_id'], metadata=custom_metadata, ) if expiration_date: add_kwargs['expiration_date'] = expiration_date await asyncio.to_thread(self.local_memory.add, **add_kwargs) def _cleanup_cache(self): """清理过期缓存""" if not self.config['cache']['enabled']: return current_time = time.time() ttl = self.config['cache']['ttl'] expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl] for key in expired_keys: del self.cache[key] if len(self.cache) > self.config['cache']['max_size']: oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']] for key in oldest_keys: del self.cache[key] def get_status(self) -> Dict: """获取状态""" return { "initialized": self.local_memory is not None, "started": self._started, "async_queue_enabled": self.config['async_write']['enabled'], "queue_size": len(self.async_queue.queue) if self.async_queue else 0, "cache_size": len(self.cache), "qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}" } # ========== Knowledge Publishing (Hub Agent) ========== async def publish_knowledge(self, content, category='knowledge', visibility='public', project_id=None, agent_id='main'): """Publish knowledge/best practices to shared memory. Used by hub agent to share with all agents (public) or project teams.""" if visibility == 'project' and not project_id: raise ValueError("project visibility requires project_id") item = { 'messages': [{'role': 'system', 'content': content}], 'user_id': 'system', 'agent_id': agent_id, 'visibility': visibility, 'project_id': project_id, 'memory_type': category, 'timestamp': datetime.now().isoformat(), } await self._execute_write(item) logger.info(f"Published {category} ({visibility}): {content[:80]}...") # ========== Cold Start (Three-Phase) ========== async def cold_start_search(self, agent_id='main', user_id='default', top_k=5): """Three-phase cold start: public -> project -> private. Returns merged memories ordered by phase then score.""" if self.local_memory is None: return [] all_mems = {} phase = {} for q in ["system best practices and conventions", "shared configuration and architecture decisions"]: try: r = await asyncio.to_thread( self.local_memory.search, q, user_id=user_id, filters={"visibility": "public"}, limit=top_k) for m in (r or []): mid = m.get('id') if isinstance(m, dict) else None if mid and mid not in all_mems: all_mems[mid] = m phase[mid] = 0 except Exception as e: logger.debug(f"Cold start public failed: {e}") for pid in get_agent_projects(agent_id): if pid == 'global': continue for q in ["project guidelines and shared knowledge", "recent project decisions and updates"]: try: r = await asyncio.to_thread( self.local_memory.search, q, user_id=user_id, filters={"visibility": "project", "project_id": pid}, limit=top_k) for m in (r or []): mid = m.get('id') if isinstance(m, dict) else None if mid and mid not in all_mems: all_mems[mid] = m phase[mid] = 1 except Exception as e: logger.debug(f"Cold start project({pid}) failed: {e}") for q in ["current active tasks deployment progress", "pending work items todos", "recent important decisions configurations"]: try: r = await asyncio.to_thread( self.local_memory.search, q, user_id=user_id, filters={"visibility": "private", "agent_id": agent_id}, limit=top_k) for m in (r or []): mid = m.get('id') if isinstance(m, dict) else None if mid and mid not in all_mems: all_mems[mid] = m phase[mid] = 2 except Exception as e: logger.debug(f"Cold start private failed: {e}") mc = self.config['retrieval']['min_confidence'] out = [m for m in all_mems.values() if m.get('score', 1.0) >= mc] out.sort(key=lambda m: (phase.get(m.get('id', ''), 9), -m.get('score', 0))) return out[:top_k] async def shutdown(self): if self.async_queue: await self.async_queue.stop() logger.info("mem0 Client shutdown complete") # 全局客户端实例 mem0_client = Mem0Client()