You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.7 KiB
90 lines
2.7 KiB
#!/usr/bin/env python3 |
|
""" |
|
Session Initialization Hook - Three-Phase Cold Start Memory Preload |
|
|
|
Retrieves memories in three phases at session startup: |
|
Phase 0 (public): Best practices and shared config for all agents |
|
Phase 1 (project): Project-specific shared knowledge |
|
Phase 2 (private): Agent's own recent context |
|
|
|
Injects formatted memories into the System Prompt. |
|
""" |
|
|
|
import asyncio |
|
import sys |
|
import os |
|
import logging |
|
|
|
sys.path.insert(0, os.path.dirname(__file__)) |
|
os.environ['OPENAI_API_BASE'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
os.environ['OPENAI_BASE_URL'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1' |
|
_dashscope_key = os.getenv('MEM0_DASHSCOPE_API_KEY', '') or os.getenv('DASHSCOPE_API_KEY', '') |
|
if _dashscope_key: |
|
os.environ['OPENAI_API_KEY'] = _dashscope_key |
|
|
|
from mem0_client import Mem0Client |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def cold_start_retrieval(agent_id: str = "main", |
|
user_id: str = "default", |
|
top_k: int = 5) -> str: |
|
"""Three-phase cold start retrieval. |
|
|
|
Uses Mem0Client.cold_start_search() which queries |
|
public -> project -> private memories in order. |
|
""" |
|
client = Mem0Client() |
|
await client.start() |
|
|
|
try: |
|
memories = await asyncio.wait_for( |
|
client.cold_start_search( |
|
agent_id=agent_id, |
|
user_id=user_id, |
|
top_k=top_k, |
|
), |
|
timeout=10.0, |
|
) |
|
except asyncio.TimeoutError: |
|
logger.warning("Cold start search timed out (10s)") |
|
memories = [] |
|
except Exception as e: |
|
logger.error(f"Cold start search failed: {e}") |
|
memories = [] |
|
|
|
await client.shutdown() |
|
|
|
if not memories: |
|
return "" |
|
|
|
prompt = "\n\n=== Cold Start Context (auto-loaded) ===\n" |
|
for i, mem in enumerate(memories, 1): |
|
mem_text = mem.get('memory', '') if isinstance(mem, dict) else str(mem) |
|
metadata = mem.get('metadata', {}) if isinstance(mem, dict) else {} |
|
vis = metadata.get('visibility', 'unknown') |
|
agent = metadata.get('agent_id', 'unknown') |
|
label = f"{vis}/{agent}" |
|
prompt += f"{i}. [{label}] {mem_text}\n" |
|
prompt += "========================================\n" |
|
|
|
logger.info(f"Cold start complete: {len(memories)} memories ({agent_id})") |
|
return prompt |
|
|
|
|
|
async def main(): |
|
result = await cold_start_retrieval( |
|
agent_id="main", |
|
user_id="wang_yuanzhang", |
|
top_k=5 |
|
) |
|
if result: |
|
print("Cold start retrieval succeeded:") |
|
print(result) |
|
else: |
|
print("No memories found (Qdrant empty or timeout)") |
|
|
|
|
|
if __name__ == '__main__': |
|
asyncio.run(main())
|
|
|