|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
"""
|
|
|
|
|
|
mem0 Client for OpenClaw - 生产级纯异步架构
|
|
|
|
|
|
Pre-Hook 检索注入 + Post-Hook 异步写入
|
|
|
|
|
|
三级可见性隔离 (public / project / private)
|
|
|
|
|
|
记忆衰减 (expiration_date) + 智能写入过滤
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import re
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import time
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
|
from collections import deque
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
# ========== LLM/Embedding API 配置 ==========
|
|
|
|
|
|
# 优先级链:OneAPI (.env) > 已有环境变量 > DashScope 默认
|
|
|
|
|
|
_api_base = (os.getenv('LLM_BASE_URL')
|
|
|
|
|
|
or os.getenv('OPENAI_BASE_URL')
|
|
|
|
|
|
or os.getenv('OPENAI_API_BASE')
|
|
|
|
|
|
or 'https://dashscope.aliyuncs.com/compatible-mode/v1')
|
|
|
|
|
|
os.environ['OPENAI_API_BASE'] = _api_base
|
|
|
|
|
|
os.environ['OPENAI_BASE_URL'] = _api_base
|
|
|
|
|
|
|
|
|
|
|
|
_api_key = (os.getenv('LLM_API_KEY')
|
|
|
|
|
|
or os.getenv('MEM0_DASHSCOPE_API_KEY')
|
|
|
|
|
|
or os.getenv('DASHSCOPE_API_KEY')
|
|
|
|
|
|
or '')
|
|
|
|
|
|
if _api_key:
|
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = _api_key
|
|
|
|
|
|
elif not os.environ.get('OPENAI_API_KEY'):
|
|
|
|
|
|
logging.warning("No API key found (LLM_API_KEY / MEM0_DASHSCOPE_API_KEY); mem0 calls will fail")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from mem0 import Memory
|
|
|
|
|
|
from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig, EmbedderConfig
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
|
print(f"⚠️ mem0ai 导入失败:{e}")
|
|
|
|
|
|
Memory = None
|
|
|
|
|
|
|
|
|
|
|
|
from local_search import LocalSearchFallback
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncMemoryQueue:
|
|
|
|
|
|
"""纯异步记忆写入队列"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, max_size: int = 100):
|
|
|
|
|
|
self.queue = deque(maxlen=max_size)
|
|
|
|
|
|
self.lock = asyncio.Lock()
|
|
|
|
|
|
self.running = False
|
|
|
|
|
|
self._worker_task = None
|
|
|
|
|
|
self.callback = None
|
|
|
|
|
|
self.batch_size = 10
|
|
|
|
|
|
self.flush_interval = 60
|
|
|
|
|
|
|
|
|
|
|
|
def add(self, item: Dict[str, Any]):
|
|
|
|
|
|
"""添加任务到队列(同步方法)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if len(self.queue) < self.queue.maxlen:
|
|
|
|
|
|
self.queue.append({
|
|
|
|
|
|
'messages': item['messages'],
|
|
|
|
|
|
'user_id': item['user_id'],
|
|
|
|
|
|
'agent_id': item['agent_id'],
|
|
|
|
|
|
'visibility': item.get('visibility', 'private'),
|
|
|
|
|
|
'project_id': item.get('project_id'),
|
|
|
|
|
|
'memory_type': item.get('memory_type', 'session'),
|
|
|
|
|
|
'timestamp': item.get('timestamp', datetime.now().isoformat())
|
|
|
|
|
|
})
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning("异步队列已满,丢弃旧任务")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"队列添加失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
async def get_batch(self, batch_size: int) -> List[Dict]:
|
|
|
|
|
|
"""获取批量任务(异步方法)"""
|
|
|
|
|
|
async with self.lock:
|
|
|
|
|
|
batch = []
|
|
|
|
|
|
while len(batch) < batch_size and self.queue:
|
|
|
|
|
|
batch.append(self.queue.popleft())
|
|
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
def start_worker(self, callback, batch_size: int, flush_interval: int):
|
|
|
|
|
|
"""启动异步后台任务"""
|
|
|
|
|
|
self.running = True
|
|
|
|
|
|
self.callback = callback
|
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
self.flush_interval = flush_interval
|
|
|
|
|
|
self._worker_task = asyncio.create_task(self._worker_loop())
|
|
|
|
|
|
logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)")
|
|
|
|
|
|
|
|
|
|
|
|
async def _worker_loop(self):
|
|
|
|
|
|
"""异步工作循环"""
|
|
|
|
|
|
last_flush = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
while self.running:
|
|
|
|
|
|
try:
|
|
|
|
|
|
if self.queue:
|
|
|
|
|
|
batch = await self.get_batch(self.batch_size)
|
|
|
|
|
|
if batch:
|
|
|
|
|
|
asyncio.create_task(self._process_batch(batch))
|
|
|
|
|
|
|
|
|
|
|
|
if time.time() - last_flush > self.flush_interval:
|
|
|
|
|
|
if self.queue:
|
|
|
|
|
|
batch = await self.get_batch(self.batch_size)
|
|
|
|
|
|
if batch:
|
|
|
|
|
|
asyncio.create_task(self._process_batch(batch))
|
|
|
|
|
|
last_flush = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
logger.info("异步工作线程已取消")
|
|
|
|
|
|
break
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"异步工作线程错误:{e}")
|
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
|
|
|
|
|
|
async def _process_batch(self, batch: List[Dict]):
|
|
|
|
|
|
"""处理批量任务"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
logger.debug(f"开始处理批量任务:{len(batch)} 条")
|
|
|
|
|
|
for item in batch:
|
|
|
|
|
|
await self.callback(item)
|
|
|
|
|
|
logger.debug(f"批量任务处理完成")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"批量处理失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
async def stop(self):
|
|
|
|
|
|
"""优雅关闭"""
|
|
|
|
|
|
self.running = False
|
|
|
|
|
|
if self._worker_task:
|
|
|
|
|
|
self._worker_task.cancel()
|
|
|
|
|
|
try:
|
|
|
|
|
|
await self._worker_task
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
logger.info("异步工作线程已关闭")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXPIRATION_MAP = {
|
|
|
|
|
|
'session': timedelta(days=7),
|
|
|
|
|
|
'chat_summary': timedelta(days=30),
|
|
|
|
|
|
'preference': None,
|
|
|
|
|
|
'knowledge': None,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
SKIP_PATTERNS = re.compile(
|
|
|
|
|
|
r'^(好的|收到|OK|ok|嗯|行|没问题|感谢|谢谢|了解|明白|知道了|👍|✅|❌)$',
|
|
|
|
|
|
re.IGNORECASE
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_CMD_PATTERN = re.compile(r'^/')
|
|
|
|
|
|
|
|
|
|
|
|
MEMORY_KEYWORDS = re.compile(
|
|
|
|
|
|
r'(记住|以后|偏好|配置|设置|规则|永远|始终|总是|不要|禁止)',
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
PUBLIC_KEYWORDS = re.compile(
|
|
|
|
|
|
r'(所有人|通知|全局|公告|大家|集群)',
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
PROJECT_KEYWORDS_CACHE: Dict[str, List[str]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_project_keywords() -> Dict[str, List[str]]:
|
|
|
|
|
|
"""从 project_registry.yaml 提取每个项目的关键词用于自动分类"""
|
|
|
|
|
|
global PROJECT_KEYWORDS_CACHE
|
|
|
|
|
|
if PROJECT_KEYWORDS_CACHE:
|
|
|
|
|
|
return PROJECT_KEYWORDS_CACHE
|
|
|
|
|
|
registry = _load_project_registry()
|
|
|
|
|
|
projects = registry.get('projects', {})
|
|
|
|
|
|
result: Dict[str, List[str]] = {}
|
|
|
|
|
|
for pid, pconf in projects.items():
|
|
|
|
|
|
if pid == 'global':
|
|
|
|
|
|
continue
|
|
|
|
|
|
kws = [pid]
|
|
|
|
|
|
name = pconf.get('name', '')
|
|
|
|
|
|
if name and len(name) >= 2:
|
|
|
|
|
|
kws.append(name)
|
|
|
|
|
|
desc = pconf.get('description', '')
|
|
|
|
|
|
for seg in re.split(r'[,、。/\s]+', desc):
|
|
|
|
|
|
seg = seg.strip()
|
|
|
|
|
|
if len(seg) >= 2:
|
|
|
|
|
|
kws.append(seg)
|
|
|
|
|
|
result[pid] = kws
|
|
|
|
|
|
PROJECT_KEYWORDS_CACHE = result
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_project_registry() -> Dict:
|
|
|
|
|
|
"""从 project_registry.yaml 加载项目注册表"""
|
|
|
|
|
|
registry_path = Path(__file__).parent / 'project_registry.yaml'
|
|
|
|
|
|
if registry_path.exists():
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(registry_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
return yaml.safe_load(f) or {}
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_agent_projects(agent_id: str) -> List[str]:
|
|
|
|
|
|
"""查询一个 agent 所属的所有 project_id"""
|
|
|
|
|
|
registry = _load_project_registry()
|
|
|
|
|
|
projects = registry.get('projects', {})
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for pid, pconf in projects.items():
|
|
|
|
|
|
members = pconf.get('members', [])
|
|
|
|
|
|
if '*' in members or agent_id in members:
|
|
|
|
|
|
result.append(pid)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mem0Client:
|
|
|
|
|
|
"""
|
|
|
|
|
|
生产级 mem0 客户端
|
|
|
|
|
|
纯异步架构 + 三级可见性 + 记忆衰减 + 智能写入过滤
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict = None):
|
|
|
|
|
|
self.config = config or self._load_default_config()
|
|
|
|
|
|
self.local_memory = None
|
|
|
|
|
|
self.async_queue = None
|
|
|
|
|
|
self.cache = {}
|
|
|
|
|
|
self._started = False
|
|
|
|
|
|
self._local_search: Dict[str, LocalSearchFallback] = {}
|
|
|
|
|
|
self._init_memory()
|
|
|
|
|
|
|
|
|
|
|
|
def _load_default_config(self) -> Dict:
|
|
|
|
|
|
"""加载默认配置 - 单库融合架构"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"qdrant": {
|
|
|
|
|
|
"host": os.getenv('MEM0_QDRANT_HOST', 'localhost'),
|
|
|
|
|
|
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')),
|
|
|
|
|
|
"collection_name": "mem0_v4_shared" # 统一共享 Collection(陈医生/张大师共用)
|
|
|
|
|
|
},
|
|
|
|
|
|
"llm": {
|
|
|
|
|
|
"provider": "openai",
|
|
|
|
|
|
"config": {
|
|
|
|
|
|
"model": os.getenv('MEM0_LLM_MODEL') or os.getenv('LLM_MODEL_ID', 'qwen-plus')
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"embedder": {
|
|
|
|
|
|
"provider": "openai",
|
|
|
|
|
|
"config": {
|
|
|
|
|
|
"model": os.getenv('MEM0_EMBEDDER_MODEL') or os.getenv('EMBEDDING_MODEL_ID', 'text-embedding-v4'),
|
|
|
|
|
|
"dimensions": 1024 # DashScope text-embedding-v4 支持的最大维度
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"retrieval": {
|
|
|
|
|
|
"enabled": True,
|
|
|
|
|
|
"top_k": 5,
|
|
|
|
|
|
"min_confidence": 0.7,
|
|
|
|
|
|
"timeout_ms": 5000
|
|
|
|
|
|
},
|
|
|
|
|
|
"async_write": {
|
|
|
|
|
|
"enabled": True,
|
|
|
|
|
|
"queue_size": 100,
|
|
|
|
|
|
"batch_size": 10,
|
|
|
|
|
|
"flush_interval": 60,
|
|
|
|
|
|
"timeout_ms": 5000
|
|
|
|
|
|
},
|
|
|
|
|
|
"cache": {
|
|
|
|
|
|
"enabled": True,
|
|
|
|
|
|
"ttl": 300,
|
|
|
|
|
|
"max_size": 1000
|
|
|
|
|
|
},
|
|
|
|
|
|
"fallback": {
|
|
|
|
|
|
"enabled": True,
|
|
|
|
|
|
"log_level": "WARNING",
|
|
|
|
|
|
"retry_attempts": 2
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"default_user_id": "default",
|
|
|
|
|
|
"default_agent_id": "general"
|
|
|
|
|
|
},
|
|
|
|
|
|
"write_filter": {
|
|
|
|
|
|
"enabled": True,
|
|
|
|
|
|
"min_user_message_length": 5,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _init_memory(self):
|
|
|
|
|
|
"""初始化 mem0(同步操作)- 三位一体完整配置"""
|
|
|
|
|
|
if Memory is None:
|
|
|
|
|
|
logger.warning("mem0ai 未安装")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
config = MemoryConfig(
|
|
|
|
|
|
vector_store=VectorStoreConfig(
|
|
|
|
|
|
provider="qdrant",
|
|
|
|
|
|
config={
|
|
|
|
|
|
"host": self.config['qdrant']['host'],
|
|
|
|
|
|
"port": self.config['qdrant']['port'],
|
|
|
|
|
|
"collection_name": self.config['qdrant']['collection_name'],
|
|
|
|
|
|
"on_disk": True,
|
|
|
|
|
|
"embedding_model_dims": 1024 # 强制同步 Qdrant 集合维度
|
|
|
|
|
|
}
|
|
|
|
|
|
),
|
|
|
|
|
|
llm=LlmConfig(
|
|
|
|
|
|
provider="openai",
|
|
|
|
|
|
config=self.config['llm']['config']
|
|
|
|
|
|
),
|
|
|
|
|
|
embedder=EmbedderConfig(
|
|
|
|
|
|
provider="openai",
|
|
|
|
|
|
config={
|
|
|
|
|
|
"model": os.getenv('MEM0_EMBEDDER_MODEL') or os.getenv('EMBEDDING_MODEL_ID', 'text-embedding-v4'),
|
|
|
|
|
|
"embedding_dims": 1024 # 核心修复:强制覆盖默认的 1536 维度
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
self.local_memory = Memory(config=config)
|
|
|
|
|
|
logger.info("✅ mem0 初始化成功(含 Embedder,1024 维度)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"❌ mem0 初始化失败:{e}")
|
|
|
|
|
|
self.local_memory = None
|
|
|
|
|
|
|
|
|
|
|
|
async def start(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
显式启动异步工作线程
|
|
|
|
|
|
必须在事件循环中调用:await mem0_client.start()
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self._started:
|
|
|
|
|
|
logger.debug("mem0 Client 已启动")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if self.config['async_write']['enabled']:
|
|
|
|
|
|
self.async_queue = AsyncMemoryQueue(
|
|
|
|
|
|
max_size=self.config['async_write']['queue_size']
|
|
|
|
|
|
)
|
|
|
|
|
|
self.async_queue.start_worker(
|
|
|
|
|
|
callback=self._async_write_memory,
|
|
|
|
|
|
batch_size=self.config['async_write']['batch_size'],
|
|
|
|
|
|
flush_interval=self.config['async_write']['flush_interval']
|
|
|
|
|
|
)
|
|
|
|
|
|
self._started = True
|
|
|
|
|
|
logger.info("✅ mem0 Client 异步工作线程已启动")
|
|
|
|
|
|
|
|
|
|
|
|
# ========== Pre-Hook: 智能检索 ==========
|
|
|
|
|
|
|
|
|
|
|
|
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]:
|
|
|
|
|
|
"""Pre-Hook: 对话前智能检索"""
|
|
|
|
|
|
if not self.config['retrieval']['enabled'] or self.local_memory is None:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
cache_key = f"{user_id}:{agent_id}:{query}"
|
|
|
|
|
|
if self.config['cache']['enabled'] and cache_key in self.cache:
|
|
|
|
|
|
cached = self.cache[cache_key]
|
|
|
|
|
|
if time.time() - cached['time'] < self.config['cache']['ttl']:
|
|
|
|
|
|
logger.debug(f"Cache hit: {cache_key}")
|
|
|
|
|
|
return cached['results']
|
|
|
|
|
|
|
|
|
|
|
|
timeout_ms = self.config['retrieval']['timeout_ms']
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
memories = await asyncio.wait_for(
|
|
|
|
|
|
self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']),
|
|
|
|
|
|
timeout=timeout_ms / 1000
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config['cache']['enabled'] and memories:
|
|
|
|
|
|
self.cache[cache_key] = {'results': memories, 'time': time.time()}
|
|
|
|
|
|
self._cleanup_cache()
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆")
|
|
|
|
|
|
return memories
|
|
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
|
logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)")
|
|
|
|
|
|
return []
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Pre-Hook 检索失败:{e}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
def _get_local_search(self, agent_id: str) -> LocalSearchFallback:
|
|
|
|
|
|
"""获取或创建 Layer 3 FTS5 实例(懒初始化 + 自动建索引)"""
|
|
|
|
|
|
aid = agent_id or 'main'
|
|
|
|
|
|
if aid not in self._local_search:
|
|
|
|
|
|
fb = LocalSearchFallback(agent_id=aid)
|
|
|
|
|
|
fb.rebuild_index()
|
|
|
|
|
|
self._local_search[aid] = fb
|
|
|
|
|
|
return self._local_search[aid]
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
三阶段检索 — 按可见性分层,合并去重
|
|
|
|
|
|
Phase 1: public (所有 agent 可见)
|
|
|
|
|
|
Phase 2: project (同 project_id 成员可见)
|
|
|
|
|
|
Phase 3: private (仅 agent_id 本人可见)
|
|
|
|
|
|
Layer 3 Fallback: Qdrant 完全不可达时自动切换 FTS5 本地检索
|
|
|
|
|
|
"""
|
|
|
|
|
|
all_memories: Dict[str, Dict] = {}
|
|
|
|
|
|
per_phase = max(top_k, 3)
|
|
|
|
|
|
qdrant_ok = False # 跟踪 Qdrant 是否至少有一次成功
|
|
|
|
|
|
|
|
|
|
|
|
if self.local_memory is None:
|
|
|
|
|
|
# mem0 未初始化(Qdrant 不可达),直接走 Layer 3
|
|
|
|
|
|
logger.warning("mem0 未初始化,直接使用 Layer 3 FTS5 本地检索")
|
|
|
|
|
|
try:
|
|
|
|
|
|
fb = await asyncio.to_thread(self._get_local_search, agent_id)
|
|
|
|
|
|
local_results = await asyncio.to_thread(fb.search, query, top_k)
|
|
|
|
|
|
return [{
|
|
|
|
|
|
'id': f"fts5:{r.get('source', '')}:{r.get('title', '')}",
|
|
|
|
|
|
'memory': r.get('snippet', ''),
|
|
|
|
|
|
'score': 0.5,
|
|
|
|
|
|
'metadata': {'source': 'layer3_fts5', 'file': r.get('source', '')},
|
|
|
|
|
|
} for r in local_results]
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Layer 3 FTS5 fallback 失败:{e}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
# Phase 1: 检索 public 记忆
|
|
|
|
|
|
try:
|
|
|
|
|
|
public_mems = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search,
|
|
|
|
|
|
query,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
filters={"visibility": "public"},
|
|
|
|
|
|
limit=per_phase
|
|
|
|
|
|
)
|
|
|
|
|
|
qdrant_ok = True
|
|
|
|
|
|
for mem in (public_mems or []):
|
|
|
|
|
|
mid = mem.get('id') if isinstance(mem, dict) else None
|
|
|
|
|
|
if mid and mid not in all_memories:
|
|
|
|
|
|
all_memories[mid] = mem
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Public 记忆检索失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
# Phase 2: 检索 project 记忆 (agent 所属的所有项目)
|
|
|
|
|
|
if agent_id and agent_id != 'general':
|
|
|
|
|
|
agent_projects = get_agent_projects(agent_id)
|
|
|
|
|
|
for project_id in agent_projects:
|
|
|
|
|
|
if project_id == 'global':
|
|
|
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
|
|
|
proj_mems = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search,
|
|
|
|
|
|
query,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
filters={
|
|
|
|
|
|
"visibility": "project",
|
|
|
|
|
|
"project_id": project_id,
|
|
|
|
|
|
},
|
|
|
|
|
|
limit=per_phase
|
|
|
|
|
|
)
|
|
|
|
|
|
qdrant_ok = True
|
|
|
|
|
|
for mem in (proj_mems or []):
|
|
|
|
|
|
mid = mem.get('id') if isinstance(mem, dict) else None
|
|
|
|
|
|
if mid and mid not in all_memories:
|
|
|
|
|
|
all_memories[mid] = mem
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Project({project_id}) 记忆检索失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
# Phase 3: 检索 private 记忆
|
|
|
|
|
|
if agent_id and agent_id != 'general':
|
|
|
|
|
|
try:
|
|
|
|
|
|
private_mems = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search,
|
|
|
|
|
|
query,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
filters={
|
|
|
|
|
|
"visibility": "private",
|
|
|
|
|
|
"agent_id": agent_id,
|
|
|
|
|
|
},
|
|
|
|
|
|
limit=per_phase
|
|
|
|
|
|
)
|
|
|
|
|
|
qdrant_ok = True
|
|
|
|
|
|
for mem in (private_mems or []):
|
|
|
|
|
|
mid = mem.get('id') if isinstance(mem, dict) else None
|
|
|
|
|
|
if mid and mid not in all_memories:
|
|
|
|
|
|
all_memories[mid] = mem
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Private 记忆检索失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
# Fallback: 兼容旧数据(无 visibility 字段)
|
|
|
|
|
|
if user_id:
|
|
|
|
|
|
try:
|
|
|
|
|
|
legacy_mems = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search,
|
|
|
|
|
|
query,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
filters={"agent_id": agent_id} if agent_id and agent_id != 'general' else None,
|
|
|
|
|
|
limit=per_phase
|
|
|
|
|
|
)
|
|
|
|
|
|
qdrant_ok = True
|
|
|
|
|
|
for mem in (legacy_mems or []):
|
|
|
|
|
|
mid = mem.get('id') if isinstance(mem, dict) else None
|
|
|
|
|
|
if mid and mid not in all_memories:
|
|
|
|
|
|
all_memories[mid] = mem
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Legacy 记忆检索失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
# Layer 3 Fallback: Qdrant 完全不可达时自动切换 FTS5
|
|
|
|
|
|
if not qdrant_ok:
|
|
|
|
|
|
logger.warning("Qdrant 不可达,自动切换 Layer 3 FTS5 本地检索")
|
|
|
|
|
|
try:
|
|
|
|
|
|
fb = await asyncio.to_thread(self._get_local_search, agent_id)
|
|
|
|
|
|
local_results = await asyncio.to_thread(fb.search, query, top_k)
|
|
|
|
|
|
for r in local_results:
|
|
|
|
|
|
fid = f"fts5:{r.get('source', '')}:{r.get('title', '')}"
|
|
|
|
|
|
if fid not in all_memories:
|
|
|
|
|
|
all_memories[fid] = {
|
|
|
|
|
|
'id': fid,
|
|
|
|
|
|
'memory': r.get('snippet', ''),
|
|
|
|
|
|
'score': 0.5,
|
|
|
|
|
|
'metadata': {'source': 'layer3_fts5', 'file': r.get('source', '')},
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Layer 3 FTS5 fallback 失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
min_confidence = self.config['retrieval']['min_confidence']
|
|
|
|
|
|
filtered = [
|
|
|
|
|
|
m for m in all_memories.values()
|
|
|
|
|
|
if m.get('score', 1.0) >= min_confidence
|
|
|
|
|
|
]
|
|
|
|
|
|
filtered.sort(key=lambda x: x.get('score', 0), reverse=True)
|
|
|
|
|
|
return filtered[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
def format_memories_for_prompt(self, memories: List[Dict]) -> str:
|
|
|
|
|
|
"""格式化记忆为 Prompt 片段"""
|
|
|
|
|
|
if not memories:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
prompt = "\n\n=== 相关记忆 ===\n"
|
|
|
|
|
|
for i, mem in enumerate(memories, 1):
|
|
|
|
|
|
memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem)
|
|
|
|
|
|
created_at = mem.get('created_at', '') if isinstance(mem, dict) else ''
|
|
|
|
|
|
prompt += f"{i}. {memory_text}"
|
|
|
|
|
|
if created_at:
|
|
|
|
|
|
prompt += f" (记录于:{created_at})"
|
|
|
|
|
|
prompt += "\n"
|
|
|
|
|
|
prompt += "===============\n"
|
|
|
|
|
|
|
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
# ========== Post-Hook: 异步写入 ==========
|
|
|
|
|
|
|
|
|
|
|
|
def _should_skip_memory(self, user_message: str, assistant_message: str) -> bool:
|
|
|
|
|
|
"""判断是否应跳过此对话的记忆写入"""
|
|
|
|
|
|
if not self.config.get('write_filter', {}).get('enabled', True):
|
|
|
|
|
|
return False
|
|
|
|
|
|
min_len = self.config.get('write_filter', {}).get('min_user_message_length', 5)
|
|
|
|
|
|
if len(user_message.strip()) < min_len:
|
|
|
|
|
|
return True
|
|
|
|
|
|
if SKIP_PATTERNS.match(user_message.strip()):
|
|
|
|
|
|
return True
|
|
|
|
|
|
if SYSTEM_CMD_PATTERN.match(user_message.strip()):
|
|
|
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def _classify_memory_type(self, user_message: str, assistant_message: str) -> str:
|
|
|
|
|
|
"""自动分类记忆类型,决定过期策略"""
|
|
|
|
|
|
combined = user_message + ' ' + assistant_message
|
|
|
|
|
|
if MEMORY_KEYWORDS.search(combined):
|
|
|
|
|
|
return 'preference'
|
|
|
|
|
|
if any(kw in combined for kw in ('部署', '配置', '架构', '端口', '安装', '版本')):
|
|
|
|
|
|
return 'knowledge'
|
|
|
|
|
|
return 'session'
|
|
|
|
|
|
|
|
|
|
|
|
def _classify_visibility(self, user_message: str, assistant_message: str, agent_id: str = None):
|
|
|
|
|
|
"""自动分类记忆可见性,返回 (visibility, project_id)"""
|
|
|
|
|
|
combined = user_message + ' ' + assistant_message
|
|
|
|
|
|
if PUBLIC_KEYWORDS.search(combined):
|
|
|
|
|
|
return 'public', None
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否匹配 agent 所属项目的关键词
|
|
|
|
|
|
if agent_id and agent_id != 'general':
|
|
|
|
|
|
proj_kws = _build_project_keywords()
|
|
|
|
|
|
agent_projects = get_agent_projects(agent_id)
|
|
|
|
|
|
for pid in agent_projects:
|
|
|
|
|
|
if pid == 'global' or pid not in proj_kws:
|
|
|
|
|
|
continue
|
|
|
|
|
|
if any(kw in combined for kw in proj_kws[pid]):
|
|
|
|
|
|
return 'project', pid
|
|
|
|
|
|
|
|
|
|
|
|
return 'private', None
|
|
|
|
|
|
|
|
|
|
|
|
def post_hook_add(self, user_message: str, assistant_message: str,
|
|
|
|
|
|
user_id: str = None, agent_id: str = None,
|
|
|
|
|
|
visibility: str = None, project_id: str = None,
|
|
|
|
|
|
memory_type: str = None):
|
|
|
|
|
|
"""Post-Hook: 对话后异步写入(同步方法,仅添加到队列)
|
|
|
|
|
|
|
|
|
|
|
|
支持三级可见性 (public/project/private) 和记忆衰减 (expiration_date)。
|
|
|
|
|
|
内置智能写入过滤,跳过无价值对话。
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.config['async_write']['enabled']:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if self._should_skip_memory(user_message, assistant_message):
|
|
|
|
|
|
logger.debug(f"Post-Hook 跳过(写入过滤):{user_message[:30]}")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
user_id = self.config['metadata']['default_user_id']
|
|
|
|
|
|
if not agent_id:
|
|
|
|
|
|
agent_id = self.config['metadata']['default_agent_id']
|
|
|
|
|
|
|
|
|
|
|
|
if not memory_type:
|
|
|
|
|
|
memory_type = self._classify_memory_type(user_message, assistant_message)
|
|
|
|
|
|
if not visibility:
|
|
|
|
|
|
visibility, auto_project_id = self._classify_visibility(user_message, assistant_message, agent_id)
|
|
|
|
|
|
if not project_id and auto_project_id:
|
|
|
|
|
|
project_id = auto_project_id
|
|
|
|
|
|
|
|
|
|
|
|
messages = [
|
|
|
|
|
|
{"role": "user", "content": user_message},
|
|
|
|
|
|
{"role": "assistant", "content": assistant_message}
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if self.async_queue:
|
|
|
|
|
|
self.async_queue.add({
|
|
|
|
|
|
'messages': messages,
|
|
|
|
|
|
'user_id': user_id,
|
|
|
|
|
|
'agent_id': agent_id,
|
|
|
|
|
|
'visibility': visibility,
|
|
|
|
|
|
'project_id': project_id,
|
|
|
|
|
|
'memory_type': memory_type,
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}, "
|
|
|
|
|
|
f"visibility={visibility}, type={memory_type}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning("异步队列未初始化")
|
|
|
|
|
|
|
|
|
|
|
|
async def _async_write_memory(self, item: Dict):
|
|
|
|
|
|
"""异步写入记忆(后台任务)"""
|
|
|
|
|
|
if self.local_memory is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
timeout_ms = self.config['async_write']['timeout_ms']
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000)
|
|
|
|
|
|
logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}")
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
|
logger.warning(f"异步写入超时 ({timeout_ms}ms)")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"异步写入失败:{e}")
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_write(self, item: Dict):
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行写入 — 三级可见性 + 记忆衰减
|
|
|
|
|
|
metadata 携带 visibility / project_id / agent_id
|
|
|
|
|
|
expiration_date 根据 memory_type 自动设置
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.local_memory is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
visibility = item.get('visibility', 'private')
|
|
|
|
|
|
memory_type = item.get('memory_type', 'session')
|
|
|
|
|
|
|
|
|
|
|
|
custom_metadata = {
|
|
|
|
|
|
"agent_id": item['agent_id'],
|
|
|
|
|
|
"visibility": visibility,
|
|
|
|
|
|
"project_id": item.get('project_id') or '',
|
|
|
|
|
|
"business_type": item.get('business_type', item['agent_id']),
|
|
|
|
|
|
"memory_type": memory_type,
|
|
|
|
|
|
"source": "openclaw",
|
|
|
|
|
|
"timestamp": item.get('timestamp'),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ttl = EXPIRATION_MAP.get(memory_type)
|
|
|
|
|
|
expiration_date = (datetime.now() + ttl).isoformat() if ttl else None
|
|
|
|
|
|
|
|
|
|
|
|
add_kwargs = dict(
|
|
|
|
|
|
messages=item['messages'],
|
|
|
|
|
|
user_id=item['user_id'],
|
|
|
|
|
|
agent_id=item['agent_id'],
|
|
|
|
|
|
metadata=custom_metadata,
|
|
|
|
|
|
)
|
|
|
|
|
|
if expiration_date:
|
|
|
|
|
|
add_kwargs['expiration_date'] = expiration_date
|
|
|
|
|
|
|
|
|
|
|
|
await asyncio.to_thread(self.local_memory.add, **add_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup_cache(self):
|
|
|
|
|
|
"""清理过期缓存"""
|
|
|
|
|
|
if not self.config['cache']['enabled']:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
ttl = self.config['cache']['ttl']
|
|
|
|
|
|
|
|
|
|
|
|
expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl]
|
|
|
|
|
|
for key in expired_keys:
|
|
|
|
|
|
del self.cache[key]
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.cache) > self.config['cache']['max_size']:
|
|
|
|
|
|
oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']]
|
|
|
|
|
|
for key in oldest_keys:
|
|
|
|
|
|
del self.cache[key]
|
|
|
|
|
|
|
|
|
|
|
|
def get_status(self) -> Dict:
|
|
|
|
|
|
"""获取状态"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"initialized": self.local_memory is not None,
|
|
|
|
|
|
"started": self._started,
|
|
|
|
|
|
"async_queue_enabled": self.config['async_write']['enabled'],
|
|
|
|
|
|
"queue_size": len(self.async_queue.queue) if self.async_queue else 0,
|
|
|
|
|
|
"cache_size": len(self.cache),
|
|
|
|
|
|
"qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# ========== Knowledge Publishing (Hub Agent) ==========
|
|
|
|
|
|
|
|
|
|
|
|
async def publish_knowledge(self, content, category='knowledge',
|
|
|
|
|
|
visibility='public', project_id=None,
|
|
|
|
|
|
agent_id='main'):
|
|
|
|
|
|
"""Publish knowledge/best practices to shared memory.
|
|
|
|
|
|
Used by hub agent to share with all agents (public) or project teams."""
|
|
|
|
|
|
if visibility == 'project' and not project_id:
|
|
|
|
|
|
raise ValueError("project visibility requires project_id")
|
|
|
|
|
|
item = {
|
|
|
|
|
|
'messages': [{'role': 'system', 'content': content}],
|
|
|
|
|
|
'user_id': 'system',
|
|
|
|
|
|
'agent_id': agent_id,
|
|
|
|
|
|
'visibility': visibility,
|
|
|
|
|
|
'project_id': project_id,
|
|
|
|
|
|
'memory_type': category,
|
|
|
|
|
|
'timestamp': datetime.now().isoformat(),
|
|
|
|
|
|
}
|
|
|
|
|
|
await self._execute_write(item)
|
|
|
|
|
|
logger.info(f"Published {category} ({visibility}): {content[:80]}...")
|
|
|
|
|
|
|
|
|
|
|
|
# ========== Cold Start (Three-Phase) ==========
|
|
|
|
|
|
|
|
|
|
|
|
async def cold_start_search(self, agent_id='main', user_id='default', top_k=5):
|
|
|
|
|
|
"""Three-phase cold start: public -> project -> private.
|
|
|
|
|
|
Returns merged memories ordered by phase then score."""
|
|
|
|
|
|
if self.local_memory is None:
|
|
|
|
|
|
return []
|
|
|
|
|
|
all_mems = {}
|
|
|
|
|
|
phase = {}
|
|
|
|
|
|
|
|
|
|
|
|
for q in ["system best practices and conventions",
|
|
|
|
|
|
"shared configuration and architecture decisions"]:
|
|
|
|
|
|
try:
|
|
|
|
|
|
r = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search, q, user_id=user_id,
|
|
|
|
|
|
filters={"visibility": "public"}, limit=top_k)
|
|
|
|
|
|
for m in (r or []):
|
|
|
|
|
|
mid = m.get('id') if isinstance(m, dict) else None
|
|
|
|
|
|
if mid and mid not in all_mems:
|
|
|
|
|
|
all_mems[mid] = m
|
|
|
|
|
|
phase[mid] = 0
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Cold start public failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
for pid in get_agent_projects(agent_id):
|
|
|
|
|
|
if pid == 'global':
|
|
|
|
|
|
continue
|
|
|
|
|
|
for q in ["project guidelines and shared knowledge",
|
|
|
|
|
|
"recent project decisions and updates"]:
|
|
|
|
|
|
try:
|
|
|
|
|
|
r = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search, q, user_id=user_id,
|
|
|
|
|
|
filters={"visibility": "project", "project_id": pid},
|
|
|
|
|
|
limit=top_k)
|
|
|
|
|
|
for m in (r or []):
|
|
|
|
|
|
mid = m.get('id') if isinstance(m, dict) else None
|
|
|
|
|
|
if mid and mid not in all_mems:
|
|
|
|
|
|
all_mems[mid] = m
|
|
|
|
|
|
phase[mid] = 1
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Cold start project({pid}) failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
for q in ["current active tasks deployment progress",
|
|
|
|
|
|
"pending work items todos",
|
|
|
|
|
|
"recent important decisions configurations"]:
|
|
|
|
|
|
try:
|
|
|
|
|
|
r = await asyncio.to_thread(
|
|
|
|
|
|
self.local_memory.search, q, user_id=user_id,
|
|
|
|
|
|
filters={"visibility": "private", "agent_id": agent_id},
|
|
|
|
|
|
limit=top_k)
|
|
|
|
|
|
for m in (r or []):
|
|
|
|
|
|
mid = m.get('id') if isinstance(m, dict) else None
|
|
|
|
|
|
if mid and mid not in all_mems:
|
|
|
|
|
|
all_mems[mid] = m
|
|
|
|
|
|
phase[mid] = 2
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.debug(f"Cold start private failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
mc = self.config['retrieval']['min_confidence']
|
|
|
|
|
|
out = [m for m in all_mems.values() if m.get('score', 1.0) >= mc]
|
|
|
|
|
|
out.sort(key=lambda m: (phase.get(m.get('id', ''), 9), -m.get('score', 0)))
|
|
|
|
|
|
return out[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
async def shutdown(self):
|
|
|
|
|
|
if self.async_queue:
|
|
|
|
|
|
await self.async_queue.stop()
|
|
|
|
|
|
logger.info("mem0 Client shutdown complete")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局客户端实例
|
|
|
|
|
|
mem0_client = Mem0Client()
|