You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

627 lines
23 KiB

#!/usr/bin/env python3
"""
mem0 Client for OpenClaw - 生产级纯异步架构
Pre-Hook 检索注入 + Post-Hook 异步写入
三级可见性隔离 (public / project / private)
记忆衰减 (expiration_date) + 智能写入过滤
"""
import os
import re
import asyncio
import logging
import time
import yaml
from typing import List, Dict, Any
from collections import deque
from datetime import datetime, timedelta
from pathlib import Path
# ========== DashScope 环境变量配置 ==========
# Gemini Pro Embedding 模型:text-embedding-v4 (1024 维度)
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' # 关键:兼容模式需要此变量
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '')
if not _dashscope_key:
_dashscope_key = os.getenv('DASHSCOPE_API_KEY', '')
if _dashscope_key:
os.environ['OPENAI_API_KEY'] = _dashscope_key
elif not os.environ.get('OPENAI_API_KEY'):
logging.warning("MEM0_DASHSCOPE_API_KEY not set; mem0 embedding/LLM calls will fail")
try:
from mem0 import Memory
from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig, EmbedderConfig
except ImportError as e:
print(f" mem0ai 导入失败:{e}")
Memory = None
logger = logging.getLogger(__name__)
class AsyncMemoryQueue:
"""纯异步记忆写入队列"""
def __init__(self, max_size: int = 100):
self.queue = deque(maxlen=max_size)
self.lock = asyncio.Lock()
self.running = False
self._worker_task = None
self.callback = None
self.batch_size = 10
self.flush_interval = 60
def add(self, item: Dict[str, Any]):
"""添加任务到队列(同步方法)"""
try:
if len(self.queue) < self.queue.maxlen:
self.queue.append({
'messages': item['messages'],
'user_id': item['user_id'],
'agent_id': item['agent_id'],
'visibility': item.get('visibility', 'private'),
'project_id': item.get('project_id'),
'memory_type': item.get('memory_type', 'session'),
'timestamp': item.get('timestamp', datetime.now().isoformat())
})
else:
logger.warning("异步队列已满,丢弃旧任务")
except Exception as e:
logger.error(f"队列添加失败:{e}")
async def get_batch(self, batch_size: int) -> List[Dict]:
"""获取批量任务(异步方法)"""
async with self.lock:
batch = []
while len(batch) < batch_size and self.queue:
batch.append(self.queue.popleft())
return batch
def start_worker(self, callback, batch_size: int, flush_interval: int):
"""启动异步后台任务"""
self.running = True
self.callback = callback
self.batch_size = batch_size
self.flush_interval = flush_interval
self._worker_task = asyncio.create_task(self._worker_loop())
logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)")
async def _worker_loop(self):
"""异步工作循环"""
last_flush = time.time()
while self.running:
try:
if self.queue:
batch = await self.get_batch(self.batch_size)
if batch:
asyncio.create_task(self._process_batch(batch))
if time.time() - last_flush > self.flush_interval:
if self.queue:
batch = await self.get_batch(self.batch_size)
if batch:
asyncio.create_task(self._process_batch(batch))
last_flush = time.time()
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("异步工作线程已取消")
break
except Exception as e:
logger.error(f"异步工作线程错误:{e}")
await asyncio.sleep(5)
async def _process_batch(self, batch: List[Dict]):
"""处理批量任务"""
try:
logger.debug(f"开始处理批量任务:{len(batch)}")
for item in batch:
await self.callback(item)
logger.debug(f"批量任务处理完成")
except Exception as e:
logger.error(f"批量处理失败:{e}")
async def stop(self):
"""优雅关闭"""
self.running = False
if self._worker_task:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
logger.info("异步工作线程已关闭")
EXPIRATION_MAP = {
'session': timedelta(days=7),
'chat_summary': timedelta(days=30),
'preference': None,
'knowledge': None,
}
SKIP_PATTERNS = re.compile(
r'^(好的|收到|OK|ok|嗯|行|没问题|感谢|谢谢|了解|明白|知道了|👍|✅|❌)$',
re.IGNORECASE
)
SYSTEM_CMD_PATTERN = re.compile(r'^/')
MEMORY_KEYWORDS = re.compile(
r'(记住|以后|偏好|配置|设置|规则|永远|始终|总是|不要|禁止)',
)
PUBLIC_KEYWORDS = re.compile(
r'(所有人|通知|全局|公告|大家|集群)',
)
def _load_project_registry() -> Dict:
"""从 project_registry.yaml 加载项目注册表"""
registry_path = Path(__file__).parent / 'project_registry.yaml'
if registry_path.exists():
try:
with open(registry_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f) or {}
except Exception:
pass
return {}
def get_agent_projects(agent_id: str) -> List[str]:
"""查询一个 agent 所属的所有 project_id"""
registry = _load_project_registry()
projects = registry.get('projects', {})
result = []
for pid, pconf in projects.items():
members = pconf.get('members', [])
if '*' in members or agent_id in members:
result.append(pid)
return result
class Mem0Client:
"""
生产级 mem0 客户端
纯异步架构 + 三级可见性 + 记忆衰减 + 智能写入过滤
"""
def __init__(self, config: Dict = None):
self.config = config or self._load_default_config()
self.local_memory = None
self.async_queue = None
self.cache = {}
self._started = False
self._init_memory()
def _load_default_config(self) -> Dict:
"""加载默认配置 - 单库融合架构"""
return {
"qdrant": {
"host": os.getenv('MEM0_QDRANT_HOST', 'localhost'),
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')),
"collection_name": "mem0_v4_shared" # 统一共享 Collection(陈医生/张大师共用)
},
"llm": {
"provider": "openai",
"config": {
"model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus')
}
},
"embedder": {
"provider": "openai",
"config": {
"model": os.getenv('MEM0_EMBEDDER_MODEL', 'text-embedding-v4'),
"dimensions": 1024 # DashScope text-embedding-v4 支持的最大维度
}
},
"retrieval": {
"enabled": True,
"top_k": 5,
"min_confidence": 0.7,
"timeout_ms": 5000
},
"async_write": {
"enabled": True,
"queue_size": 100,
"batch_size": 10,
"flush_interval": 60,
"timeout_ms": 5000
},
"cache": {
"enabled": True,
"ttl": 300,
"max_size": 1000
},
"fallback": {
"enabled": True,
"log_level": "WARNING",
"retry_attempts": 2
},
"metadata": {
"default_user_id": "default",
"default_agent_id": "general"
},
"write_filter": {
"enabled": True,
"min_user_message_length": 5,
}
}
def _init_memory(self):
"""初始化 mem0(同步操作)- 三位一体完整配置"""
if Memory is None:
logger.warning("mem0ai 未安装")
return
try:
config = MemoryConfig(
vector_store=VectorStoreConfig(
provider="qdrant",
config={
"host": self.config['qdrant']['host'],
"port": self.config['qdrant']['port'],
"collection_name": self.config['qdrant']['collection_name'],
"on_disk": True,
"embedding_model_dims": 1024 # 强制同步 Qdrant 集合维度
}
),
llm=LlmConfig(
provider="openai",
config=self.config['llm']['config']
),
embedder=EmbedderConfig(
provider="openai",
config={
"model": "text-embedding-v4",
"embedding_dims": 1024 # 核心修复:强制覆盖默认的 1536 维度
# api_base 和 api_key 通过环境变量 OPENAI_BASE_URL 和 OPENAI_API_KEY 读取
}
)
)
self.local_memory = Memory(config=config)
logger.info("✅ mem0 初始化成功(含 Embedder,1024 维度)")
except Exception as e:
logger.error(f"❌ mem0 初始化失败:{e}")
self.local_memory = None
async def start(self):
"""
显式启动异步工作线程
必须在事件循环中调用await mem0_client.start()
"""
if self._started:
logger.debug("mem0 Client 已启动")
return
if self.config['async_write']['enabled']:
self.async_queue = AsyncMemoryQueue(
max_size=self.config['async_write']['queue_size']
)
self.async_queue.start_worker(
callback=self._async_write_memory,
batch_size=self.config['async_write']['batch_size'],
flush_interval=self.config['async_write']['flush_interval']
)
self._started = True
logger.info("✅ mem0 Client 异步工作线程已启动")
# ========== Pre-Hook: 智能检索 ==========
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]:
"""Pre-Hook: 对话前智能检索"""
if not self.config['retrieval']['enabled'] or self.local_memory is None:
return []
cache_key = f"{user_id}:{agent_id}:{query}"
if self.config['cache']['enabled'] and cache_key in self.cache:
cached = self.cache[cache_key]
if time.time() - cached['time'] < self.config['cache']['ttl']:
logger.debug(f"Cache hit: {cache_key}")
return cached['results']
timeout_ms = self.config['retrieval']['timeout_ms']
try:
memories = await asyncio.wait_for(
self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']),
timeout=timeout_ms / 1000
)
if self.config['cache']['enabled'] and memories:
self.cache[cache_key] = {'results': memories, 'time': time.time()}
self._cleanup_cache()
logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆")
return memories
except asyncio.TimeoutError:
logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)")
return []
except Exception as e:
logger.warning(f"Pre-Hook 检索失败:{e}")
return []
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]:
"""
三阶段检索 按可见性分层合并去重
Phase 1: public (所有 agent 可见)
Phase 2: project ( project_id 成员可见)
Phase 3: private ( agent_id 本人可见)
"""
if self.local_memory is None:
return []
all_memories: Dict[str, Dict] = {}
per_phase = max(top_k, 3)
# Phase 1: 检索 public 记忆
try:
public_mems = await asyncio.to_thread(
self.local_memory.search,
query,
user_id=user_id,
filters={"visibility": "public"},
limit=per_phase
)
for mem in (public_mems or []):
mid = mem.get('id') if isinstance(mem, dict) else None
if mid and mid not in all_memories:
all_memories[mid] = mem
except Exception as e:
logger.debug(f"Public 记忆检索失败:{e}")
# Phase 2: 检索 project 记忆 (agent 所属的所有项目)
if agent_id and agent_id != 'general':
agent_projects = get_agent_projects(agent_id)
for project_id in agent_projects:
if project_id == 'global':
continue
try:
proj_mems = await asyncio.to_thread(
self.local_memory.search,
query,
user_id=user_id,
filters={
"visibility": "project",
"project_id": project_id,
},
limit=per_phase
)
for mem in (proj_mems or []):
mid = mem.get('id') if isinstance(mem, dict) else None
if mid and mid not in all_memories:
all_memories[mid] = mem
except Exception as e:
logger.debug(f"Project({project_id}) 记忆检索失败:{e}")
# Phase 3: 检索 private 记忆
if agent_id and agent_id != 'general':
try:
private_mems = await asyncio.to_thread(
self.local_memory.search,
query,
user_id=user_id,
filters={
"visibility": "private",
"agent_id": agent_id,
},
limit=per_phase
)
for mem in (private_mems or []):
mid = mem.get('id') if isinstance(mem, dict) else None
if mid and mid not in all_memories:
all_memories[mid] = mem
except Exception as e:
logger.debug(f"Private 记忆检索失败:{e}")
# Fallback: 兼容旧数据(无 visibility 字段)
if user_id:
try:
legacy_mems = await asyncio.to_thread(
self.local_memory.search,
query,
user_id=user_id,
filters={"agent_id": agent_id} if agent_id and agent_id != 'general' else None,
limit=per_phase
)
for mem in (legacy_mems or []):
mid = mem.get('id') if isinstance(mem, dict) else None
if mid and mid not in all_memories:
all_memories[mid] = mem
except Exception as e:
logger.debug(f"Legacy 记忆检索失败:{e}")
min_confidence = self.config['retrieval']['min_confidence']
filtered = [
m for m in all_memories.values()
if m.get('score', 1.0) >= min_confidence
]
filtered.sort(key=lambda x: x.get('score', 0), reverse=True)
return filtered[:top_k]
def format_memories_for_prompt(self, memories: List[Dict]) -> str:
"""格式化记忆为 Prompt 片段"""
if not memories:
return ""
prompt = "\n\n=== 相关记忆 ===\n"
for i, mem in enumerate(memories, 1):
memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem)
created_at = mem.get('created_at', '') if isinstance(mem, dict) else ''
prompt += f"{i}. {memory_text}"
if created_at:
prompt += f" (记录于:{created_at})"
prompt += "\n"
prompt += "===============\n"
return prompt
# ========== Post-Hook: 异步写入 ==========
def _should_skip_memory(self, user_message: str, assistant_message: str) -> bool:
"""判断是否应跳过此对话的记忆写入"""
if not self.config.get('write_filter', {}).get('enabled', True):
return False
min_len = self.config.get('write_filter', {}).get('min_user_message_length', 5)
if len(user_message.strip()) < min_len:
return True
if SKIP_PATTERNS.match(user_message.strip()):
return True
if SYSTEM_CMD_PATTERN.match(user_message.strip()):
return True
return False
def _classify_memory_type(self, user_message: str, assistant_message: str) -> str:
"""自动分类记忆类型,决定过期策略"""
combined = user_message + ' ' + assistant_message
if MEMORY_KEYWORDS.search(combined):
return 'preference'
if any(kw in combined for kw in ('部署', '配置', '架构', '端口', '安装', '版本')):
return 'knowledge'
return 'session'
def _classify_visibility(self, user_message: str, assistant_message: str, agent_id: str = None) -> str:
"""自动分类记忆可见性"""
combined = user_message + ' ' + assistant_message
if PUBLIC_KEYWORDS.search(combined):
return 'public'
return 'private'
def post_hook_add(self, user_message: str, assistant_message: str,
user_id: str = None, agent_id: str = None,
visibility: str = None, project_id: str = None,
memory_type: str = None):
"""Post-Hook: 对话后异步写入(同步方法,仅添加到队列)
支持三级可见性 (public/project/private) 和记忆衰减 (expiration_date)
内置智能写入过滤跳过无价值对话
"""
if not self.config['async_write']['enabled']:
return
if self._should_skip_memory(user_message, assistant_message):
logger.debug(f"Post-Hook 跳过(写入过滤):{user_message[:30]}")
return
if not user_id:
user_id = self.config['metadata']['default_user_id']
if not agent_id:
agent_id = self.config['metadata']['default_agent_id']
if not memory_type:
memory_type = self._classify_memory_type(user_message, assistant_message)
if not visibility:
visibility = self._classify_visibility(user_message, assistant_message, agent_id)
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
if self.async_queue:
self.async_queue.add({
'messages': messages,
'user_id': user_id,
'agent_id': agent_id,
'visibility': visibility,
'project_id': project_id,
'memory_type': memory_type,
'timestamp': datetime.now().isoformat()
})
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}, "
f"visibility={visibility}, type={memory_type}")
else:
logger.warning("异步队列未初始化")
async def _async_write_memory(self, item: Dict):
"""异步写入记忆(后台任务)"""
if self.local_memory is None:
return
timeout_ms = self.config['async_write']['timeout_ms']
try:
await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000)
logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}")
except asyncio.TimeoutError:
logger.warning(f"异步写入超时 ({timeout_ms}ms)")
except Exception as e:
logger.warning(f"异步写入失败:{e}")
async def _execute_write(self, item: Dict):
"""
执行写入 三级可见性 + 记忆衰减
metadata 携带 visibility / project_id / agent_id
expiration_date 根据 memory_type 自动设置
"""
if self.local_memory is None:
return
visibility = item.get('visibility', 'private')
memory_type = item.get('memory_type', 'session')
custom_metadata = {
"agent_id": item['agent_id'],
"visibility": visibility,
"project_id": item.get('project_id') or '',
"business_type": item.get('business_type', item['agent_id']),
"memory_type": memory_type,
"source": "openclaw",
"timestamp": item.get('timestamp'),
}
ttl = EXPIRATION_MAP.get(memory_type)
expiration_date = (datetime.now() + ttl).isoformat() if ttl else None
add_kwargs = dict(
messages=item['messages'],
user_id=item['user_id'],
agent_id=item['agent_id'],
metadata=custom_metadata,
)
if expiration_date:
add_kwargs['expiration_date'] = expiration_date
await asyncio.to_thread(self.local_memory.add, **add_kwargs)
def _cleanup_cache(self):
"""清理过期缓存"""
if not self.config['cache']['enabled']:
return
current_time = time.time()
ttl = self.config['cache']['ttl']
expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl]
for key in expired_keys:
del self.cache[key]
if len(self.cache) > self.config['cache']['max_size']:
oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']]
for key in oldest_keys:
del self.cache[key]
def get_status(self) -> Dict:
"""获取状态"""
return {
"initialized": self.local_memory is not None,
"started": self._started,
"async_queue_enabled": self.config['async_write']['enabled'],
"queue_size": len(self.async_queue.queue) if self.async_queue else 0,
"cache_size": len(self.cache),
"qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}"
}
async def shutdown(self):
"""优雅关闭"""
if self.async_queue:
await self.async_queue.stop()
logger.info("mem0 Client 已关闭")
# 全局客户端实例
mem0_client = Mem0Client()