diff --git a/skills/mem0-integration/__pycache__/mem0_client.cpython-312.pyc b/skills/mem0-integration/__pycache__/mem0_client.cpython-312.pyc index eebe998..9ff6650 100644 Binary files a/skills/mem0-integration/__pycache__/mem0_client.cpython-312.pyc and b/skills/mem0-integration/__pycache__/mem0_client.cpython-312.pyc differ diff --git a/skills/mem0-integration/__pycache__/openclaw_interceptor.cpython-312.pyc b/skills/mem0-integration/__pycache__/openclaw_interceptor.cpython-312.pyc new file mode 100644 index 0000000..1d2d00e Binary files /dev/null and b/skills/mem0-integration/__pycache__/openclaw_interceptor.cpython-312.pyc differ diff --git a/skills/mem0-integration/mem0_client.py b/skills/mem0-integration/mem0_client.py index ce1e307..9a356dc 100644 --- a/skills/mem0-integration/mem0_client.py +++ b/skills/mem0-integration/mem0_client.py @@ -1,31 +1,177 @@ #!/usr/bin/env python3 """ -mem0 Client for OpenClaw (v1.0 兼容) +mem0 Client for OpenClaw - 生产级纯异步架构 +Pre-Hook 检索注入 + Post-Hook 异步写入 """ import os +import asyncio import logging -from typing import List, Dict, Optional +import time +from typing import List, Dict, Optional, Any +from collections import deque +from datetime import datetime -# 设置环境变量(在导入 mem0 之前) +# 设置环境变量 os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' -os.environ['OPENAI_API_KEY'] = 'sk-c1715ee0479841399fd359c574647648' +os.environ['OPENAI_API_KEY'] = os.getenv('MEM0_DASHSCOPE_API_KEY', 'sk-c1715ee0479841399fd359c574647648') try: from mem0 import Memory from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig except ImportError as e: - print(f"⚠️ mem0ai 导入失败:{e}") + print(f"⚠️ mem0ai 导入失败:{e}") Memory = None logger = logging.getLogger(__name__) + +class AsyncMemoryQueue: + """纯异步记忆写入队列""" + + def __init__(self, max_size: int = 100): + self.queue = deque(maxlen=max_size) + self.lock = asyncio.Lock() + self.running = False + self._worker_task = None + self.callback = None + self.batch_size = 10 + self.flush_interval = 60 + + def add(self, item: Dict[str, Any]): + """添加任务到队列""" + try: + if len(self.queue) < self.queue.maxlen: + self.queue.append({ + 'messages': item['messages'], + 'user_id': item['user_id'], + 'agent_id': item['agent_id'], + 'timestamp': datetime.now().isoformat() + }) + else: + logger.warning("异步队列已满,丢弃旧任务") + except Exception as e: + logger.error(f"队列添加失败:{e}") + + async def get_batch(self, batch_size: int) -> List[Dict]: + """获取批量任务""" + async with self.lock: + batch = [] + while len(batch) < batch_size and self.queue: + batch.append(self.queue.popleft()) + return batch + + def start_worker(self, callback, batch_size: int, flush_interval: int): + """启动异步后台任务""" + self.running = True + self.callback = callback + self.batch_size = batch_size + self.flush_interval = flush_interval + self._worker_task = asyncio.create_task(self._worker_loop()) + logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)") + + async def _worker_loop(self): + """异步工作循环""" + last_flush = time.time() + + while self.running: + try: + if self.queue: + batch = await self.get_batch(self.batch_size) + if batch: + asyncio.create_task(self._process_batch(batch)) + + if time.time() - last_flush > self.flush_interval: + if self.queue: + batch = await self.get_batch(self.batch_size) + if batch: + asyncio.create_task(self._process_batch(batch)) + last_flush = time.time() + + await asyncio.sleep(1) + + except asyncio.CancelledError: + logger.info("异步工作线程已取消") + break + except Exception as e: + logger.error(f"异步工作线程错误:{e}") + await asyncio.sleep(5) + + async def _process_batch(self, batch: List[Dict]): + """处理批量任务""" + try: + logger.debug(f"开始处理批量任务:{len(batch)} 条") + for item in batch: + await self.callback(item) + logger.debug(f"批量任务处理完成") + except Exception as e: + logger.error(f"批量处理失败:{e}") + + async def stop(self): + """优雅关闭""" + self.running = False + if self._worker_task: + self._worker_task.cancel() + try: + await self._worker_task + except asyncio.CancelledError: + pass + logger.info("异步工作线程已关闭") + + class Mem0Client: - def __init__(self): + """生产级 mem0 客户端""" + + def __init__(self, config: Dict = None): + self.config = config or self._load_default_config() self.local_memory = None - self.init_memory() + self.async_queue = None + self.cache = {} + self._init_memory() + self._start_async_worker() - def init_memory(self): + def _load_default_config(self) -> Dict: + """加载默认配置""" + return { + "qdrant": { + "host": os.getenv('MEM0_QDRANT_HOST', '100.115.94.1'), + "port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), + "collection_name": "mem0_shared" + }, + "llm": { + "provider": "openai", + "config": {"model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus')} + }, + "retrieval": { + "enabled": True, + "top_k": 5, + "min_confidence": 0.7, + "timeout_ms": 2000 + }, + "async_write": { + "enabled": True, + "queue_size": 100, + "batch_size": 10, + "flush_interval": 60, + "timeout_ms": 5000 + }, + "cache": { + "enabled": True, + "ttl": 300, + "max_size": 1000 + }, + "fallback": { + "enabled": True, + "log_level": "WARNING", + "retry_attempts": 2 + }, + "metadata": { + "default_user_id": "default", + "default_agent_id": "general" + } + } + + def _init_memory(self): """初始化 mem0""" if Memory is None: logger.warning("mem0ai 未安装") @@ -36,69 +182,217 @@ class Mem0Client: vector_store=VectorStoreConfig( provider="qdrant", config={ - "host": os.getenv('MEM0_QDRANT_HOST', 'localhost'), - "port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), - "collection_name": "mem0_local", + "host": self.config['qdrant']['host'], + "port": self.config['qdrant']['port'], + "collection_name": self.config['qdrant']['collection_name'], "on_disk": True } ), llm=LlmConfig( provider="openai", - config={"model": "qwen-plus"} + config=self.config['llm']['config'] ) ) - self.local_memory = Memory(config=config) - logger.info("✅ 本地记忆初始化成功") + logger.info("✅ mem0 初始化成功") except Exception as e: - logger.error(f"❌ 初始化失败:{e}") + logger.error(f"❌ mem0 初始化失败:{e}") self.local_memory = None - def add(self, messages: List[Dict], user_id: str) -> Optional[Dict]: - """添加记忆""" - if self.local_memory is None: - return {"error": "mem0 not initialized"} + def _start_async_worker(self): + """启动异步写入工作线程""" + if not self.config['async_write']['enabled']: + return try: - result = self.local_memory.add(messages, user_id=user_id) - return {"success": True} - except Exception as e: - return {"error": str(e)} + loop = asyncio.get_event_loop() + self.async_queue = AsyncMemoryQueue( + max_size=self.config['async_write']['queue_size'] + ) + self.async_queue.start_worker( + callback=self._async_write_memory, + batch_size=self.config['async_write']['batch_size'], + flush_interval=self.config['async_write']['flush_interval'] + ) + except RuntimeError: + logger.debug("当前无事件循环,异步队列将在首次使用时初始化") - def search(self, query: str, user_id: str, limit: int = 5) -> List[Dict]: - """搜索记忆""" - if self.local_memory is None: + async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: + """Pre-Hook: 对话前智能检索""" + if not self.config['retrieval']['enabled'] or self.local_memory is None: return [] + cache_key = f"{user_id}:{agent_id}:{query}" + if self.config['cache']['enabled'] and cache_key in self.cache: + cached = self.cache[cache_key] + if time.time() - cached['time'] < self.config['cache']['ttl']: + logger.debug(f"Cache hit: {cache_key}") + return cached['results'] + + timeout_ms = self.config['retrieval']['timeout_ms'] + try: - return self.local_memory.search(query, user_id=user_id, limit=limit) - except Exception: + memories = await asyncio.wait_for( + self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']), + timeout=timeout_ms / 1000 + ) + + if self.config['cache']['enabled'] and memories: + self.cache[cache_key] = {'results': memories, 'time': time.time()} + self._cleanup_cache() + + logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆") + return memories + + except asyncio.TimeoutError: + logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)") + return [] + except Exception as e: + logger.warning(f"Pre-Hook 检索失败:{e}") return [] - def get_all(self, user_id: str) -> List[Dict]: - """获取所有记忆""" + async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: + """执行检索(使用 asyncio.to_thread 隔离阻塞)""" if self.local_memory is None: return [] - try: - return self.local_memory.get_all(user_id=user_id) - except Exception: - return [] + user_memories = [] + if user_id: + try: + user_memories = await asyncio.to_thread( + self.local_memory.search, query, user_id=user_id, limit=top_k + ) + except Exception as e: + logger.debug(f"用户记忆检索失败:{e}") + + agent_memories = [] + if agent_id and agent_id != 'general': + try: + agent_memories = await asyncio.to_thread( + self.local_memory.search, + query, + user_id=f"{user_id}:{agent_id}" if user_id else agent_id, + limit=top_k + ) + except Exception as e: + logger.debug(f"业务记忆检索失败:{e}") + + all_memories = {} + for mem in user_memories + agent_memories: + mem_id = mem.get('id') if isinstance(mem, dict) else None + if mem_id and mem_id not in all_memories: + all_memories[mem_id] = mem + + min_confidence = self.config['retrieval']['min_confidence'] + filtered = [m for m in all_memories.values() if m.get('score', 1.0) >= min_confidence] + filtered.sort(key=lambda x: x.get('score', 0), reverse=True) + + return filtered[:top_k] + + def format_memories_for_prompt(self, memories: List[Dict]) -> str: + """格式化记忆为 Prompt 片段""" + if not memories: + return "" + + prompt = "\n\n=== 相关记忆 ===\n" + for i, mem in enumerate(memories, 1): + memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) + created_at = mem.get('created_at', '') if isinstance(mem, dict) else '' + prompt += f"{i}. {memory_text}" + if created_at: + prompt += f" (记录于:{created_at})" + prompt += "\n" + prompt += "===============\n" + + return prompt + + def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None): + """Post-Hook: 对话后异步写入""" + if not self.config['async_write']['enabled']: + return + + if not user_id: + user_id = self.config['metadata']['default_user_id'] + if not agent_id: + agent_id = self.config['metadata']['default_agent_id'] + + messages = [ + {"role": "user", "content": user_message}, + {"role": "assistant", "content": assistant_message} + ] + + if self.async_queue: + self.async_queue.add({ + 'messages': messages, + 'user_id': user_id, + 'agent_id': agent_id + }) + logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}") + else: + logger.warning("异步队列未初始化") - def delete(self, memory_id: str, user_id: str) -> bool: - """删除记忆""" + async def _async_write_memory(self, item: Dict): + """异步写入记忆""" if self.local_memory is None: - return False + return + + timeout_ms = self.config['async_write']['timeout_ms'] try: - self.local_memory.delete(memory_id, user_id=user_id) - return True - except Exception: - return False + await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000) + logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}") + except asyncio.TimeoutError: + logger.warning(f"异步写入超时 ({timeout_ms}ms)") + except Exception as e: + logger.warning(f"异步写入失败:{e}") + + async def _execute_write(self, item: Dict): + """执行写入(使用 asyncio.to_thread 隔离阻塞)""" + if self.local_memory is None: + return + + full_user_id = f"{item['user_id']}:{item['agent_id']}" + + await asyncio.to_thread( + self.local_memory.add, + messages=item['messages'], + user_id=full_user_id, + agent_id=item['agent_id'] + ) + + def _cleanup_cache(self): + """清理过期缓存""" + if not self.config['cache']['enabled']: + return + + current_time = time.time() + ttl = self.config['cache']['ttl'] + + expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl] + for key in expired_keys: + del self.cache[key] + + if len(self.cache) > self.config['cache']['max_size']: + oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']] + for key in oldest_keys: + del self.cache[key] def get_status(self) -> Dict: """获取状态""" return { "initialized": self.local_memory is not None, - "qdrant": f"{os.getenv('MEM0_QDRANT_HOST', 'localhost')}:{os.getenv('MEM0_QDRANT_PORT', '6333')}" + "async_queue_enabled": self.config['async_write']['enabled'], + "queue_size": len(self.async_queue.queue) if self.async_queue else 0, + "cache_size": len(self.cache), + "qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}" } + + async def shutdown(self): + """优雅关闭""" + if self.async_queue: + await self.async_queue.stop() + logger.info("mem0 Client 已关闭") + + +# 全局客户端实例 +mem0_client = Mem0Client() diff --git a/skills/mem0-integration/openclaw_interceptor.py b/skills/mem0-integration/openclaw_interceptor.py new file mode 100644 index 0000000..4645086 --- /dev/null +++ b/skills/mem0-integration/openclaw_interceptor.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +"""OpenClaw 拦截器:Pre-Hook + Post-Hook""" + +import asyncio +import logging +import sys +sys.path.insert(0, '/root/.openclaw/workspace/skills/mem0-integration') +from mem0_client import mem0_client + +logger = logging.getLogger(__name__) + + +class ConversationInterceptor: + def __init__(self): + self.enabled = True + + async def pre_hook(self, query: str, context: dict) -> str: + if not self.enabled: + return None + try: + user_id = context.get('user_id', 'default') + agent_id = context.get('agent_id', 'general') + memories = await mem0_client.pre_hook_search(query=query, user_id=user_id, agent_id=agent_id) + if memories: + return mem0_client.format_memories_for_prompt(memories) + return None + except Exception as e: + logger.error(f"Pre-Hook 失败:{e}") + return None + + async def post_hook(self, user_message: str, assistant_message: str, context: dict): + if not self.enabled: + return + try: + user_id = context.get('user_id', 'default') + agent_id = context.get('agent_id', 'general') + await mem0_client.post_hook_add(user_message, assistant_message, user_id, agent_id) + logger.debug(f"Post-Hook: 已提交对话") + except Exception as e: + logger.error(f"Post-Hook 失败:{e}") + + +interceptor = ConversationInterceptor() + + +async def intercept_before_llm(query: str, context: dict): + return await interceptor.pre_hook(query, context) + + +async def intercept_after_response(user_msg: str, assistant_msg: str, context: dict): + await interceptor.post_hook(user_msg, assistant_msg, context) diff --git a/skills/mem0-integration/test_integration.py b/skills/mem0-integration/test_integration.py new file mode 100644 index 0000000..ffe70ff --- /dev/null +++ b/skills/mem0-integration/test_integration.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# /root/.openclaw/workspace/skills/mem0-integration/test_integration.py + +import asyncio +import logging +import sys +sys.path.insert(0, '/root/.openclaw/workspace/skills/mem0-integration') + +from mem0_client import mem0_client +from openclaw_interceptor import intercept_before_llm, intercept_after_response + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +async def mock_llm(system_prompt, user_message): + """模拟 LLM 调用""" + return f"这是模拟回复:{user_message[:20]}..." + +async def test_full_flow(): + """测试完整对话流程""" + print("=" * 60) + print("🧪 测试 mem0 集成架构") + print("=" * 60) + + context = { + 'user_id': '5237946060', + 'agent_id': 'general' + } + + # ========== 测试 1: Pre-Hook 检索 ========== + print("\n1️⃣ 测试 Pre-Hook 检索...") + query = "我平时喜欢用什么时区?" + memory_prompt = await intercept_before_llm(query, context) + + if memory_prompt: + print(f"✅ 检索到记忆:\n{memory_prompt}") + else: + print("⚠️ 未检索到记忆(正常,首次对话)") + + # ========== 测试 2: 完整对话流程 ========== + print("\n2️⃣ 测试完整对话流程...") + user_message = "我平时喜欢使用 UTC 时区,请用简体中文和我交流" + + print(f"用户:{user_message}") + response = await mock_llm("system", user_message) + print(f"助手:{response}") + + # ========== 测试 3: Post-Hook 异步写入 ========== + print("\n3️⃣ 测试 Post-Hook 异步写入...") + await intercept_after_response(user_message, response, context) + print(f"✅ 对话已提交到异步队列") + print(f" 队列大小:{len(mem0_client.async_queue.queue) if mem0_client.async_queue else 0}") + + # ========== 等待异步写入完成 ========== + print("\n4️⃣ 等待异步写入 (5 秒)...") + await asyncio.sleep(5) + + # ========== 测试 4: 验证记忆已存储 ========== + print("\n5️⃣ 验证记忆已存储...") + memories = await mem0_client.pre_hook_search("时区", **context) + print(f"✅ 检索到 {len(memories)} 条记忆") + for i, mem in enumerate(memories, 1): + print(f" {i}. {mem.get('memory', 'N/A')[:100]}") + + # ========== 状态报告 ========== + print("\n" + "=" * 60) + print("📊 系统状态:") + status = mem0_client.get_status() + for key, value in status.items(): + print(f" {key}: {value}") + print("=" * 60) + print("✅ 测试完成") + +if __name__ == '__main__': + asyncio.run(test_full_flow())