-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Integrate pricing with canonical model #6130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d28872
f4b8859
96cb5ec
8234a5f
3aee1c6
77bd96e
3584478
70a9a1e
dec46a8
b084dda
df6313d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,10 +13,8 @@ use goose::config::{Config, ConfigError}; | |
| use goose::model::ModelConfig; | ||
| use goose::providers::auto_detect::detect_provider_from_api_key; | ||
| use goose::providers::base::{ProviderMetadata, ProviderType}; | ||
| use goose::providers::canonical::maybe_get_canonical_model; | ||
| use goose::providers::create_with_default_model; | ||
| use goose::providers::pricing::{ | ||
| get_all_pricing, get_model_pricing, parse_model_id, refresh_pricing, | ||
| }; | ||
| use goose::providers::providers as get_providers; | ||
| use goose::{ | ||
| agents::execute_commands, agents::ExtensionConfig, config::permission::PermissionLevel, | ||
|
|
@@ -470,7 +468,8 @@ pub struct PricingResponse { | |
|
|
||
| #[derive(Deserialize, ToSchema)] | ||
| pub struct PricingQuery { | ||
| pub configured_only: bool, | ||
| pub provider: String, | ||
| pub model: String, | ||
| } | ||
|
|
||
| #[utoipa::path( | ||
|
|
@@ -484,84 +483,28 @@ pub struct PricingQuery { | |
| pub async fn get_pricing( | ||
| Json(query): Json<PricingQuery>, | ||
| ) -> Result<Json<PricingResponse>, StatusCode> { | ||
| let configured_only = query.configured_only; | ||
|
|
||
| // If refresh requested (configured_only = false), refresh the cache | ||
| if !configured_only { | ||
| if let Err(e) = refresh_pricing().await { | ||
| tracing::error!("Failed to refresh pricing data: {}", e); | ||
| } | ||
| } | ||
| let canonical_model = | ||
| maybe_get_canonical_model(&query.provider, &query.model).ok_or(StatusCode::NOT_FOUND)?; | ||
|
|
||
| let mut pricing_data = Vec::new(); | ||
|
|
||
| if !configured_only { | ||
| // Get ALL pricing data from the cache | ||
| let all_pricing = get_all_pricing().await; | ||
|
|
||
| for (provider, models) in all_pricing { | ||
| for (model, pricing) in models { | ||
| pricing_data.push(PricingData { | ||
| provider: provider.clone(), | ||
| model: model.clone(), | ||
| input_token_cost: pricing.input_cost, | ||
| output_token_cost: pricing.output_cost, | ||
| currency: "$".to_string(), | ||
| context_length: pricing.context_length, | ||
| }); | ||
| } | ||
| } | ||
| } else { | ||
| for (metadata, provider_type) in get_providers().await { | ||
| // Skip unconfigured providers if filtering | ||
| if !check_provider_configured(&metadata, provider_type) { | ||
| continue; | ||
| } | ||
|
|
||
| for model_info in &metadata.known_models { | ||
| // Handle OpenRouter models specially - they store full provider/model names | ||
| let (lookup_provider, lookup_model) = if metadata.name == "openrouter" { | ||
| // For OpenRouter, parse the model name to extract real provider/model | ||
| if let Some((provider, model)) = parse_model_id(&model_info.name) { | ||
| (provider, model) | ||
| } else { | ||
| // Fallback if parsing fails | ||
| (metadata.name.clone(), model_info.name.clone()) | ||
| } | ||
| } else { | ||
| // For other providers, use names as-is | ||
| (metadata.name.clone(), model_info.name.clone()) | ||
| }; | ||
|
|
||
| // Only get pricing from OpenRouter cache | ||
| if let Some(pricing) = get_model_pricing(&lookup_provider, &lookup_model).await { | ||
| pricing_data.push(PricingData { | ||
| provider: metadata.name.clone(), | ||
| model: model_info.name.clone(), | ||
| input_token_cost: pricing.input_cost, | ||
| output_token_cost: pricing.output_cost, | ||
| currency: "$".to_string(), | ||
| context_length: pricing.context_length, | ||
| }); | ||
| } | ||
| // No fallback to hardcoded prices | ||
| } | ||
| } | ||
| if let (Some(input_cost), Some(output_cost)) = ( | ||
| canonical_model.pricing.prompt, | ||
| canonical_model.pricing.completion, | ||
| ) { | ||
| pricing_data.push(PricingData { | ||
| provider: query.provider.clone(), | ||
| model: query.model.clone(), | ||
| input_token_cost: input_cost, | ||
| output_token_cost: output_cost, | ||
| currency: "$".to_string(), | ||
| context_length: Some(canonical_model.context_length as u32), | ||
| }); | ||
| } | ||
|
Comment on lines
+491
to
503
|
||
|
|
||
| tracing::debug!( | ||
| "Returning pricing for {} models{}", | ||
| pricing_data.len(), | ||
| if configured_only { | ||
| " (configured providers only)" | ||
| } else { | ||
| " (all cached models)" | ||
| } | ||
| ); | ||
|
|
||
| Ok(Json(PricingResponse { | ||
| pricing: pricing_data, | ||
| source: "openrouter".to_string(), | ||
| source: "canonical".to_string(), | ||
| })) | ||
| } | ||
|
Comment on lines
483
to
509
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,3 +20,9 @@ impl ModelMapping { | |
| } | ||
| } | ||
| } | ||
|
|
||
| pub fn maybe_get_canonical_model(provider: &str, model: &str) -> Option<CanonicalModel> { | ||
| let registry = CanonicalModelRegistry::bundled().ok()?; | ||
| let canonical_id = map_to_canonical_model(provider, model, registry)?; | ||
| registry.get(&canonical_id).cloned() | ||
|
Comment on lines
+24
to
+27
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cast from
usizetou32could overflow on 64-bit systems if context_length exceeds u32::MAX. Consider usingtry_into()with proper error handling or validate that the value fits within u32 range.