You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.7 KiB

#!/usr/bin/env python3
"""
Session Initialization Hook - Three-Phase Cold Start Memory Preload
Retrieves memories in three phases at session startup:
Phase 0 (public): Best practices and shared config for all agents
Phase 1 (project): Project-specific shared knowledge
Phase 2 (private): Agent's own recent context
Injects formatted memories into the System Prompt.
"""
import asyncio
import sys
import os
import logging
sys.path.insert(0, os.path.dirname(__file__))
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') or os.getenv('DASHSCOPE_API_KEY', '')
if _dashscope_key:
os.environ['OPENAI_API_KEY'] = _dashscope_key
from mem0_client import Mem0Client
logger = logging.getLogger(__name__)
async def cold_start_retrieval(agent_id: str = "main",
user_id: str = "default",
top_k: int = 5) -> str:
"""Three-phase cold start retrieval.
Uses Mem0Client.cold_start_search() which queries
public -> project -> private memories in order.
"""
client = Mem0Client()
await client.start()
try:
memories = await asyncio.wait_for(
client.cold_start_search(
agent_id=agent_id,
user_id=user_id,
top_k=top_k,
),
timeout=10.0,
)
except asyncio.TimeoutError:
logger.warning("Cold start search timed out (10s)")
memories = []
except Exception as e:
logger.error(f"Cold start search failed: {e}")
memories = []
await client.shutdown()
if not memories:
return ""
prompt = "\n\n=== Cold Start Context (auto-loaded) ===\n"
for i, mem in enumerate(memories, 1):
mem_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem)
metadata = mem.get('metadata', {}) if isinstance(mem, dict) else {}
vis = metadata.get('visibility', 'unknown')
agent = metadata.get('agent_id', 'unknown')
label = f"{vis}/{agent}"
prompt += f"{i}. [{label}] {mem_text}\n"
prompt += "========================================\n"
logger.info(f"Cold start complete: {len(memories)} memories ({agent_id})")
return prompt
async def main():
result = await cold_start_retrieval(
agent_id="main",
user_id="wang_yuanzhang",
top_k=5
)
if result:
print("Cold start retrieval succeeded:")
print(result)
else:
print("No memories found (Qdrant empty or timeout)")
if __name__ == '__main__':
asyncio.run(main())