Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ dependencies = [ # Complete dependencies from requirements.txt for out-of-the-bo
"botocore==1.35.99",
"fastmcp==2.2.5",
"googleapis-common-protos",
"toml>=0.10.2",
"pytest>=8.4.1",
]

[project.urls]
Expand All @@ -82,4 +84,4 @@ dependencies = [ # Complete dependencies from requirements.txt for out-of-the-bo
[tool.setuptools.packages.find]
where = ["."] # Look in the project root directory
include = ["spoon_ai*"] # Include only the 'spoon_ai' package and its subpackages
exclude = ["tests*", "test*", "cookbook*", "api*", "cli*", "agents*", "models*", "chat_logs*", "migrations*", "react_logs*", "notebooks*"] # Explicitly exclude other top-level dirs
exclude = ["tests*", "test*", "cookbook*", "api*", "cli*", "agents*", "models*", "chat_logs*", "migrations*", "react_logs*", "notebooks*"] # Explicitly exclude other top-level dirs
377 changes: 377 additions & 0 deletions spoon_ai/llm/providers/anthropic.py

Large diffs are not rendered by default.

207 changes: 207 additions & 0 deletions spoon_ai/llm/providers/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import os
from typing import List, Optional, Dict, Literal
from logging import getLogger

from openai import AsyncOpenAI

from spoon_ai.llm.base import LLMBase
from spoon_ai.llm.factory import LLMFactory
from spoon_ai.schema import Message, LLMConfig, LLMResponse, ToolCall, Function
from spoon_ai.utils.config_manager import ConfigManager

logger = getLogger(__name__)

class DeepSeekConfig(LLMConfig):
"""DeepSeek-specific configuration"""
model: str = "deepseek-reasoner"
base_url: str = "https://api.deepseek.com"

@LLMFactory.register("deepseek")
class DeepSeekProvider(LLMBase):
"""DeepSeek provider implementation using OpenAI-compatible API"""

def __init__(self, config_path: str = "config.json", config_name: str = "llm"):
# Use ConfigManager for all configuration (no TOML)
self.config_manager = ConfigManager()

# Load configuration using ConfigManager instead of TOML
self.config = self._load_config_from_json()

# DeepSeek-specific key lookup with fallback to openai key for compatibility
api_key = self.config_manager.get_api_key("deepseek") or self.config_manager.get_api_key("openai") or os.getenv("DEEPSEEK_API_KEY") or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("DeepSeek API key not found in config.json or DEEPSEEK_API_KEY/OPENAI_API_KEY environment variable")

# Use DeepSeek's official API endpoint (ignore environment overrides)
base_url = self.config.base_url

# Initialize DeepSeek client (OpenAI-compatible)
self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url
)

def _load_config_from_json(self) -> DeepSeekConfig:
"""Load configuration from config.json via ConfigManager"""
# Get model from config.json or use default
model_name = self.config_manager.get("model_name") or "deepseek-reasoner"

return DeepSeekConfig(
model=model_name,
base_url="https://api.deepseek.com", # Official DeepSeek endpoint
max_tokens=4096,
temperature=0.3
)

def _load_config(self, config_path: str, config_name: str) -> DeepSeekConfig:
"""Load DeepSeek-specific configuration (for compatibility only)"""
# This method is for compatibility with LLMBase interface
# We use _load_config_from_json() instead in the constructor
return self._load_config_from_json()

async def chat(
self,
messages: List[Message],
system_msgs: Optional[List[Message]] = None,
**kwargs
) -> LLMResponse:
"""Send chat request to DeepSeek"""
# Format messages for OpenAI-compatible API
formatted_messages = []

# Add system messages first
if system_msgs:
for sys_msg in system_msgs:
formatted_messages.append({
"role": "system",
"content": sys_msg.content
})

# Add user/assistant messages
for message in messages:
msg_dict = {"role": message.role}
if message.content:
msg_dict["content"] = message.content
if message.tool_calls:
msg_dict["tool_calls"] = [tc.model_dump() for tc in message.tool_calls]
if message.name:
msg_dict["name"] = message.name
if message.tool_call_id:
msg_dict["tool_call_id"] = message.tool_call_id
formatted_messages.append(msg_dict)

