You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
107 lines
3.8 KiB
107 lines
3.8 KiB
#!/usr/bin/env python3 |
|
""" |
|
Mem0 Python 集成脚本 |
|
被 Node.js 插件调用,执行实际的记忆操作 |
|
""" |
|
|
|
import sys |
|
import json |
|
import os |
|
import asyncio |
|
from datetime import datetime |
|
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') or os.getenv('DASHSCOPE_API_KEY', '') |
|
if _key: |
|
os.environ['OPENAI_API_KEY'] = _key |
|
|
|
sys.path.insert(0, os.path.dirname(__file__)) |
|
|
|
from mem0_client import mem0_client |
|
|
|
|
|
async def main(): |
|
if len(sys.argv) < 2: |
|
print(json.dumps({"error": "No action specified"})) |
|
return |
|
|
|
action = sys.argv[1] |
|
data = json.loads(sys.argv[2]) if len(sys.argv) > 2 else {} |
|
|
|
try: |
|
if action == 'init': |
|
# 初始化 mem0 |
|
await mem0_client.start() |
|
print(json.dumps({ |
|
"status": "initialized", |
|
"qdrant": f"{mem0_client.config['qdrant']['host']}:{mem0_client.config['qdrant']['port']}" |
|
})) |
|
|
|
elif action == 'search': |
|
# 检索记忆 |
|
memories = await mem0_client.pre_hook_search( |
|
query=data.get('query', ''), |
|
user_id=data.get('user_id', 'default'), |
|
agent_id=data.get('agent_id', 'general') |
|
) |
|
print(json.dumps({ |
|
"memories": memories, |
|
"count": len(memories) |
|
})) |
|
|
|
elif action == 'add': |
|
await mem0_client.start() |
|
user_msg = data.get('user_message', '') |
|
asst_msg = data.get('assistant_message', '') |
|
user_id = data.get('user_id', 'default') |
|
agent_id = data.get('agent_id', 'general') |
|
if mem0_client._should_skip_memory(user_msg, asst_msg): |
|
print(json.dumps({"status": "skipped", "reason": "write_filter"})) |
|
else: |
|
memory_type = mem0_client._classify_memory_type(user_msg, asst_msg) |
|
visibility = mem0_client._classify_visibility(user_msg, asst_msg, agent_id) |
|
item = { |
|
'messages': [ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": asst_msg}, |
|
], |
|
'user_id': user_id, |
|
'agent_id': agent_id, |
|
'visibility': visibility, |
|
'project_id': data.get('project_id'), |
|
'memory_type': memory_type, |
|
'timestamp': datetime.now().isoformat(), |
|
} |
|
await mem0_client._execute_write(item) |
|
print(json.dumps({"status": "written", "visibility": visibility, "memory_type": memory_type})) |
|
|
|
elif action == 'publish': |
|
await mem0_client.start() |
|
await mem0_client.publish_knowledge( |
|
content=data.get('content', ''), |
|
category=data.get('category', 'knowledge'), |
|
visibility=data.get('visibility', 'public'), |
|
project_id=data.get('project_id'), |
|
agent_id=data.get('agent_id', 'main'), |
|
) |
|
print(json.dumps({"status": "published"})) |
|
|
|
elif action == 'cold_start': |
|
await mem0_client.start() |
|
memories = await mem0_client.cold_start_search( |
|
agent_id=data.get('agent_id', 'main'), |
|
user_id=data.get('user_id', 'default'), |
|
top_k=data.get('top_k', 5), |
|
) |
|
print(json.dumps({"memories": memories, "count": len(memories)})) |
|
|
|
else: |
|
print(json.dumps({"error": f"Unknown action: {action}"})) |
|
|
|
except Exception as e: |
|
print(json.dumps({"error": str(e)})) |
|
|
|
|
|
if __name__ == '__main__': |
|
asyncio.run(main())
|
|
|