Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,27 @@ func configToMap(config any) map[string]any {
r := jsonschema.Reflector{
DoNotReference: true, // Prevent $ref usage
ExpandedStruct: true, // Include all fields directly
// Prevent stack overflow panic due type traversal recursion (circular references)
// [genai.Schema] should not be used at this point since Schema is provided later
// NOTE: keep track of updated fields in [genai.GenerateContentConfig] since
// they could create runtime panics when parsing fields with type recursion
IgnoredTypes: []any{genai.Schema{}},
IgnoredTypes: []any{
genai.Schema{},
genai.Tool{},
genai.ToolConfig{},
genai.HTTPOptions{},
},
}

schema := r.Reflect(config)
result := base.SchemaAsMap(schema)

// prevent users to override Genkit primitive features
if propertiesMap, ok := result["properties"].(map[string]any); ok {
delete(propertiesMap, "cachedContent")
delete(propertiesMap, "systemInstruction")
delete(propertiesMap, "responseMimeType")
delete(propertiesMap, "responseJsonSchema")
delete(propertiesMap, "candidateCount")
}
return result
}

Expand Down Expand Up @@ -140,9 +153,8 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model

var config any
config = &genai.GenerateContentConfig{}
if imageOpts, found := supportedImagenModels[name]; found {
if strings.Contains(name, "imagen") {
config = &genai.GenerateImagesConfig{}
opts = imageOpts
}
meta := &ai.ModelOptions{
Label: opts.Label,
Expand Down Expand Up @@ -711,7 +723,6 @@ func toGeminiParts(parts []*ai.Part) ([]*genai.Part, error) {

// toGeminiPart converts a [ai.Part] to a [genai.Part].
func toGeminiPart(p *ai.Part) (*genai.Part, error) {

switch {
case p.IsReasoning():
// TODO: go-genai does not support genai.NewPartFromThought()
Expand Down
101 changes: 81 additions & 20 deletions go/plugins/googlegenai/googlegenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,37 @@ func (ga *GoogleAI) ListActions(ctx context.Context) []core.ActionDesc {
"systemRole": true,
"tools": true,
"toolChoice": true,
"constrained": true,
"constrained": "no-tools",
},
"versions": []string{},
"stage": string(ai.ModelStageStable),
"versions": []string{},
"stage": string(ai.ModelStageStable),
"customOptions": configToMap(&genai.GenerateContentConfig{}),
},
}
metadata["label"] = fmt.Sprintf("%s - %s", googleAILabelPrefix, name)

actions = append(actions, core.ActionDesc{
Type: core.ActionTypeModel,
Name: fmt.Sprintf("%s/%s", googleAIProvider, name),
Key: fmt.Sprintf("/%s/%s/%s", core.ActionTypeModel, googleAIProvider, name),
Metadata: metadata,
})
}

for _, name := range models.imagen {
metadata := map[string]any{
"model": map[string]any{
"supports": map[string]any{
"media": true,
"multiturn": true,
"systemRole": false,
"tools": false,
"toolChoice": false,
"constrained": "no-tools",
},
"versions": []string{},
"stage": string(ai.ModelStageStable),
"customOptions": configToMap(&genai.GenerateImagesConfig{}),
},
}
metadata["label"] = fmt.Sprintf("%s - %s", googleAILabelPrefix, name)
Expand All @@ -352,20 +379,24 @@ func (ga *GoogleAI) ListActions(ctx context.Context) []core.ActionDesc {
}

func (ga *GoogleAI) ResolveAction(atype core.ActionType, name string) core.Action {
var config any
switch atype {
case core.ActionTypeEmbedder:
return newEmbedder(ga.gclient, name, &ai.EmbedderOptions{}).(core.Action)
case core.ActionTypeModel:
var supports *ai.ModelSupports
if strings.Contains(name, "gemini") || strings.Contains(name, "gemma") {
supports = &Multimodal
supports := &Multimodal
config = &genai.GenerateContentConfig{}
if strings.Contains(name, "imagen") {
supports = &Media
config = &genai.GenerateImagesConfig{}
}

return newModel(ga.gclient, name, ai.ModelOptions{
Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name),
Stage: ai.ModelStageStable,
Versions: []string{},
Supports: supports,
Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name),
Stage: ai.ModelStageStable,
Versions: []string{},
Supports: supports,
ConfigSchema: configToMap(config),
}).(core.Action)
}

Expand All @@ -388,10 +419,36 @@ func (v *VertexAI) ListActions(ctx context.Context) []core.ActionDesc {
"systemRole": true,
"tools": true,
"toolChoice": true,
"constrained": true,
"constrained": "no-tools",
},
"versions": []string{},
"stage": string(ai.ModelStageStable),
"customOptions": configToMap(&genai.GenerateContentConfig{}),
},
}
metadata["label"] = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name)
actions = append(actions, core.ActionDesc{
Type: core.ActionTypeModel,
Name: fmt.Sprintf("%s/%s", vertexAIProvider, name),
Key: fmt.Sprintf("/%s/%s/%s", core.ActionTypeModel, vertexAIProvider, name),
Metadata: metadata,
})
}