try:
response = await self.client.chat.completions.create(
messages=formatted_messages,
model=self.config.model,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
**kwargs
)

content = response.choices[0].message.content or ""
return LLMResponse(content=content, text=content)

except Exception as e:
logger.error(f"DeepSeek API request failed: {str(e)}")
return LLMResponse(
content=f"API request failed: {str(e)}",
text=f"API request failed: {str(e)}"
)

async def completion(self, prompt: str, **kwargs) -> LLMResponse:
"""Send text completion request to DeepSeek"""
# Create a user message
message = Message(role="user", content=prompt)
return await self.chat(messages=[message], **kwargs)

async def chat_with_tools(
self,
messages: List[Message],
system_msgs: Optional[List[Message]] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Literal["none", "auto", "required"] = "auto",
**kwargs
) -> LLMResponse:
"""Send chat request with tools to DeepSeek"""
# Format messages for OpenAI-compatible API
formatted_messages = []

# Add system messages first
if system_msgs:
for sys_msg in system_msgs:
formatted_messages.append({
"role": "system",
"content": sys_msg.content
})

# Add user/assistant messages
for message in messages:
msg_dict = {"role": message.role}
if message.content:
msg_dict["content"] = message.content
if message.tool_calls:
msg_dict["tool_calls"] = [tc.model_dump() for tc in message.tool_calls]
if message.name:
msg_dict["name"] = message.name
if message.tool_call_id:
msg_dict["tool_call_id"] = message.tool_call_id
formatted_messages.append(msg_dict)

try:
response = await self.client.chat.completions.create(
messages=formatted_messages,
model=self.config.model,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
tools=tools,
tool_choice=tool_choice,
**kwargs
)

# Extract message and finish_reason from OpenAI-compatible response
message = response.choices[0].message
finish_reason = response.choices[0].finish_reason

# Convert tool calls to our ToolCall format
tool_calls = []
if message.tool_calls:
for tool_call in message.tool_calls:
tool_calls.append(ToolCall(
id=tool_call.id,
type=tool_call.type,
function=Function(
name=tool_call.function.name,
arguments=tool_call.function.arguments
)
))

# Map finish reasons to standardized values
standardized_finish_reason = finish_reason
if finish_reason == "stop":
standardized_finish_reason = "stop"
elif finish_reason == "length":
standardized_finish_reason = "length"
elif finish_reason == "tool_calls":
standardized_finish_reason = "tool_calls"
elif finish_reason == "content_filter":
standardized_finish_reason = "content_filter"

content = message.content or ""

return LLMResponse(
content=content,
text=content,
tool_calls=tool_calls,
finish_reason=standardized_finish_reason,
native_finish_reason=finish_reason
)

