|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""
|
|
|
|
|
Session Initialization Hook - Three-Phase Cold Start Memory Preload
|
|
|
|
|
|
|
|
|
|
Retrieves memories in three phases at session startup:
|
|
|
|
|
Phase 0 (public): Best practices and shared config for all agents
|
|
|
|
|
Phase 1 (project): Project-specific shared knowledge
|
|
|
|
|
Phase 2 (private): Agent's own recent context
|
|
|
|
|
|
|
|
|
|
Injects formatted memories into the System Prompt.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import sys
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
|
|
|
|
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
|
|
|
|
|
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') or os.getenv('DASHSCOPE_API_KEY', '')
|
|
|
|
|
if _dashscope_key:
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = _dashscope_key
|
|
|
|
|
|
|
|
|
|
from mem0_client import Mem0Client
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def cold_start_retrieval(agent_id: str = "main",
|
|
|
|
|
user_id: str = "default",
|
|
|
|
|
top_k: int = 5) -> str:
|
|
|
|
|
"""Three-phase cold start retrieval.
|
|
|
|
|
|
|
|
|
|
Uses Mem0Client.cold_start_search() which queries
|
|
|
|
|
public -> project -> private memories in order.
|
|
|
|
|
"""
|
|
|
|
|
client = Mem0Client()
|
|
|
|
|
await client.start()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
memories = await asyncio.wait_for(
|
|
|
|
|
client.cold_start_search(
|
|
|
|
|
agent_id=agent_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
),
|
|
|
|
|
timeout=10.0,
|
|
|
|
|
)
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.warning("Cold start search timed out (10s)")
|
|
|
|
|
memories = []
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Cold start search failed: {e}")
|
|
|
|
|
memories = []
|
|
|
|
|
|
|
|
|
|
await client.shutdown()
|
|
|
|
|
|
|
|
|
|
if not memories:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
prompt = "\n\n=== Cold Start Context (auto-loaded) ===\n"
|
|
|
|
|
for i, mem in enumerate(memories, 1):
|
|
|
|
|
mem_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem)
|
|
|
|
|
metadata = mem.get('metadata', {}) if isinstance(mem, dict) else {}
|
|
|
|
|
vis = metadata.get('visibility', 'unknown')
|
|
|
|
|
agent = metadata.get('agent_id', 'unknown')
|
|
|
|
|
label = f"{vis}/{agent}"
|
|
|
|
|
prompt += f"{i}. [{label}] {mem_text}\n"
|
|
|
|
|
prompt += "========================================\n"
|
|
|
|
|
|
|
|
|
|
logger.info(f"Cold start complete: {len(memories)} memories ({agent_id})")
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
result = await cold_start_retrieval(
|
|
|
|
|
agent_id="main",
|
|
|
|
|
user_id="wang_yuanzhang",
|
|
|
|
|
top_k=5
|
|
|
|
|
)
|
|
|
|
|
if result:
|
|
|
|
|
print("Cold start retrieval succeeded:")
|
|
|
|
|
print(result)
|
|
|
|
|
else:
|
|
|
|
|
print("No memories found (Qdrant empty or timeout)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
asyncio.run(main())
|