for _, name := range models.imagen {
metadata := map[string]any{
"model": map[string]any{
"supports": map[string]any{
"media": true,
"multiturn": true,
"systemRole": false,
"tools": false,
"toolChoice": false,
"constrained": "no-tools",
},
"versions": []string{},
"stage": string(ai.ModelStageStable),
"versions": []string{},
"stage": string(ai.ModelStageStable),
"customOptions": configToMap(&genai.GenerateImagesConfig{}),
},
}
metadata["label"] = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name)
Expand All @@ -415,20 +472,24 @@ func (v *VertexAI) ListActions(ctx context.Context) []core.ActionDesc {
}

func (v *VertexAI) ResolveAction(atype core.ActionType, name string) core.Action {
var config any
switch atype {
case core.ActionTypeEmbedder:
return newEmbedder(v.gclient, name, &ai.EmbedderOptions{}).(core.Action)
case core.ActionTypeModel:
var supports *ai.ModelSupports
if strings.Contains(name, "gemini") {
supports = &Multimodal
supports := &Multimodal
config = &genai.GenerateContentConfig{}
if strings.Contains(name, "imagen") {
supports = &Media
config = &genai.GenerateImagesConfig{}
}

return newModel(v.gclient, name, ai.ModelOptions{
Label: fmt.Sprintf("%s - %s", vertexAILabelPrefix, name),
Stage: ai.ModelStageStable,
Versions: []string{},
Supports: supports,
Label: fmt.Sprintf("%s - %s", vertexAILabelPrefix, name),
Stage: ai.ModelStageStable,
Versions: []string{},
Supports: supports,
ConfigSchema: configToMap(config),
}).(core.Action)
}
return nil
Expand Down
25 changes: 11 additions & 14 deletions go/plugins/googlegenai/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ type genaiModels struct {
func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) {
models := genaiModels{}
allowedModels := []string{"gemini", "gemma"}
allowedImagenModels := []string{"imagen"}

for item, err := range client.Models.All(ctx) {
var name string
Expand All @@ -428,22 +427,20 @@ func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, er
continue
}

found := slices.ContainsFunc(allowedModels, func(s string) bool {
return strings.Contains(name, s)
})
// filter out: Aqa, Text-bison, Chat, learnlm
if found {
models.gemini = append(models.gemini, name)
if slices.Contains(item.SupportedActions, "predict") && strings.Contains(name, "imagen") {
models.imagen = append(models.imagen, name)
continue
}

found = slices.ContainsFunc(allowedImagenModels, func(s string) bool {
return strings.Contains(name, s)
})
// filter out: Aqa, Text-bison, Chat, learnlm
if found {
models.imagen = append(models.imagen, name)
continue
if slices.Contains(item.SupportedActions, "generateContent") {
found := slices.ContainsFunc(allowedModels, func(s string) bool {
return strings.Contains(name, s)
})
// filter out: Aqa, Text-bison, Chat, learnlm
if found {
models.gemini = append(models.gemini, name)
continue
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions go/samples/prompts/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func SimplePrompt(ctx context.Context, g *genkit.Genkit) {
// Define prompt with default model and system text.
helloPrompt := genkit.DefinePrompt(
g, "SimplePrompt",
ai.WithModelName("vertexai/gemini-2.0-flash-lite"), // Override the default model.
ai.WithModelName("vertexai/gemini-2.5-pro"), // Override the default model.
ai.WithSystem("You are a helpful AI assistant named Walt. Greet the user."),
ai.WithPrompt("Hello, who are you?"),
)
Expand Down Expand Up @@ -272,7 +272,7 @@ func PromptWithExecuteOverrides(ctx context.Context, g *genkit.Genkit) {

// Call the model and add additional messages from the user.
resp, err := helloPrompt.Execute(ctx,
ai.WithModel(googlegenai.VertexAIModel(g, "gemini-2.0-flash-lite")),
ai.WithModel(googlegenai.VertexAIModel(g, "gemini-2.5-pro")),
ai.WithMessages(ai.NewUserTextMessage("And I like turtles.")),
)
if err != nil {
Expand Down
Loading