Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Here's how to navigate it:

- [🧩 Agent Framework](#agent-framework): Learn how to create your own agents, register custom tools, and extend SpoonOS with minimal setup.


- [📊 Enhanced Graph System](#enhanced-graph-system): Discover the powerful graph-based workflow orchestration system for complex AI agent workflows.

- [🔌 API Integration](#api-integration): Plug in external APIs to enhance your agent workflows.
Expand Down Expand Up @@ -524,6 +525,10 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())
```
# New: Async Context Management
async with SpoonReactAI(name="my_agent") as agent:
result = await agent.run("Hello world")
# Agent automatically cleaned up here

Register your own tools, override run(), or extend with MCP integrations. See docs/agent.md or docs/mcp_mode_usage.md

Expand Down Expand Up @@ -558,6 +563,11 @@ chatbot = ChatBot(
enable_prompt_cache=True
)
```
# New: Async Context Management
async with SpoonReactAI(name="my_agent") as agent:
result = await agent.run("Hello world")
# Agent automatically cleaned up here


## 🗂️ Project Structure

Expand Down
82 changes: 82 additions & 0 deletions spoon_ai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,85 @@ async def _run_and_signal_done(self, request: Optional[str] = None):
logger.info(f"Resetting agent {self.name} state from {self.state} to IDLE")
self.state = AgentState.IDLE
self.current_step = 0

async def __aenter__(self):
"""Async context manager entry"""
logger.debug(f"Initializing agent '{self.name}' async context")

# Initialize async resources
if hasattr(self, 'initialize') and callable(getattr(self, 'initialize')):
try:
await self.initialize()
logger.info(f"Agent '{self.name}' initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent '{self.name}': {e}")
raise RuntimeError(f"Agent initialization failed: {e}") from e

return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit with comprehensive cleanup"""
cleanup_errors = []

try:
logger.debug(f"Cleaning up agent '{self.name}' (exception: {exc_type is not None})")

# 1. Disconnect from external services
if hasattr(self, 'disconnect') and callable(getattr(self, 'disconnect')):
try:
await self.disconnect()
logger.debug("Agent disconnected from external services")
except Exception as e:
cleanup_errors.append(f"Disconnect error: {e}")

# 2. Save state/history if needed
if hasattr(self, 'save_state') and callable(getattr(self, 'save_state')):
try:
await self.save_state()
logger.debug("Agent state saved")
except Exception as e:
cleanup_errors.append(f"State save error: {e}")

# 3. Optimize memory before exit
if hasattr(self, 'optimize_memory') and callable(getattr(self, 'optimize_memory')):
try:
self.optimize_memory()
logger.debug("Memory optimized")
except Exception as e:
cleanup_errors.append(f"Memory optimization error: {e}")

# 4. Reset agent state
try:
self.state = AgentState.IDLE
self.current_step = 0
if hasattr(self, 'task_done'):
self.task_done.set() # Signal completion
logger.debug("Agent state reset")
except Exception as e:
cleanup_errors.append(f"State reset error: {e}")

# 5. Clear sensitive data
sensitive_attrs = ['_temp_data', '_cache', '_session_data']
for attr in sensitive_attrs:
if hasattr(self, attr):
try:
delattr(self, attr)
except Exception as e:
cleanup_errors.append(f"Attribute cleanup error ({attr}): {e}")

# Log cleanup results
if cleanup_errors:
logger.warning(f"Agent '{self.name}' cleanup completed with {len(cleanup_errors)} errors: {'; '.join(cleanup_errors)}")
else:
logger.info(f"Agent '{self.name}' cleaned up successfully")

except Exception as critical_error:
logger.error(f"Critical error during agent '{self.name}' cleanup: {critical_error}")
cleanup_errors.append(f"Critical cleanup error: {critical_error}")

# Don't suppress original exceptions - only raise cleanup errors if no original exception
if cleanup_errors and exc_type is None:
raise RuntimeError(f"Cleanup failed: {'; '.join(cleanup_errors)}")

# Return False to not suppress the original exception if there was one
return False
27 changes: 26 additions & 1 deletion spoon_ai/agents/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,29 @@ def clear(self):
except Exception as e:
logger.error(f"Error checking tools after clear: {e}")

logger.debug(f"CustomAgent '{self.name}' fully cleared and validated")
logger.debug(f"CustomAgent '{self.name}' fully cleared and validated")

async def save_state(self):
"""Save agent state for context manager cleanup"""
try:
# Save chat history
self.save_chat_history()

# Save tool configuration if needed
tool_config = {
'tool_count': len(self.list_tools()),
'tool_names': self.list_tools(),
'last_updated': datetime.datetime.now().isoformat()
}

config_dir = Path('agent_states')
config_dir.mkdir(exist_ok=True)

with open(config_dir / f'{self.name}_tools.json', 'w') as f:
json.dump(tool_config, f, indent=2)

logger.debug(f"Saved state for CustomAgent '{self.name}'")

except Exception as e:
logger.error(f"Error saving state for CustomAgent '{self.name}': {e}")
raise
102 changes: 102 additions & 0 deletions tests/test_async_context_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from spoon_ai.agents.custom_agent import CustomAgent
from spoon_ai.schema import AgentState

@pytest.mark.asyncio
class TestAsyncContextManagement:
"""Test async context management for agents"""

async def test_async_context_manager_success(self):
"""Test successful async context management"""
mock_llm = Mock()

async with CustomAgent(name="test_agent", llm=mock_llm) as agent:
assert agent.name == "test_agent"
assert agent.state == AgentState.IDLE

# Simulate some work
result = "test completed"

# Agent should be cleaned up
assert agent.state == AgentState.IDLE

async def test_async_context_manager_with_exception(self):
"""Test cleanup happens even when exception occurs"""
mock_llm = Mock()

with pytest.raises(ValueError, match="test error"):
async with CustomAgent(name="test_agent", llm=mock_llm) as agent:
raise ValueError("test error")

# Agent should still be cleaned up
assert agent.state == AgentState.IDLE

async def test_initialization_during_context_entry(self):
"""Test that initialize is called during context entry if it exists"""
mock_llm = Mock()

class MockAgentWithInit(CustomAgent):
async def initialize(self):
self._initialized = True

async with MockAgentWithInit(name="test_agent", llm=mock_llm) as agent:
assert hasattr(agent, '_initialized')
assert agent._initialized is True

async def test_disconnect_during_context_exit(self):
"""Test that disconnect is called during cleanup"""
mock_llm = Mock()

class MockAgentWithDisconnect(CustomAgent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.disconnected = False

async def disconnect(self):
self.disconnected = True

agent_ref = None
async with MockAgentWithDisconnect(name="test_agent", llm=mock_llm) as agent:
agent_ref = agent
assert not agent.disconnected

assert agent_ref.disconnected is True

@patch('spoon_ai.agents.base.logger')
async def test_cleanup_error_handling(self, mock_logger):
"""Test that cleanup errors are handled gracefully"""
mock_llm = Mock()

class MockAgentWithFailingCleanup(CustomAgent):
async def disconnect(self):
raise RuntimeError("Disconnect failed")

# Should not raise exception due to cleanup failure
async with MockAgentWithFailingCleanup(name="test_agent", llm=mock_llm) as agent:
pass

# Should have logged the cleanup error
mock_logger.warning.assert_called()
warning_call = mock_logger.warning.call_args[0][0]
assert "cleanup completed with" in warning_call
assert "Disconnect failed" in warning_call

async def test_memory_optimization_during_cleanup(self):
"""Test memory optimization is called during cleanup"""
mock_llm = Mock()

class MockAgentWithMemoryOpt(CustomAgent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.memory_optimized = False

def optimize_memory(self):
self.memory_optimized = True

agent_ref = None
async with MockAgentWithMemoryOpt(name="test_agent", llm=mock_llm) as agent:
agent_ref = agent

assert agent_ref.memory_optimized is True