#!/usr/bin/env python3 """ mem0 Client for OpenClaw - 生产级纯异步架构 Pre-Hook 检索注入 + Post-Hook 异步写入 三级可见性隔离 (public / project / private) 记忆衰减 (expiration_date) + 智能写入过滤 """ import os import re import asyncio import logging import time import yaml from typing import List, Dict, Any from collections import deque from datetime import datetime, timedelta from pathlib import Path # ========== LLM/Embedding API 配置 ========== # 优先级链:OneAPI (.env) > 已有环境变量 > DashScope 默认 _api_base = (os.getenv('LLM_BASE_URL') or os.getenv('OPENAI_BASE_URL') or os.getenv('OPENAI_API_BASE') or 'https://dashscope.aliyuncs.com/compatible-mode/v1') os.environ['OPENAI_API_BASE'] = _api_base os.environ['OPENAI_BASE_URL'] = _api_base _api_key = (os.getenv('LLM_API_KEY') or os.getenv('MEM0_DASHSCOPE_API_KEY') or os.getenv('DASHSCOPE_API_KEY') or '') if _api_key: os.environ['OPENAI_API_KEY'] = _api_key elif not os.environ.get('OPENAI_API_KEY'): logging.warning("No API key found (LLM_API_KEY / MEM0_DASHSCOPE_API_KEY); mem0 calls will fail") try: from mem0 import Memory from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig, EmbedderConfig except ImportError as e: print(f"⚠️ mem0ai 导入失败:{e}") Memory = None from local_search import LocalSearchFallback logger = logging.getLogger(__name__) class AsyncMemoryQueue: """纯异步记忆写入队列""" def __init__(self, max_size: int = 100): self.queue = deque(maxlen=max_size) self.lock = asyncio.Lock() self.running = False self._worker_task = None self.callback = None self.batch_size = 10 self.flush_interval = 60 def add(self, item: Dict[str, Any]): """添加任务到队列(同步方法)""" try: if len(self.queue) < self.queue.maxlen: self.queue.append({ 'messages': item['messages'], 'user_id': item['user_id'], 'agent_id': item['agent_id'], 'visibility': item.get('visibility', 'private'), 'project_id': item.get('project_id'), 'memory_type': item.get('memory_type', 'session'), 'timestamp': item.get('timestamp', datetime.now().isoformat()) }) else: logger.warning("异步队列已满,丢弃旧任务") except Exception as e: logger.error(f"队列添加失败:{e}") async def get_batch(self, batch_size: int) -> List[Dict]: """获取批量任务(异步方法)""" async with self.lock: batch = [] while len(batch) < batch_size and self.queue: batch.append(self.queue.popleft()) return batch def start_worker(self, callback, batch_size: int, flush_interval: int): """启动异步后台任务""" self.running = True self.callback = callback self.batch_size = batch_size self.flush_interval = flush_interval self._worker_task = asyncio.create_task(self._worker_loop()) logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)") async def _worker_loop(self): """异步工作循环""" last_flush = time.time() while self.running: try: if self.queue: batch = await self.get_batch(self.batch_size) if batch: asyncio.create_task(self._process_batch(batch)) if time.time() - last_flush > self.flush_interval: if self.queue: batch = await self.get_batch(self.batch_size) if batch: asyncio.create_task(self._process_batch(batch)) last_flush = time.time() await asyncio.sleep(1) except asyncio.CancelledError: logger.info("异步工作线程已取消") break except Exception as e: logger.error(f"异步工作线程错误:{e}") await asyncio.sleep(5) async def _process_batch(self, batch: List[Dict]): """处理批量任务""" try: logger.debug(f"开始处理批量任务:{len(batch)} 条") for item in batch: await self.callback(item) logger.debug(f"批量任务处理完成") except Exception as e: logger.error(f"批量处理失败:{e}") async def stop(self): """优雅关闭""" self.running = False if self._worker_task: self._worker_task.cancel() try: await self._worker_task except asyncio.CancelledError: pass logger.info("异步工作线程已关闭") EXPIRATION_MAP = { 'session': timedelta(days=7), 'chat_summary': timedelta(days=30), 'preference': None, 'knowledge': None, } SKIP_PATTERNS = re.compile( r'^(好的|收到|OK|ok|嗯|行|没问题|感谢|谢谢|了解|明白|知道了|👍|✅|❌)$', re.IGNORECASE ) SYSTEM_CMD_PATTERN = re.compile(r'^/') MEMORY_KEYWORDS = re.compile( r'(记住|以后|偏好|配置|设置|规则|永远|始终|总是|不要|禁止)', ) PUBLIC_KEYWORDS = re.compile( r'(所有人|通知|全局|公告|大家|集群)', ) PROJECT_KEYWORDS_CACHE: Dict[str, List[str]] = {} def _build_project_keywords() -> Dict[str, List[str]]: """从 project_registry.yaml 提取每个项目的关键词用于自动分类""" global PROJECT_KEYWORDS_CACHE if PROJECT_KEYWORDS_CACHE: return PROJECT_KEYWORDS_CACHE registry = _load_project_registry() projects = registry.get('projects', {}) result: Dict[str, List[str]] = {} for pid, pconf in projects.items(): if pid == 'global': continue kws = [pid] name = pconf.get('name', '') if name and len(name) >= 2: kws.append(name) desc = pconf.get('description', '') for seg in re.split(r'[,、。/\s]+', desc): seg = seg.strip() if len(seg) >= 2: kws.append(seg) result[pid] = kws PROJECT_KEYWORDS_CACHE = result return result def _load_project_registry() -> Dict: """从 project_registry.yaml 加载项目注册表""" registry_path = Path(__file__).parent / 'project_registry.yaml' if registry_path.exists(): try: with open(registry_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} except Exception: pass return {} def get_agent_projects(agent_id: str) -> List[str]: """查询一个 agent 所属的所有 project_id""" registry = _load_project_registry() projects = registry.get('projects', {}) result = [] for pid, pconf in projects.items(): members = pconf.get('members', []) if '*' in members or agent_id in members: result.append(pid) return result class Mem0Client: """ 生产级 mem0 客户端 纯异步架构 + 三级可见性 + 记忆衰减 + 智能写入过滤 """ def __init__(self, config: Dict = None): self.config = config or self._load_default_config() self.local_memory = None self.async_queue = None self.cache = {} self._started = False self._local_search: Dict[str, LocalSearchFallback] = {} self._init_memory() def _load_default_config(self) -> Dict: """加载默认配置 - 单库融合架构""" return { "qdrant": { "host": os.getenv('MEM0_QDRANT_HOST', 'localhost'), "port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), "collection_name": "mem0_v4_shared" # 统一共享 Collection(陈医生/张大师共用) }, "llm": { "provider": "openai", "config": { "model": os.getenv('MEM0_LLM_MODEL') or os.getenv('LLM_MODEL_ID', 'qwen-plus') } }, "embedder": { "provider": "openai", "config": { "model": os.getenv('MEM0_EMBEDDER_MODEL') or os.getenv('EMBEDDING_MODEL_ID', 'text-embedding-v4'), "dimensions": 1024 # DashScope text-embedding-v4 支持的最大维度 } }, "retrieval": { "enabled": True, "top_k": 5, "min_confidence": 0.7, "timeout_ms": 5000 }, "async_write": { "enabled": True, "queue_size": 100, "batch_size": 10, "flush_interval": 60, "timeout_ms": 5000 }, "cache": { "enabled": True, "ttl": 300, "max_size": 1000 }, "fallback": { "enabled": True, "log_level": "WARNING", "retry_attempts": 2 }, "metadata": { "default_user_id": "default", "default_agent_id": "general" }, "write_filter": { "enabled": True, "min_user_message_length": 5, } } def _init_memory(self): """初始化 mem0(同步操作)- 三位一体完整配置""" if Memory is None: logger.warning("mem0ai 未安装") return try: config = MemoryConfig( vector_store=VectorStoreConfig( provider="qdrant", config={ "host": self.config['qdrant']['host'], "port": self.config['qdrant']['port'], "collection_name": self.config['qdrant']['collection_name'], "on_disk": True, "embedding_model_dims": 1024 # 强制同步 Qdrant 集合维度 } ), llm=LlmConfig( provider="openai", config=self.config['llm']['config'] ), embedder=EmbedderConfig( provider="openai", config={ "model": os.getenv('MEM0_EMBEDDER_MODEL') or os.getenv('EMBEDDING_MODEL_ID', 'text-embedding-v4'), "embedding_dims": 1024 # 核心修复:强制覆盖默认的 1536 维度 } ) ) self.local_memory = Memory(config=config) logger.info("✅ mem0 初始化成功(含 Embedder,1024 维度)") except Exception as e: logger.error(f"❌ mem0 初始化失败:{e}") self.local_memory = None async def start(self): """ 显式启动异步工作线程 必须在事件循环中调用:await mem0_client.start() """ if self._started: logger.debug("mem0 Client 已启动") return if self.config['async_write']['enabled']: self.async_queue = AsyncMemoryQueue( max_size=self.config['async_write']['queue_size'] ) self.async_queue.start_worker( callback=self._async_write_memory, batch_size=self.config['async_write']['batch_size'], flush_interval=self.config['async_write']['flush_interval'] ) self._started = True logger.info("✅ mem0 Client 异步工作线程已启动") # ========== Pre-Hook: 智能检索 ========== async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: """Pre-Hook: 对话前智能检索""" if not self.config['retrieval']['enabled'] or self.local_memory is None: return [] cache_key = f"{user_id}:{agent_id}:{query}" if self.config['cache']['enabled'] and cache_key in self.cache: cached = self.cache[cache_key] if time.time() - cached['time'] < self.config['cache']['ttl']: logger.debug(f"Cache hit: {cache_key}") return cached['results'] timeout_ms = self.config['retrieval']['timeout_ms'] try: memories = await asyncio.wait_for( self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']), timeout=timeout_ms / 1000 ) if self.config['cache']['enabled'] and memories: self.cache[cache_key] = {'results': memories, 'time': time.time()} self._cleanup_cache() logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆") return memories except asyncio.TimeoutError: logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)") return [] except Exception as e: logger.warning(f"Pre-Hook 检索失败:{e}") return [] def _get_local_search(self, agent_id: str) -> LocalSearchFallback: """获取或创建 Layer 3 FTS5 实例(懒初始化 + 自动建索引)""" aid = agent_id or 'main' if aid not in self._local_search: fb = LocalSearchFallback(agent_id=aid) fb.rebuild_index() self._local_search[aid] = fb return self._local_search[aid] async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: """ 三阶段检索 — 按可见性分层,合并去重 Phase 1: public (所有 agent 可见) Phase 2: project (同 project_id 成员可见) Phase 3: private (仅 agent_id 本人可见) Layer 3 Fallback: Qdrant 完全不可达时自动切换 FTS5 本地检索 """ all_memories: Dict[str, Dict] = {} per_phase = max(top_k, 3) qdrant_ok = False # 跟踪 Qdrant 是否至少有一次成功 if self.local_memory is None: # mem0 未初始化(Qdrant 不可达),直接走 Layer 3 logger.warning("mem0 未初始化,直接使用 Layer 3 FTS5 本地检索") try: fb = await asyncio.to_thread(self._get_local_search, agent_id) local_results = await asyncio.to_thread(fb.search, query, top_k) return [{ 'id': f"fts5:{r.get('source', '')}:{r.get('title', '')}", 'memory': r.get('snippet', ''), 'score': 0.5, 'metadata': {'source': 'layer3_fts5', 'file': r.get('source', '')}, } for r in local_results] except Exception as e: logger.error(f"Layer 3 FTS5 fallback 失败:{e}") return [] # Phase 1: 检索 public 记忆 try: public_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={"visibility": "public"}, limit=per_phase ) qdrant_ok = True for mem in (public_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Public 记忆检索失败:{e}") # Phase 2: 检索 project 记忆 (agent 所属的所有项目) if agent_id and agent_id != 'general': agent_projects = get_agent_projects(agent_id) for project_id in agent_projects: if project_id == 'global': continue try: proj_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={ "visibility": "project", "project_id": project_id, }, limit=per_phase ) qdrant_ok = True for mem in (proj_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Project({project_id}) 记忆检索失败:{e}") # Phase 3: 检索 private 记忆 if agent_id and agent_id != 'general': try: private_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={ "visibility": "private", "agent_id": agent_id, }, limit=per_phase ) qdrant_ok = True for mem in (private_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Private 记忆检索失败:{e}") # Fallback: 兼容旧数据(无 visibility 字段) if user_id: try: legacy_mems = await asyncio.to_thread( self.local_memory.search, query, user_id=user_id, filters={"agent_id": agent_id} if agent_id and agent_id != 'general' else None, limit=per_phase ) qdrant_ok = True for mem in (legacy_mems or []): mid = mem.get('id') if isinstance(mem, dict) else None if mid and mid not in all_memories: all_memories[mid] = mem except Exception as e: logger.debug(f"Legacy 记忆检索失败:{e}") # Layer 3 Fallback: Qdrant 完全不可达时自动切换 FTS5 if not qdrant_ok: logger.warning("Qdrant 不可达,自动切换 Layer 3 FTS5 本地检索") try: fb = await asyncio.to_thread(self._get_local_search, agent_id) local_results = await asyncio.to_thread(fb.search, query, top_k) for r in local_results: fid = f"fts5:{r.get('source', '')}:{r.get('title', '')}" if fid not in all_memories: all_memories[fid] = { 'id': fid, 'memory': r.get('snippet', ''), 'score': 0.5, 'metadata': {'source': 'layer3_fts5', 'file': r.get('source', '')}, } except Exception as e: logger.error(f"Layer 3 FTS5 fallback 失败:{e}") min_confidence = self.config['retrieval']['min_confidence'] filtered = [ m for m in all_memories.values() if m.get('score', 1.0) >= min_confidence ] filtered.sort(key=lambda x: x.get('score', 0), reverse=True) return filtered[:top_k] def format_memories_for_prompt(self, memories: List[Dict]) -> str: """格式化记忆为 Prompt 片段""" if not memories: return "" prompt = "\n\n=== 相关记忆 ===\n" for i, mem in enumerate(memories, 1): memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) created_at = mem.get('created_at', '') if isinstance(mem, dict) else '' prompt += f"{i}. {memory_text}" if created_at: prompt += f" (记录于:{created_at})" prompt += "\n" prompt += "===============\n" return prompt # ========== Post-Hook: 异步写入 ========== def _should_skip_memory(self, user_message: str, assistant_message: str) -> bool: """判断是否应跳过此对话的记忆写入""" if not self.config.get('write_filter', {}).get('enabled', True): return False min_len = self.config.get('write_filter', {}).get('min_user_message_length', 5) if len(user_message.strip()) < min_len: return True if SKIP_PATTERNS.match(user_message.strip()): return True if SYSTEM_CMD_PATTERN.match(user_message.strip()): return True return False def _classify_memory_type(self, user_message: str, assistant_message: str) -> str: """自动分类记忆类型,决定过期策略""" combined = user_message + ' ' + assistant_message if MEMORY_KEYWORDS.search(combined): return 'preference' if any(kw in combined for kw in ('部署', '配置', '架构', '端口', '安装', '版本')): return 'knowledge' return 'session' def _classify_visibility(self, user_message: str, assistant_message: str, agent_id: str = None): """自动分类记忆可见性,返回 (visibility, project_id)""" combined = user_message + ' ' + assistant_message if PUBLIC_KEYWORDS.search(combined): return 'public', None # 检查是否匹配 agent 所属项目的关键词 if agent_id and agent_id != 'general': proj_kws = _build_project_keywords() agent_projects = get_agent_projects(agent_id) for pid in agent_projects: if pid == 'global' or pid not in proj_kws: continue if any(kw in combined for kw in proj_kws[pid]): return 'project', pid return 'private', None def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None, visibility: str = None, project_id: str = None, memory_type: str = None): """Post-Hook: 对话后异步写入(同步方法,仅添加到队列) 支持三级可见性 (public/project/private) 和记忆衰减 (expiration_date)。 内置智能写入过滤,跳过无价值对话。 """ if not self.config['async_write']['enabled']: return if self._should_skip_memory(user_message, assistant_message): logger.debug(f"Post-Hook 跳过(写入过滤):{user_message[:30]}") return if not user_id: user_id = self.config['metadata']['default_user_id'] if not agent_id: agent_id = self.config['metadata']['default_agent_id'] if not memory_type: memory_type = self._classify_memory_type(user_message, assistant_message) if not visibility: visibility, auto_project_id = self._classify_visibility(user_message, assistant_message, agent_id) if not project_id and auto_project_id: project_id = auto_project_id messages = [ {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message} ] if self.async_queue: self.async_queue.add({ 'messages': messages, 'user_id': user_id, 'agent_id': agent_id, 'visibility': visibility, 'project_id': project_id, 'memory_type': memory_type, 'timestamp': datetime.now().isoformat() }) logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}, " f"visibility={visibility}, type={memory_type}") else: logger.warning("异步队列未初始化") async def _async_write_memory(self, item: Dict): """异步写入记忆(后台任务)""" if self.local_memory is None: return timeout_ms = self.config['async_write']['timeout_ms'] try: await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000) logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}") except asyncio.TimeoutError: logger.warning(f"异步写入超时 ({timeout_ms}ms)") except Exception as e: logger.warning(f"异步写入失败:{e}") async def _execute_write(self, item: Dict): """ 执行写入 — 三级可见性 + 记忆衰减 metadata 携带 visibility / project_id / agent_id expiration_date 根据 memory_type 自动设置 """ if self.local_memory is None: return visibility = item.get('visibility', 'private') memory_type = item.get('memory_type', 'session') custom_metadata = { "agent_id": item['agent_id'], "visibility": visibility, "project_id": item.get('project_id') or '', "business_type": item.get('business_type', item['agent_id']), "memory_type": memory_type, "source": "openclaw", "timestamp": item.get('timestamp'), } ttl = EXPIRATION_MAP.get(memory_type) expiration_date = (datetime.now() + ttl).isoformat() if ttl else None add_kwargs = dict( messages=item['messages'], user_id=item['user_id'], agent_id=item['agent_id'], metadata=custom_metadata, ) if expiration_date: add_kwargs['expiration_date'] = expiration_date await asyncio.to_thread(self.local_memory.add, **add_kwargs) def _cleanup_cache(self): """清理过期缓存""" if not self.config['cache']['enabled']: return current_time = time.time() ttl = self.config['cache']['ttl'] expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl] for key in expired_keys: del self.cache[key] if len(self.cache) > self.config['cache']['max_size']: oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']] for key in oldest_keys: del self.cache[key] def get_status(self) -> Dict: """获取状态""" return { "initialized": self.local_memory is not None, "started": self._started, "async_queue_enabled": self.config['async_write']['enabled'], "queue_size": len(self.async_queue.queue) if self.async_queue else 0, "cache_size": len(self.cache), "qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}" } # ========== Knowledge Publishing (Hub Agent) ========== async def publish_knowledge(self, content, category='knowledge', visibility='public', project_id=None, agent_id='main'): """Publish knowledge/best practices to shared memory. Used by hub agent to share with all agents (public) or project teams.""" if visibility == 'project' and not project_id: raise ValueError("project visibility requires project_id") item = { 'messages': [{'role': 'system', 'content': content}], 'user_id': 'system', 'agent_id': agent_id, 'visibility': visibility, 'project_id': project_id, 'memory_type': category, 'timestamp': datetime.now().isoformat(), } await self._execute_write(item) logger.info(f"Published {category} ({visibility}): {content[:80]}...") # ========== Cold Start (Three-Phase) ========== async def cold_start_search(self, agent_id='main', user_id='default', top_k=5): """Three-phase cold start: public -> project -> private. Returns merged memories ordered by phase then score.""" if self.local_memory is None: return [] all_mems = {} phase = {} for q in ["system best practices and conventions", "shared configuration and architecture decisions"]: try: r = await asyncio.to_thread( self.local_memory.search, q, user_id=user_id, filters={"visibility": "public"}, limit=top_k) for m in (r or []): mid = m.get('id') if isinstance(m, dict) else None if mid and mid not in all_mems: all_mems[mid] = m phase[mid] = 0 except Exception as e: logger.debug(f"Cold start public failed: {e}") for pid in get_agent_projects(agent_id): if pid == 'global': continue for q in ["project guidelines and shared knowledge", "recent project decisions and updates"]: try: r = await asyncio.to_thread( self.local_memory.search, q, user_id=user_id, filters={"visibility": "project", "project_id": pid}, limit=top_k) for m in (r or []): mid = m.get('id') if isinstance(m, dict) else None if mid and mid not in all_mems: all_mems[mid] = m phase[mid] = 1 except Exception as e: logger.debug(f"Cold start project({pid}) failed: {e}") for q in ["current active tasks deployment progress", "pending work items todos", "recent important decisions configurations"]: try: r = await asyncio.to_thread( self.local_memory.search, q, user_id=user_id, filters={"visibility": "private", "agent_id": agent_id}, limit=top_k) for m in (r or []): mid = m.get('id') if isinstance(m, dict) else None if mid and mid not in all_mems: all_mems[mid] = m phase[mid] = 2 except Exception as e: logger.debug(f"Cold start private failed: {e}") mc = self.config['retrieval']['min_confidence'] out = [m for m in all_mems.values() if m.get('score', 1.0) >= mc] out.sort(key=lambda m: (phase.get(m.get('id', ''), 9), -m.get('score', 0))) return out[:top_k] async def shutdown(self): if self.async_queue: await self.async_queue.stop() logger.info("mem0 Client shutdown complete") # 全局客户端实例 mem0_client = Mem0Client()