|
|
#!/usr/bin/env python3 |
|
|
""" |
|
|
mem0 Client for OpenClaw - 生产级纯异步架构 |
|
|
Pre-Hook 检索注入 + Post-Hook 异步写入 |
|
|
三级可见性隔离 (public / project / private) |
|
|
记忆衰减 (expiration_date) + 智能写入过滤 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import asyncio |
|
|
import logging |
|
|
import time |
|
|
import yaml |
|
|
from typing import List, Dict, Any |
|
|
from collections import deque |
|
|
from datetime import datetime, timedelta |
|
|
from pathlib import Path |
|
|
|
|
|
# ========== DashScope 环境变量配置 ========== |
|
|
# Gemini Pro Embedding 模型:text-embedding-v4 (1024 维度) |
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' # 关键:兼容模式需要此变量 |
|
|
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') |
|
|
if not _dashscope_key: |
|
|
_dashscope_key = os.getenv('DASHSCOPE_API_KEY', '') |
|
|
if _dashscope_key: |
|
|
os.environ['OPENAI_API_KEY'] = _dashscope_key |
|
|
elif not os.environ.get('OPENAI_API_KEY'): |
|
|
logging.warning("MEM0_DASHSCOPE_API_KEY not set; mem0 embedding/LLM calls will fail") |
|
|
|
|
|
try: |
|
|
from mem0 import Memory |
|
|
from mem0.configs.base import MemoryConfig, VectorStoreConfig, LlmConfig, EmbedderConfig |
|
|
except ImportError as e: |
|
|
print(f"⚠️ mem0ai 导入失败:{e}") |
|
|
Memory = None |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AsyncMemoryQueue: |
|
|
"""纯异步记忆写入队列""" |
|
|
|
|
|
def __init__(self, max_size: int = 100): |
|
|
self.queue = deque(maxlen=max_size) |
|
|
self.lock = asyncio.Lock() |
|
|
self.running = False |
|
|
self._worker_task = None |
|
|
self.callback = None |
|
|
self.batch_size = 10 |
|
|
self.flush_interval = 60 |
|
|
|
|
|
def add(self, item: Dict[str, Any]): |
|
|
"""添加任务到队列(同步方法)""" |
|
|
try: |
|
|
if len(self.queue) < self.queue.maxlen: |
|
|
self.queue.append({ |
|
|
'messages': item['messages'], |
|
|
'user_id': item['user_id'], |
|
|
'agent_id': item['agent_id'], |
|
|
'visibility': item.get('visibility', 'private'), |
|
|
'project_id': item.get('project_id'), |
|
|
'memory_type': item.get('memory_type', 'session'), |
|
|
'timestamp': item.get('timestamp', datetime.now().isoformat()) |
|
|
}) |
|
|
else: |
|
|
logger.warning("异步队列已满,丢弃旧任务") |
|
|
except Exception as e: |
|
|
logger.error(f"队列添加失败:{e}") |
|
|
|
|
|
async def get_batch(self, batch_size: int) -> List[Dict]: |
|
|
"""获取批量任务(异步方法)""" |
|
|
async with self.lock: |
|
|
batch = [] |
|
|
while len(batch) < batch_size and self.queue: |
|
|
batch.append(self.queue.popleft()) |
|
|
return batch |
|
|
|
|
|
def start_worker(self, callback, batch_size: int, flush_interval: int): |
|
|
"""启动异步后台任务""" |
|
|
self.running = True |
|
|
self.callback = callback |
|
|
self.batch_size = batch_size |
|
|
self.flush_interval = flush_interval |
|
|
self._worker_task = asyncio.create_task(self._worker_loop()) |
|
|
logger.info(f"✅ 异步工作线程已启动 (batch_size={batch_size}, interval={flush_interval}s)") |
|
|
|
|
|
async def _worker_loop(self): |
|
|
"""异步工作循环""" |
|
|
last_flush = time.time() |
|
|
|
|
|
while self.running: |
|
|
try: |
|
|
if self.queue: |
|
|
batch = await self.get_batch(self.batch_size) |
|
|
if batch: |
|
|
asyncio.create_task(self._process_batch(batch)) |
|
|
|
|
|
if time.time() - last_flush > self.flush_interval: |
|
|
if self.queue: |
|
|
batch = await self.get_batch(self.batch_size) |
|
|
if batch: |
|
|
asyncio.create_task(self._process_batch(batch)) |
|
|
last_flush = time.time() |
|
|
|
|
|
await asyncio.sleep(1) |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
logger.info("异步工作线程已取消") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"异步工作线程错误:{e}") |
|
|
await asyncio.sleep(5) |
|
|
|
|
|
async def _process_batch(self, batch: List[Dict]): |
|
|
"""处理批量任务""" |
|
|
try: |
|
|
logger.debug(f"开始处理批量任务:{len(batch)} 条") |
|
|
for item in batch: |
|
|
await self.callback(item) |
|
|
logger.debug(f"批量任务处理完成") |
|
|
except Exception as e: |
|
|
logger.error(f"批量处理失败:{e}") |
|
|
|
|
|
async def stop(self): |
|
|
"""优雅关闭""" |
|
|
self.running = False |
|
|
if self._worker_task: |
|
|
self._worker_task.cancel() |
|
|
try: |
|
|
await self._worker_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
logger.info("异步工作线程已关闭") |
|
|
|
|
|
|
|
|
EXPIRATION_MAP = { |
|
|
'session': timedelta(days=7), |
|
|
'chat_summary': timedelta(days=30), |
|
|
'preference': None, |
|
|
'knowledge': None, |
|
|
} |
|
|
|
|
|
SKIP_PATTERNS = re.compile( |
|
|
r'^(好的|收到|OK|ok|嗯|行|没问题|感谢|谢谢|了解|明白|知道了|👍|✅|❌)$', |
|
|
re.IGNORECASE |
|
|
) |
|
|
|
|
|
SYSTEM_CMD_PATTERN = re.compile(r'^/') |
|
|
|
|
|
MEMORY_KEYWORDS = re.compile( |
|
|
r'(记住|以后|偏好|配置|设置|规则|永远|始终|总是|不要|禁止)', |
|
|
) |
|
|
|
|
|
PUBLIC_KEYWORDS = re.compile( |
|
|
r'(所有人|通知|全局|公告|大家|集群)', |
|
|
) |
|
|
|
|
|
|
|
|
def _load_project_registry() -> Dict: |
|
|
"""从 project_registry.yaml 加载项目注册表""" |
|
|
registry_path = Path(__file__).parent / 'project_registry.yaml' |
|
|
if registry_path.exists(): |
|
|
try: |
|
|
with open(registry_path, 'r', encoding='utf-8') as f: |
|
|
return yaml.safe_load(f) or {} |
|
|
except Exception: |
|
|
pass |
|
|
return {} |
|
|
|
|
|
|
|
|
def get_agent_projects(agent_id: str) -> List[str]: |
|
|
"""查询一个 agent 所属的所有 project_id""" |
|
|
registry = _load_project_registry() |
|
|
projects = registry.get('projects', {}) |
|
|
result = [] |
|
|
for pid, pconf in projects.items(): |
|
|
members = pconf.get('members', []) |
|
|
if '*' in members or agent_id in members: |
|
|
result.append(pid) |
|
|
return result |
|
|
|
|
|
|
|
|
class Mem0Client: |
|
|
""" |
|
|
生产级 mem0 客户端 |
|
|
纯异步架构 + 三级可见性 + 记忆衰减 + 智能写入过滤 |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Dict = None): |
|
|
self.config = config or self._load_default_config() |
|
|
self.local_memory = None |
|
|
self.async_queue = None |
|
|
self.cache = {} |
|
|
self._started = False |
|
|
self._init_memory() |
|
|
|
|
|
def _load_default_config(self) -> Dict: |
|
|
"""加载默认配置 - 单库融合架构""" |
|
|
return { |
|
|
"qdrant": { |
|
|
"host": os.getenv('MEM0_QDRANT_HOST', 'localhost'), |
|
|
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), |
|
|
"collection_name": "mem0_v4_shared" # 统一共享 Collection(陈医生/张大师共用) |
|
|
}, |
|
|
"llm": { |
|
|
"provider": "openai", |
|
|
"config": { |
|
|
"model": os.getenv('MEM0_LLM_MODEL', 'qwen-plus') |
|
|
} |
|
|
}, |
|
|
"embedder": { |
|
|
"provider": "openai", |
|
|
"config": { |
|
|
"model": os.getenv('MEM0_EMBEDDER_MODEL', 'text-embedding-v4'), |
|
|
"dimensions": 1024 # DashScope text-embedding-v4 支持的最大维度 |
|
|
} |
|
|
}, |
|
|
"retrieval": { |
|
|
"enabled": True, |
|
|
"top_k": 5, |
|
|
"min_confidence": 0.7, |
|
|
"timeout_ms": 5000 |
|
|
}, |
|
|
"async_write": { |
|
|
"enabled": True, |
|
|
"queue_size": 100, |
|
|
"batch_size": 10, |
|
|
"flush_interval": 60, |
|
|
"timeout_ms": 5000 |
|
|
}, |
|
|
"cache": { |
|
|
"enabled": True, |
|
|
"ttl": 300, |
|
|
"max_size": 1000 |
|
|
}, |
|
|
"fallback": { |
|
|
"enabled": True, |
|
|
"log_level": "WARNING", |
|
|
"retry_attempts": 2 |
|
|
}, |
|
|
"metadata": { |
|
|
"default_user_id": "default", |
|
|
"default_agent_id": "general" |
|
|
}, |
|
|
"write_filter": { |
|
|
"enabled": True, |
|
|
"min_user_message_length": 5, |
|
|
} |
|
|
} |
|
|
|
|
|
def _init_memory(self): |
|
|
"""初始化 mem0(同步操作)- 三位一体完整配置""" |
|
|
if Memory is None: |
|
|
logger.warning("mem0ai 未安装") |
|
|
return |
|
|
|
|
|
try: |
|
|
config = MemoryConfig( |
|
|
vector_store=VectorStoreConfig( |
|
|
provider="qdrant", |
|
|
config={ |
|
|
"host": self.config['qdrant']['host'], |
|
|
"port": self.config['qdrant']['port'], |
|
|
"collection_name": self.config['qdrant']['collection_name'], |
|
|
"on_disk": True, |
|
|
"embedding_model_dims": 1024 # 强制同步 Qdrant 集合维度 |
|
|
} |
|
|
), |
|
|
llm=LlmConfig( |
|
|
provider="openai", |
|
|
config=self.config['llm']['config'] |
|
|
), |
|
|
embedder=EmbedderConfig( |
|
|
provider="openai", |
|
|
config={ |
|
|
"model": "text-embedding-v4", |
|
|
"embedding_dims": 1024 # 核心修复:强制覆盖默认的 1536 维度 |
|
|
# api_base 和 api_key 通过环境变量 OPENAI_BASE_URL 和 OPENAI_API_KEY 读取 |
|
|
} |
|
|
) |
|
|
) |
|
|
self.local_memory = Memory(config=config) |
|
|
logger.info("✅ mem0 初始化成功(含 Embedder,1024 维度)") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ mem0 初始化失败:{e}") |
|
|
self.local_memory = None |
|
|
|
|
|
async def start(self): |
|
|
""" |
|
|
显式启动异步工作线程 |
|
|
必须在事件循环中调用:await mem0_client.start() |
|
|
""" |
|
|
if self._started: |
|
|
logger.debug("mem0 Client 已启动") |
|
|
return |
|
|
|
|
|
if self.config['async_write']['enabled']: |
|
|
self.async_queue = AsyncMemoryQueue( |
|
|
max_size=self.config['async_write']['queue_size'] |
|
|
) |
|
|
self.async_queue.start_worker( |
|
|
callback=self._async_write_memory, |
|
|
batch_size=self.config['async_write']['batch_size'], |
|
|
flush_interval=self.config['async_write']['flush_interval'] |
|
|
) |
|
|
self._started = True |
|
|
logger.info("✅ mem0 Client 异步工作线程已启动") |
|
|
|
|
|
# ========== Pre-Hook: 智能检索 ========== |
|
|
|
|
|
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: |
|
|
"""Pre-Hook: 对话前智能检索""" |
|
|
if not self.config['retrieval']['enabled'] or self.local_memory is None: |
|
|
return [] |
|
|
|
|
|
cache_key = f"{user_id}:{agent_id}:{query}" |
|
|
if self.config['cache']['enabled'] and cache_key in self.cache: |
|
|
cached = self.cache[cache_key] |
|
|
if time.time() - cached['time'] < self.config['cache']['ttl']: |
|
|
logger.debug(f"Cache hit: {cache_key}") |
|
|
return cached['results'] |
|
|
|
|
|
timeout_ms = self.config['retrieval']['timeout_ms'] |
|
|
|
|
|
try: |
|
|
memories = await asyncio.wait_for( |
|
|
self._execute_search(query, user_id, agent_id, top_k or self.config['retrieval']['top_k']), |
|
|
timeout=timeout_ms / 1000 |
|
|
) |
|
|
|
|
|
if self.config['cache']['enabled'] and memories: |
|
|
self.cache[cache_key] = {'results': memories, 'time': time.time()} |
|
|
self._cleanup_cache() |
|
|
|
|
|
logger.info(f"Pre-Hook 检索完成:{len(memories)} 条记忆") |
|
|
return memories |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.warning(f"Pre-Hook 检索超时 ({timeout_ms}ms)") |
|
|
return [] |
|
|
except Exception as e: |
|
|
logger.warning(f"Pre-Hook 检索失败:{e}") |
|
|
return [] |
|
|
|
|
|
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: |
|
|
""" |
|
|
三阶段检索 — 按可见性分层,合并去重 |
|
|
Phase 1: public (所有 agent 可见) |
|
|
Phase 2: project (同 project_id 成员可见) |
|
|
Phase 3: private (仅 agent_id 本人可见) |
|
|
""" |
|
|
if self.local_memory is None: |
|
|
return [] |
|
|
|
|
|
all_memories: Dict[str, Dict] = {} |
|
|
per_phase = max(top_k, 3) |
|
|
|
|
|
# Phase 1: 检索 public 记忆 |
|
|
try: |
|
|
public_mems = await asyncio.to_thread( |
|
|
self.local_memory.search, |
|
|
query, |
|
|
user_id=user_id, |
|
|
filters={"visibility": "public"}, |
|
|
limit=per_phase |
|
|
) |
|
|
for mem in (public_mems or []): |
|
|
mid = mem.get('id') if isinstance(mem, dict) else None |
|
|
if mid and mid not in all_memories: |
|
|
all_memories[mid] = mem |
|
|
except Exception as e: |
|
|
logger.debug(f"Public 记忆检索失败:{e}") |
|
|
|
|
|
# Phase 2: 检索 project 记忆 (agent 所属的所有项目) |
|
|
if agent_id and agent_id != 'general': |
|
|
agent_projects = get_agent_projects(agent_id) |
|
|
for project_id in agent_projects: |
|
|
if project_id == 'global': |
|
|
continue |
|
|
try: |
|
|
proj_mems = await asyncio.to_thread( |
|
|
self.local_memory.search, |
|
|
query, |
|
|
user_id=user_id, |
|
|
filters={ |
|
|
"visibility": "project", |
|
|
"project_id": project_id, |
|
|
}, |
|
|
limit=per_phase |
|
|
) |
|
|
for mem in (proj_mems or []): |
|
|
mid = mem.get('id') if isinstance(mem, dict) else None |
|
|
if mid and mid not in all_memories: |
|
|
all_memories[mid] = mem |
|
|
except Exception as e: |
|
|
logger.debug(f"Project({project_id}) 记忆检索失败:{e}") |
|
|
|
|
|
# Phase 3: 检索 private 记忆 |
|
|
if agent_id and agent_id != 'general': |
|
|
try: |
|
|
private_mems = await asyncio.to_thread( |
|
|
self.local_memory.search, |
|
|
query, |
|
|
user_id=user_id, |
|
|
filters={ |
|
|
"visibility": "private", |
|
|
"agent_id": agent_id, |
|
|
}, |
|
|
limit=per_phase |
|
|
) |
|
|
for mem in (private_mems or []): |
|
|
mid = mem.get('id') if isinstance(mem, dict) else None |
|
|
if mid and mid not in all_memories: |
|
|
all_memories[mid] = mem |
|
|
except Exception as e: |
|
|
logger.debug(f"Private 记忆检索失败:{e}") |
|
|
|
|
|
# Fallback: 兼容旧数据(无 visibility 字段) |
|
|
if user_id: |
|
|
try: |
|
|
legacy_mems = await asyncio.to_thread( |
|
|
self.local_memory.search, |
|
|
query, |
|
|
user_id=user_id, |
|
|
filters={"agent_id": agent_id} if agent_id and agent_id != 'general' else None, |
|
|
limit=per_phase |
|
|
) |
|
|
for mem in (legacy_mems or []): |
|
|
mid = mem.get('id') if isinstance(mem, dict) else None |
|
|
if mid and mid not in all_memories: |
|
|
all_memories[mid] = mem |
|
|
except Exception as e: |
|
|
logger.debug(f"Legacy 记忆检索失败:{e}") |
|
|
|
|
|
min_confidence = self.config['retrieval']['min_confidence'] |
|
|
filtered = [ |
|
|
m for m in all_memories.values() |
|
|
if m.get('score', 1.0) >= min_confidence |
|
|
] |
|
|
filtered.sort(key=lambda x: x.get('score', 0), reverse=True) |
|
|
return filtered[:top_k] |
|
|
|
|
|
def format_memories_for_prompt(self, memories: List[Dict]) -> str: |
|
|
"""格式化记忆为 Prompt 片段""" |
|
|
if not memories: |
|
|
return "" |
|
|
|
|
|
prompt = "\n\n=== 相关记忆 ===\n" |
|
|
for i, mem in enumerate(memories, 1): |
|
|
memory_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) |
|
|
created_at = mem.get('created_at', '') if isinstance(mem, dict) else '' |
|
|
prompt += f"{i}. {memory_text}" |
|
|
if created_at: |
|
|
prompt += f" (记录于:{created_at})" |
|
|
prompt += "\n" |
|
|
prompt += "===============\n" |
|
|
|
|
|
return prompt |
|
|
|
|
|
# ========== Post-Hook: 异步写入 ========== |
|
|
|
|
|
def _should_skip_memory(self, user_message: str, assistant_message: str) -> bool: |
|
|
"""判断是否应跳过此对话的记忆写入""" |
|
|
if not self.config.get('write_filter', {}).get('enabled', True): |
|
|
return False |
|
|
min_len = self.config.get('write_filter', {}).get('min_user_message_length', 5) |
|
|
if len(user_message.strip()) < min_len: |
|
|
return True |
|
|
if SKIP_PATTERNS.match(user_message.strip()): |
|
|
return True |
|
|
if SYSTEM_CMD_PATTERN.match(user_message.strip()): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _classify_memory_type(self, user_message: str, assistant_message: str) -> str: |
|
|
"""自动分类记忆类型,决定过期策略""" |
|
|
combined = user_message + ' ' + assistant_message |
|
|
if MEMORY_KEYWORDS.search(combined): |
|
|
return 'preference' |
|
|
if any(kw in combined for kw in ('部署', '配置', '架构', '端口', '安装', '版本')): |
|
|
return 'knowledge' |
|
|
return 'session' |
|
|
|
|
|
def _classify_visibility(self, user_message: str, assistant_message: str, agent_id: str = None) -> str: |
|
|
"""自动分类记忆可见性""" |
|
|
combined = user_message + ' ' + assistant_message |
|
|
if PUBLIC_KEYWORDS.search(combined): |
|
|
return 'public' |
|
|
return 'private' |
|
|
|
|
|
def post_hook_add(self, user_message: str, assistant_message: str, |
|
|
user_id: str = None, agent_id: str = None, |
|
|
visibility: str = None, project_id: str = None, |
|
|
memory_type: str = None): |
|
|
"""Post-Hook: 对话后异步写入(同步方法,仅添加到队列) |
|
|
|
|
|
支持三级可见性 (public/project/private) 和记忆衰减 (expiration_date)。 |
|
|
内置智能写入过滤,跳过无价值对话。 |
|
|
""" |
|
|
if not self.config['async_write']['enabled']: |
|
|
return |
|
|
|
|
|
if self._should_skip_memory(user_message, assistant_message): |
|
|
logger.debug(f"Post-Hook 跳过(写入过滤):{user_message[:30]}") |
|
|
return |
|
|
|
|
|
if not user_id: |
|
|
user_id = self.config['metadata']['default_user_id'] |
|
|
if not agent_id: |
|
|
agent_id = self.config['metadata']['default_agent_id'] |
|
|
|
|
|
if not memory_type: |
|
|
memory_type = self._classify_memory_type(user_message, assistant_message) |
|
|
if not visibility: |
|
|
visibility = self._classify_visibility(user_message, assistant_message, agent_id) |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": user_message}, |
|
|
{"role": "assistant", "content": assistant_message} |
|
|
] |
|
|
|
|
|
if self.async_queue: |
|
|
self.async_queue.add({ |
|
|
'messages': messages, |
|
|
'user_id': user_id, |
|
|
'agent_id': agent_id, |
|
|
'visibility': visibility, |
|
|
'project_id': project_id, |
|
|
'memory_type': memory_type, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
}) |
|
|
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}, " |
|
|
f"visibility={visibility}, type={memory_type}") |
|
|
else: |
|
|
logger.warning("异步队列未初始化") |
|
|
|
|
|
async def _async_write_memory(self, item: Dict): |
|
|
"""异步写入记忆(后台任务)""" |
|
|
if self.local_memory is None: |
|
|
return |
|
|
|
|
|
timeout_ms = self.config['async_write']['timeout_ms'] |
|
|
|
|
|
try: |
|
|
await asyncio.wait_for(self._execute_write(item), timeout=timeout_ms / 1000) |
|
|
logger.debug(f"异步写入成功:user={item['user_id']}, agent={item['agent_id']}") |
|
|
except asyncio.TimeoutError: |
|
|
logger.warning(f"异步写入超时 ({timeout_ms}ms)") |
|
|
except Exception as e: |
|
|
logger.warning(f"异步写入失败:{e}") |
|
|
|
|
|
async def _execute_write(self, item: Dict): |
|
|
""" |
|
|
执行写入 — 三级可见性 + 记忆衰减 |
|
|
metadata 携带 visibility / project_id / agent_id |
|
|
expiration_date 根据 memory_type 自动设置 |
|
|
""" |
|
|
if self.local_memory is None: |
|
|
return |
|
|
|
|
|
visibility = item.get('visibility', 'private') |
|
|
memory_type = item.get('memory_type', 'session') |
|
|
|
|
|
custom_metadata = { |
|
|
"agent_id": item['agent_id'], |
|
|
"visibility": visibility, |
|
|
"project_id": item.get('project_id') or '', |
|
|
"business_type": item.get('business_type', item['agent_id']), |
|
|
"memory_type": memory_type, |
|
|
"source": "openclaw", |
|
|
"timestamp": item.get('timestamp'), |
|
|
} |
|
|
|
|
|
ttl = EXPIRATION_MAP.get(memory_type) |
|
|
expiration_date = (datetime.now() + ttl).isoformat() if ttl else None |
|
|
|
|
|
add_kwargs = dict( |
|
|
messages=item['messages'], |
|
|
user_id=item['user_id'], |
|
|
agent_id=item['agent_id'], |
|
|
metadata=custom_metadata, |
|
|
) |
|
|
if expiration_date: |
|
|
add_kwargs['expiration_date'] = expiration_date |
|
|
|
|
|
await asyncio.to_thread(self.local_memory.add, **add_kwargs) |
|
|
|
|
|
def _cleanup_cache(self): |
|
|
"""清理过期缓存""" |
|
|
if not self.config['cache']['enabled']: |
|
|
return |
|
|
|
|
|
current_time = time.time() |
|
|
ttl = self.config['cache']['ttl'] |
|
|
|
|
|
expired_keys = [k for k, v in self.cache.items() if current_time - v['time'] > ttl] |
|
|
for key in expired_keys: |
|
|
del self.cache[key] |
|
|
|
|
|
if len(self.cache) > self.config['cache']['max_size']: |
|
|
oldest_keys = sorted(self.cache.keys(), key=lambda k: self.cache[k]['time'])[:len(self.cache) - self.config['cache']['max_size']] |
|
|
for key in oldest_keys: |
|
|
del self.cache[key] |
|
|
|
|
|
def get_status(self) -> Dict: |
|
|
"""获取状态""" |
|
|
return { |
|
|
"initialized": self.local_memory is not None, |
|
|
"started": self._started, |
|
|
"async_queue_enabled": self.config['async_write']['enabled'], |
|
|
"queue_size": len(self.async_queue.queue) if self.async_queue else 0, |
|
|
"cache_size": len(self.cache), |
|
|
"qdrant": f"{self.config['qdrant']['host']}:{self.config['qdrant']['port']}" |
|
|
} |
|
|
|
|
|
# ========== Knowledge Publishing (Hub Agent) ========== |
|
|
|
|
|
async def publish_knowledge(self, content, category='knowledge', |
|
|
visibility='public', project_id=None, |
|
|
agent_id='main'): |
|
|
"""Publish knowledge/best practices to shared memory. |
|
|
Used by hub agent to share with all agents (public) or project teams.""" |
|
|
if visibility == 'project' and not project_id: |
|
|
raise ValueError("project visibility requires project_id") |
|
|
item = { |
|
|
'messages': [{'role': 'system', 'content': content}], |
|
|
'user_id': 'system', |
|
|
'agent_id': agent_id, |
|
|
'visibility': visibility, |
|
|
'project_id': project_id, |
|
|
'memory_type': category, |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
} |
|
|
await self._execute_write(item) |
|
|
logger.info(f"Published {category} ({visibility}): {content[:80]}...") |
|
|
|
|
|
# ========== Cold Start (Three-Phase) ========== |
|
|
|
|
|
async def cold_start_search(self, agent_id='main', user_id='default', top_k=5): |
|
|
"""Three-phase cold start: public -> project -> private. |
|
|
Returns merged memories ordered by phase then score.""" |
|
|
if self.local_memory is None: |
|
|
return [] |
|
|
all_mems = {} |
|
|
phase = {} |
|
|
|
|
|
for q in ["system best practices and conventions", |
|
|
"shared configuration and architecture decisions"]: |
|
|
try: |
|
|
r = await asyncio.to_thread( |
|
|
self.local_memory.search, q, user_id=user_id, |
|
|
filters={"visibility": "public"}, limit=top_k) |
|
|
for m in (r or []): |
|
|
mid = m.get('id') if isinstance(m, dict) else None |
|
|
if mid and mid not in all_mems: |
|
|
all_mems[mid] = m |
|
|
phase[mid] = 0 |
|
|
except Exception as e: |
|
|
logger.debug(f"Cold start public failed: {e}") |
|
|
|
|
|
for pid in get_agent_projects(agent_id): |
|
|
if pid == 'global': |
|
|
continue |
|
|
for q in ["project guidelines and shared knowledge", |
|
|
"recent project decisions and updates"]: |
|
|
try: |
|
|
r = await asyncio.to_thread( |
|
|
self.local_memory.search, q, user_id=user_id, |
|
|
filters={"visibility": "project", "project_id": pid}, |
|
|
limit=top_k) |
|
|
for m in (r or []): |
|
|
mid = m.get('id') if isinstance(m, dict) else None |
|
|
if mid and mid not in all_mems: |
|
|
all_mems[mid] = m |
|
|
phase[mid] = 1 |
|
|
except Exception as e: |
|
|
logger.debug(f"Cold start project({pid}) failed: {e}") |
|
|
|
|
|
for q in ["current active tasks deployment progress", |
|
|
"pending work items todos", |
|
|
"recent important decisions configurations"]: |
|
|
try: |
|
|
r = await asyncio.to_thread( |
|
|
self.local_memory.search, q, user_id=user_id, |
|
|
filters={"visibility": "private", "agent_id": agent_id}, |
|
|
limit=top_k) |
|
|
for m in (r or []): |
|
|
mid = m.get('id') if isinstance(m, dict) else None |
|
|
if mid and mid not in all_mems: |
|
|
all_mems[mid] = m |
|
|
phase[mid] = 2 |
|
|
except Exception as e: |
|
|
logger.debug(f"Cold start private failed: {e}") |
|
|
|
|
|
mc = self.config['retrieval']['min_confidence'] |
|
|
out = [m for m in all_mems.values() if m.get('score', 1.0) >= mc] |
|
|
out.sort(key=lambda m: (phase.get(m.get('id', ''), 9), -m.get('score', 0))) |
|
|
return out[:top_k] |
|
|
|
|
|
async def shutdown(self): |
|
|
if self.async_queue: |
|
|
await self.async_queue.stop() |
|
|
logger.info("mem0 Client shutdown complete") |
|
|
|
|
|
|
|
|
# 全局客户端实例 |
|
|
mem0_client = Mem0Client()
|
|
|
|