except Exception as e:
logger.error(f"DeepSeek API request failed: {str(e)}")
return LLMResponse(
content=f"API request failed: {str(e)}",
text=f"API request failed: {str(e)}",
tool_calls=[],
finish_reason="error"
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@

from google import genai
from google.genai import types
from pydantic import Field
import logging

from spoon_ai.schema import Message
from spoon_ai.llm.base import LLMBase, LLMConfig, LLMResponse
from spoon_ai.schema import Message, LLMConfig, LLMResponse
from spoon_ai.llm.base import LLMBase
from spoon_ai.llm.factory import LLMFactory
from logging import getLogger
logger = getLogger(__name__)
from spoon_ai.utils.config_manager import ConfigManager

logger = logging.getLogger(__name__)

Expand All @@ -26,31 +24,46 @@ class GeminiConfig(LLMConfig):

@LLMFactory.register("gemini")
class GeminiProvider(LLMBase):
"""Gemini Provider Implementation"""
"""Gemini Provider Implementation with ConfigManager integration"""

def __init__(self, config_path: str = "config/config.toml", config_name: str = "chitchat"):
def __init__(self, config_path: str = "config.json", config_name: str = "llm"):
"""Initialize Gemini Provider

Args:
config_path: Configuration file path
config_name: Configuration name
"""
super().__init__(config_path, config_name)
# Use ConfigManager for all configuration (no TOML)
self.config_manager = ConfigManager()

# Load configuration using ConfigManager instead of TOML
self.config = self._load_config_from_json()

# Get API key with config.json -> environment fallback
api_key = self.config_manager.get_api_key("gemini") or self.config_manager.get_api_key("google") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("Gemini API key not found in config.json or GEMINI_API_KEY/GOOGLE_API_KEY environment variables")

# Initialize Gemini API client
self.client = genai.Client(api_key=self.config.api_key)
self.client = genai.Client(api_key=api_key)

def _load_config(self, config_path: str, config_name: str) -> GeminiConfig:
"""Load configuration
def _load_config_from_json(self) -> GeminiConfig:
"""Load configuration from config.json via ConfigManager"""
# Get model from config.json or use default
model_name = self.config_manager.get("model_name") or "gemini-2.5-pro"

Args:
config_path: Configuration file path
config_name: Configuration name

Returns:
GeminiConfig: Gemini configuration
"""
config = super()._load_config(config_path, config_name)
return GeminiConfig(**config.model_dump())
return GeminiConfig(
model=model_name,
api_key="", # API key handled separately
max_tokens=4096,
temperature=0.3
)

def _load_config(self, config_path: str, config_name: str) -> GeminiConfig:
"""Load Gemini-specific configuration (for compatibility only)"""
# This method is for compatibility with LLMBase interface
# We use _load_config_from_json() instead in the constructor
return self._load_config_from_json()

async def chat(
self,
Expand All @@ -70,10 +83,6 @@ async def chat(
Returns:
LLMResponse: LLM response
"""


for msg in messages:
role = msg.role if hasattr(msg, 'role') else 'unknown'

# Get the last user message
user_message = ""
Expand Down Expand Up @@ -139,54 +148,30 @@ async def chat(
generate_config.response_schema = schema
generate_config.response_mime_type = 'application/json' # Set MIME type to JSON

# Send request
# Send request - use non-streaming for basic chat
logger.debug(f"Gemini request model: {self.config.model}")

content = ""
buffer = ""
is_content = False
stream = self.client.models.generate_content_stream(
response = self.client.models.generate_content(
model=self.config.model,
contents=contents,
config=generate_config
)

for part_response in stream:
chunk = part_response.candidates[0].content.parts[0].text
buffer += chunk
if is_content:
try:
json.loads(buffer)
content = json.loads(buffer)["response"]
# Compare buffer and content already put in queue, don't include any JSON boundary symbols in final output
await self.output_queue.put(chunk.strip("}").strip().strip('"'))
except json.JSONDecodeError as e:

await self.output_queue.put(chunk)
continue
except Exception as e:

logger.error(f"Gemini API request parsing failed: {str(e)}")
logger.error(f"Current buffer: {buffer}")
buffer = ""
is_content = False
elif '"response":' in buffer:
# Truncate from here, save the following content to content
try:
parts = buffer.split('"response":', 1)
if len(parts) > 1:
chunk = parts[1].strip()
is_content = True
await self.output_queue.put(chunk.strip("}").strip().strip('"'))

except Exception as e:
logger.error(f"Gemini API request parsing failed: {str(e)}")
logger.error(f"Current buffer: {buffer}")
buffer = ""
is_content = False
await self.output_queue.put(None)
self.task_done.set()
return LLMResponse(content=content, text=buffer)

# Parse response content
content = ""
if hasattr(response, "candidates") and response.candidates:
candidate = response.candidates[0]
if hasattr(candidate, "content") and candidate.content:
# Iterate through all parts
for part in candidate.content.parts:
# Check if there is text content
if hasattr(part, "text") and part.text:
if content:
content += "\n" + part.text
else:
content = part.text

return LLMResponse(content=content, text=content)
except Exception as e:
error_msg = f"Gemini API request failed: {str(e)}"
logger.error(error_msg)
Expand Down
Loading