|
|
|
@ -2,6 +2,7 @@ |
|
|
|
""" |
|
|
|
""" |
|
|
|
mem0 Client for OpenClaw - 生产级纯异步架构 |
|
|
|
mem0 Client for OpenClaw - 生产级纯异步架构 |
|
|
|
Pre-Hook 检索注入 + Post-Hook 异步写入 |
|
|
|
Pre-Hook 检索注入 + Post-Hook 异步写入 |
|
|
|
|
|
|
|
元数据维度隔离 (user_id + agent_id) |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import os |
|
|
|
@ -12,8 +13,9 @@ from typing import List, Dict, Optional, Any |
|
|
|
from collections import deque |
|
|
|
from collections import deque |
|
|
|
from datetime import datetime |
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
|
|
# 设置环境变量 |
|
|
|
# ========== DashScope 环境变量配置 ========== |
|
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
|
|
|
|
|
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' # 关键:兼容模式需要此变量 |
|
|
|
os.environ['OPENAI_API_KEY'] = os.getenv('MEM0_DASHSCOPE_API_KEY', 'sk-c1715ee0479841399fd359c574647648') |
|
|
|
os.environ['OPENAI_API_KEY'] = os.getenv('MEM0_DASHSCOPE_API_KEY', 'sk-c1715ee0479841399fd359c574647648') |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
try: |
|
|
|
@ -39,14 +41,14 @@ class AsyncMemoryQueue: |
|
|
|
self.flush_interval = 60 |
|
|
|
self.flush_interval = 60 |
|
|
|
|
|
|
|
|
|
|
|
def add(self, item: Dict[str, Any]): |
|
|
|
def add(self, item: Dict[str, Any]): |
|
|
|
"""添加任务到队列""" |
|
|
|
"""添加任务到队列(同步方法)""" |
|
|
|
try: |
|
|
|
try: |
|
|
|
if len(self.queue) < self.queue.maxlen: |
|
|
|
if len(self.queue) < self.queue.maxlen: |
|
|
|
self.queue.append({ |
|
|
|
self.queue.append({ |
|
|
|
'messages': item['messages'], |
|
|
|
'messages': item['messages'], |
|
|
|
'user_id': item['user_id'], |
|
|
|
'user_id': item['user_id'], |
|
|
|
'agent_id': item['agent_id'], |
|
|
|
'agent_id': item['agent_id'], |
|
|
|
'timestamp': datetime.now().isoformat() |
|
|
|
'timestamp': item.get('timestamp', datetime.now().isoformat()) |
|
|
|
}) |
|
|
|
}) |
|
|
|
else: |
|
|
|
else: |
|
|
|
logger.warning("异步队列已满,丢弃旧任务") |
|
|
|
logger.warning("异步队列已满,丢弃旧任务") |
|
|
|
@ -54,7 +56,7 @@ class AsyncMemoryQueue: |
|
|
|
logger.error(f"队列添加失败:{e}") |
|
|
|
logger.error(f"队列添加失败:{e}") |
|
|
|
|
|
|
|
|
|
|
|
async def get_batch(self, batch_size: int) -> List[Dict]: |
|
|
|
async def get_batch(self, batch_size: int) -> List[Dict]: |
|
|
|
"""获取批量任务""" |
|
|
|
"""获取批量任务(异步方法)""" |
|
|
|
async with self.lock: |
|
|
|
async with self.lock: |
|
|
|
batch = [] |
|
|
|
batch = [] |
|
|
|
while len(batch) < batch_size and self.queue: |
|
|
|
while len(batch) < batch_size and self.queue: |
|
|
|
@ -120,21 +122,25 @@ class AsyncMemoryQueue: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mem0Client: |
|
|
|
class Mem0Client: |
|
|
|
"""生产级 mem0 客户端""" |
|
|
|
""" |
|
|
|
|
|
|
|
生产级 mem0 客户端 |
|
|
|
|
|
|
|
纯异步架构 + 阻塞操作隔离 + 元数据维度隔离 |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Dict = None): |
|
|
|
def __init__(self, config: Dict = None): |
|
|
|
self.config = config or self._load_default_config() |
|
|
|
self.config = config or self._load_default_config() |
|
|
|
self.local_memory = None |
|
|
|
self.local_memory = None |
|
|
|
self.async_queue = None |
|
|
|
self.async_queue = None |
|
|
|
self.cache = {} |
|
|
|
self.cache = {} |
|
|
|
|
|
|
|
self._started = False |
|
|
|
|
|
|
|
# 不在 __init__ 中启动异步任务 |
|
|
|
self._init_memory() |
|
|
|
self._init_memory() |
|
|
|
self._start_async_worker() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_default_config(self) -> Dict: |
|
|
|
def _load_default_config(self) -> Dict: |
|
|
|
"""加载默认配置""" |
|
|
|
"""加载默认配置""" |
|
|
|
return { |
|
|
|
return { |
|
|
|
"qdrant": { |
|
|
|
"qdrant": { |
|
|
|
"host": os.getenv('MEM0_QDRANT_HOST', '100.115.94.1'), |
|
|
|
"host": os.getenv('MEM0_QDRANT_HOST', 'localhost'), |
|
|
|
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), |
|
|
|
"port": int(os.getenv('MEM0_QDRANT_PORT', '6333')), |
|
|
|
"collection_name": "mem0_shared" |
|
|
|
"collection_name": "mem0_shared" |
|
|
|
}, |
|
|
|
}, |
|
|
|
@ -172,7 +178,7 @@ class Mem0Client: |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _init_memory(self): |
|
|
|
def _init_memory(self): |
|
|
|
"""初始化 mem0""" |
|
|
|
"""初始化 mem0(同步操作)""" |
|
|
|
if Memory is None: |
|
|
|
if Memory is None: |
|
|
|
logger.warning("mem0ai 未安装") |
|
|
|
logger.warning("mem0ai 未安装") |
|
|
|
return |
|
|
|
return |
|
|
|
@ -199,13 +205,16 @@ class Mem0Client: |
|
|
|
logger.error(f"❌ mem0 初始化失败:{e}") |
|
|
|
logger.error(f"❌ mem0 初始化失败:{e}") |
|
|
|
self.local_memory = None |
|
|
|
self.local_memory = None |
|
|
|
|
|
|
|
|
|
|
|
def _start_async_worker(self): |
|
|
|
async def start(self): |
|
|
|
"""启动异步写入工作线程""" |
|
|
|
""" |
|
|
|
if not self.config['async_write']['enabled']: |
|
|
|
显式启动异步工作线程 |
|
|
|
|
|
|
|
必须在事件循环中调用:await mem0_client.start() |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
if self._started: |
|
|
|
|
|
|
|
logger.debug("mem0 Client 已启动") |
|
|
|
return |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
if self.config['async_write']['enabled']: |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
|
|
self.async_queue = AsyncMemoryQueue( |
|
|
|
self.async_queue = AsyncMemoryQueue( |
|
|
|
max_size=self.config['async_write']['queue_size'] |
|
|
|
max_size=self.config['async_write']['queue_size'] |
|
|
|
) |
|
|
|
) |
|
|
|
@ -214,8 +223,10 @@ class Mem0Client: |
|
|
|
batch_size=self.config['async_write']['batch_size'], |
|
|
|
batch_size=self.config['async_write']['batch_size'], |
|
|
|
flush_interval=self.config['async_write']['flush_interval'] |
|
|
|
flush_interval=self.config['async_write']['flush_interval'] |
|
|
|
) |
|
|
|
) |
|
|
|
except RuntimeError: |
|
|
|
self._started = True |
|
|
|
logger.debug("当前无事件循环,异步队列将在首次使用时初始化") |
|
|
|
logger.info("✅ mem0 Client 异步工作线程已启动") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== Pre-Hook: 智能检索 ========== |
|
|
|
|
|
|
|
|
|
|
|
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: |
|
|
|
async def pre_hook_search(self, query: str, user_id: str = None, agent_id: str = None, top_k: int = None) -> List[Dict]: |
|
|
|
"""Pre-Hook: 对话前智能检索""" |
|
|
|
"""Pre-Hook: 对话前智能检索""" |
|
|
|
@ -252,39 +263,54 @@ class Mem0Client: |
|
|
|
return [] |
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: |
|
|
|
async def _execute_search(self, query: str, user_id: str, agent_id: str, top_k: int) -> List[Dict]: |
|
|
|
"""执行检索(使用 asyncio.to_thread 隔离阻塞)""" |
|
|
|
""" |
|
|
|
|
|
|
|
执行检索 - 使用 metadata 过滤器实现维度隔离 |
|
|
|
|
|
|
|
""" |
|
|
|
if self.local_memory is None: |
|
|
|
if self.local_memory is None: |
|
|
|
return [] |
|
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 策略 1: 检索全局用户记忆 |
|
|
|
user_memories = [] |
|
|
|
user_memories = [] |
|
|
|
if user_id: |
|
|
|
if user_id: |
|
|
|
try: |
|
|
|
try: |
|
|
|
user_memories = await asyncio.to_thread( |
|
|
|
user_memories = await asyncio.to_thread( |
|
|
|
self.local_memory.search, query, user_id=user_id, limit=top_k |
|
|
|
self.local_memory.search, |
|
|
|
|
|
|
|
query, |
|
|
|
|
|
|
|
user_id=user_id, |
|
|
|
|
|
|
|
limit=top_k |
|
|
|
) |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
logger.debug(f"用户记忆检索失败:{e}") |
|
|
|
logger.debug(f"用户记忆检索失败:{e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 策略 2: 检索业务域记忆(使用 metadata 过滤器) |
|
|
|
agent_memories = [] |
|
|
|
agent_memories = [] |
|
|
|
if agent_id and agent_id != 'general': |
|
|
|
if agent_id and agent_id != 'general': |
|
|
|
try: |
|
|
|
try: |
|
|
|
agent_memories = await asyncio.to_thread( |
|
|
|
agent_memories = await asyncio.to_thread( |
|
|
|
self.local_memory.search, |
|
|
|
self.local_memory.search, |
|
|
|
query, |
|
|
|
query, |
|
|
|
user_id=f"{user_id}:{agent_id}" if user_id else agent_id, |
|
|
|
user_id=user_id, |
|
|
|
|
|
|
|
filters={"agent_id": agent_id}, # metadata 过滤,实现垂直隔离 |
|
|
|
limit=top_k |
|
|
|
limit=top_k |
|
|
|
) |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
logger.debug(f"业务记忆检索失败:{e}") |
|
|
|
logger.debug(f"业务记忆检索失败:{e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 合并结果(去重) |
|
|
|
all_memories = {} |
|
|
|
all_memories = {} |
|
|
|
for mem in user_memories + agent_memories: |
|
|
|
for mem in user_memories + agent_memories: |
|
|
|
mem_id = mem.get('id') if isinstance(mem, dict) else None |
|
|
|
mem_id = mem.get('id') if isinstance(mem, dict) else None |
|
|
|
if mem_id and mem_id not in all_memories: |
|
|
|
if mem_id and mem_id not in all_memories: |
|
|
|
all_memories[mem_id] = mem |
|
|
|
all_memories[mem_id] = mem |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 按置信度过滤 |
|
|
|
min_confidence = self.config['retrieval']['min_confidence'] |
|
|
|
min_confidence = self.config['retrieval']['min_confidence'] |
|
|
|
filtered = [m for m in all_memories.values() if m.get('score', 1.0) >= min_confidence] |
|
|
|
filtered = [ |
|
|
|
|
|
|
|
m for m in all_memories.values() |
|
|
|
|
|
|
|
if m.get('score', 1.0) >= min_confidence |
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 按置信度排序 |
|
|
|
filtered.sort(key=lambda x: x.get('score', 0), reverse=True) |
|
|
|
filtered.sort(key=lambda x: x.get('score', 0), reverse=True) |
|
|
|
|
|
|
|
|
|
|
|
return filtered[:top_k] |
|
|
|
return filtered[:top_k] |
|
|
|
@ -306,8 +332,10 @@ class Mem0Client: |
|
|
|
|
|
|
|
|
|
|
|
return prompt |
|
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== Post-Hook: 异步写入 ========== |
|
|
|
|
|
|
|
|
|
|
|
def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None): |
|
|
|
def post_hook_add(self, user_message: str, assistant_message: str, user_id: str = None, agent_id: str = None): |
|
|
|
"""Post-Hook: 对话后异步写入""" |
|
|
|
"""Post-Hook: 对话后异步写入(同步方法,仅添加到队列)""" |
|
|
|
if not self.config['async_write']['enabled']: |
|
|
|
if not self.config['async_write']['enabled']: |
|
|
|
return |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
@ -325,14 +353,15 @@ class Mem0Client: |
|
|
|
self.async_queue.add({ |
|
|
|
self.async_queue.add({ |
|
|
|
'messages': messages, |
|
|
|
'messages': messages, |
|
|
|
'user_id': user_id, |
|
|
|
'user_id': user_id, |
|
|
|
'agent_id': agent_id |
|
|
|
'agent_id': agent_id, |
|
|
|
|
|
|
|
'timestamp': datetime.now().isoformat() |
|
|
|
}) |
|
|
|
}) |
|
|
|
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}") |
|
|
|
logger.debug(f"Post-Hook 已提交:user={user_id}, agent={agent_id}") |
|
|
|
else: |
|
|
|
else: |
|
|
|
logger.warning("异步队列未初始化") |
|
|
|
logger.warning("异步队列未初始化") |
|
|
|
|
|
|
|
|
|
|
|
async def _async_write_memory(self, item: Dict): |
|
|
|
async def _async_write_memory(self, item: Dict): |
|
|
|
"""异步写入记忆""" |
|
|
|
"""异步写入记忆(后台任务)""" |
|
|
|
if self.local_memory is None: |
|
|
|
if self.local_memory is None: |
|
|
|
return |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
@ -347,17 +376,27 @@ class Mem0Client: |
|
|
|
logger.warning(f"异步写入失败:{e}") |
|
|
|
logger.warning(f"异步写入失败:{e}") |
|
|
|
|
|
|
|
|
|
|
|
async def _execute_write(self, item: Dict): |
|
|
|
async def _execute_write(self, item: Dict): |
|
|
|
"""执行写入(使用 asyncio.to_thread 隔离阻塞)""" |
|
|
|
""" |
|
|
|
|
|
|
|
执行写入 - 使用 metadata 实现维度隔离 |
|
|
|
|
|
|
|
关键:通过 metadata 字典传递 agent_id,而非直接参数 |
|
|
|
|
|
|
|
""" |
|
|
|
if self.local_memory is None: |
|
|
|
if self.local_memory is None: |
|
|
|
return |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
full_user_id = f"{item['user_id']}:{item['agent_id']}" |
|
|
|
# 构建元数据,实现业务隔离 |
|
|
|
|
|
|
|
custom_metadata = { |
|
|
|
|
|
|
|
"agent_id": item['agent_id'], |
|
|
|
|
|
|
|
"source": "openclaw", |
|
|
|
|
|
|
|
"timestamp": item.get('timestamp'), |
|
|
|
|
|
|
|
"business_type": item['agent_id'] |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 阻塞操作,放入线程池执行 |
|
|
|
await asyncio.to_thread( |
|
|
|
await asyncio.to_thread( |
|
|
|
self.local_memory.add, |
|
|
|
self.local_memory.add, |
|
|
|
messages=item['messages'], |
|
|
|
messages=item['messages'], |
|
|
|
user_id=full_user_id, |
|
|
|
user_id=item['user_id'], # 原生支持的全局用户标识 |
|
|
|
agent_id=item['agent_id'] |
|
|
|
metadata=custom_metadata # 注入自定义业务维度 |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _cleanup_cache(self): |
|
|
|
def _cleanup_cache(self): |
|
|
|
@ -381,6 +420,7 @@ class Mem0Client: |
|
|
|
"""获取状态""" |
|
|
|
"""获取状态""" |
|
|
|
return { |
|
|
|
return { |
|
|
|
"initialized": self.local_memory is not None, |
|
|
|
"initialized": self.local_memory is not None, |
|
|
|
|
|
|
|
"started": self._started, |
|
|
|
"async_queue_enabled": self.config['async_write']['enabled'], |
|
|
|
"async_queue_enabled": self.config['async_write']['enabled'], |
|
|
|
"queue_size": len(self.async_queue.queue) if self.async_queue else 0, |
|
|
|
"queue_size": len(self.async_queue.queue) if self.async_queue else 0, |
|
|
|
"cache_size": len(self.cache), |
|
|
|
"cache_size": len(self.cache), |
|
|
|
|