Skip to content

Commit f7d9fff

Browse files
committed
refactor: Introducing Models.resolve to simplify and unify model resolution in chat, embedding, and image methods
Now `assume_model_exists` can be used in `paint` and `embed` methods too.
1 parent 9ccf3e7 commit f7d9fff

File tree

5 files changed

+41
-26
lines changed

5 files changed

+41
-26
lines changed

lib/ruby_llm/chat.rb

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,8 @@ def with_tools(*tools)
6060
self
6161
end
6262

63-
def with_model(model_id, provider: nil, assume_exists: false) # rubocop:disable Metrics/AbcSize,Metrics/MethodLength
64-
assume_exists = true if provider && Provider.providers[provider.to_sym].local?
65-
66-
if assume_exists
67-
raise ArgumentError, 'Provider must be specified if assume_exists is true' unless provider
68-
69-
@provider = Provider.providers[provider.to_sym] || raise(Error, "Unknown provider: #{provider.to_sym}")
70-
@model = Struct.new(:id, :provider, :supports_functions, :supports_vision).new(model_id, provider, true, true)
71-
RubyLLM.logger.warn "Assuming model '#{model_id}' exists for provider '#{provider}'. " \
72-
'Capabilities may not be accurately reflected.'
73-
else
74-
@model = Models.find model_id, provider
75-
@provider = Provider.providers[@model.provider.to_sym] || raise(Error, "Unknown provider: #{@model.provider}")
76-
end
63+
def with_model(model_id, provider: nil, assume_exists: false)
64+
@model, @provider = Models.resolve(model_id, provider:, assume_exists:)
7765
self
7866
end
7967

lib/ruby_llm/embedding.rb

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@ def initialize(vectors:, model:, input_tokens: 0)
1212
@input_tokens = input_tokens
1313
end
1414

15-
def self.embed(text, model: nil, provider: nil, context: nil, dimensions: nil)
15+
def self.embed(text, # rubocop:disable Metrics/ParameterLists,Metrics/CyclomaticComplexity
16+
model: nil,
17+
provider: nil,
18+
assume_model_exists: false,
19+
context: nil,
20+
dimensions: nil)
1621
config = context&.config || RubyLLM.config
17-
model_id = model || config.default_embedding_model
18-
Models.find(model_id, provider)
22+
model, provider = Models.resolve(model, provider: provider, assume_exists: assume_model_exists) if model
23+
model_id = model&.id || config.default_embedding_model
1924

20-
provider = Provider.for(model_id)
25+
provider = Provider.for(model_id) if provider.nil?
2126
connection = context ? context.connection_for(provider) : provider.connection(config)
2227
provider.embed(text, model: model_id, connection:, dimensions:)
2328
end

lib/ruby_llm/image.rb

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,17 @@ def save(path)
3636
path
3737
end
3838

39-
def self.paint(prompt, model: nil, provider: nil, size: '1024x1024', context: nil)
39+
def self.paint(prompt, # rubocop:disable Metrics/ParameterLists,Metrics/CyclomaticComplexity
40+
model: nil,
41+
provider: nil,
42+
assume_model_exists: false,
43+
size: '1024x1024',
44+
context: nil)
4045
config = context&.config || RubyLLM.config
41-
model_id = model || config.default_image_model
42-
Models.find(model_id, provider) # Validate model exists
46+
model, provider = Models.resolve(model, provider: provider, assume_exists: assume_model_exists) if model
47+
model_id = model&.id || config.default_image_model
4348

44-
provider = Provider.for(model_id)
49+
provider = Provider.for(model_id) if provider.nil?
4550
connection = context ? context.connection_for(provider) : provider.connection(config)
4651
provider.paint(prompt, model: model_id, size:, connection:)
4752
end

lib/ruby_llm/models.rb

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module RubyLLM
99
# RubyLLM.models.chat_models # Models that support chat
1010
# RubyLLM.models.by_provider('openai').chat_models # OpenAI chat models
1111
# RubyLLM.models.find('claude-3') # Get info about a specific model
12-
class Models
12+
class Models # rubocop:disable Metrics/ClassLength
1313
include Enumerable
1414

1515
# Delegate class methods to the singleton instance
@@ -46,6 +46,23 @@ def refresh! # rubocop:disable Metrics/AbcSize,Metrics/CyclomaticComplexity,Metr
4646
@instance
4747
end
4848

49+
def resolve(model_id, provider: nil, assume_exists: false) # rubocop:disable Metrics/AbcSize,Metrics/MethodLength
50+
assume_exists = true if provider && Provider.providers[provider.to_sym].local?
51+
52+
if assume_exists
53+
raise ArgumentError, 'Provider must be specified if assume_exists is true' unless provider
54+
55+
provider = Provider.providers[provider.to_sym] || raise(Error, "Unknown provider: #{provider.to_sym}")
56+
model = Struct.new(:id, :provider, :supports_functions, :supports_vision).new(model_id, provider, true, true)
57+
RubyLLM.logger.warn "Assuming model '#{model_id}' exists for provider '#{provider}'. " \
58+
'Capabilities may not be accurately reflected.'
59+
else
60+
model = Models.find model_id, provider
61+
provider = Provider.providers[model.provider.to_sym] || raise(Error, "Unknown provider: #{model.provider}")
62+
end
63+
[model, provider]
64+
end
65+
4966
def method_missing(method, ...)
5067
if instance.respond_to?(method)
5168
instance.send(method, ...)

spec/ruby_llm/chat_assume_model_exists_spec.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
describe '#assume_model_exists' do
99
let(:real_model) { 'gpt-4.1-nano' }
1010
let(:custom_model) { 'my-custom-model' }
11-
let(:provider) { :openai }
11+
let(:provider) { 'openai' }
1212
# Keep a reference to the original models for cleanup
1313
let!(:original_models) { RubyLLM::Models.instance.all.dup }
1414

@@ -33,7 +33,7 @@
3333
)
3434

3535
expect(chat.model.id).to eq(custom_model)
36-
expect(chat.model.provider).to eq(provider)
36+
expect(chat.model.provider.slug).to eq(provider)
3737
end
3838

3939
it 'works with RubyLLM.chat convenience method' do # rubocop:disable RSpec/ExampleLength
@@ -74,7 +74,7 @@
7474
chat.with_model(custom_model, provider: provider, assume_exists: true)
7575

7676
expect(chat.model.id).to eq(custom_model)
77-
expect(chat.model.provider).to eq(provider)
77+
expect(chat.model.provider.slug).to eq(provider)
7878
end
7979
end
8080
end

0 commit comments

Comments
 (0)