Skip to content

Commit df51ca8

Browse files
committed
refactor(plugins/googleai): migrate to v2 API
1 parent 2b43449 commit df51ca8

File tree

6 files changed

+194
-281
lines changed

6 files changed

+194
-281
lines changed

js/plugins/googleai/src/embedder.ts

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ import {
2323
z,
2424
type EmbedderAction,
2525
type EmbedderReference,
26-
type Genkit,
2726
} from 'genkit';
2827
import { embedderRef } from 'genkit/embedder';
28+
import { embedder } from 'genkit/plugin';
2929
import { getApiKeyFromEnvVar } from './common.js';
3030
import type { PluginOptions } from './index.js';
3131

@@ -97,13 +97,12 @@ export const geminiEmbedding001 = embedderRef({
9797
});
9898

9999
export const SUPPORTED_MODELS = {
100-
'embedding-001': textEmbeddingGecko001,
101-
'text-embedding-004': textEmbedding004,
102-
'gemini-embedding-001': geminiEmbedding001,
100+
'googleai/embedding-001': textEmbeddingGecko001,
101+
'googleai/text-embedding-004': textEmbedding004,
102+
'googleai/gemini-embedding-001': geminiEmbedding001,
103103
};
104104

105105
export function defineGoogleAIEmbedder(
106-
ai: Genkit,
107106
name: string,
108107
pluginOptions: PluginOptions
109108
): EmbedderAction<any> {
@@ -117,7 +116,7 @@ export function defineGoogleAIEmbedder(
117116
'For more details see https://genkit.dev/docs/plugins/google-genai'
118117
);
119118
}
120-
const embedder: EmbedderReference =
119+
const embedderReference: EmbedderReference =
121120
SUPPORTED_MODELS[name] ??
122121
embedderRef({
123122
name: name,
@@ -130,16 +129,16 @@ export function defineGoogleAIEmbedder(
130129
},
131130
},
132131
});
133-
const apiModelName = embedder.name.startsWith('googleai/')
134-
? embedder.name.substring('googleai/'.length)
135-
: embedder.name;
136-
return ai.defineEmbedder(
132+
const apiModelName = embedderReference.name.startsWith('googleai/')
133+
? embedderReference.name.substring('googleai/'.length)
134+
: embedderReference.name;
135+
return embedder(
137136
{
138-
name: embedder.name,
137+
name: embedderReference.name,
139138
configSchema: GeminiEmbeddingConfigSchema,
140-
info: embedder.info!,
139+
info: embedderReference.info!,
141140
},
142-
async (input, options) => {
141+
async ({ input, options }) => {
143142
if (pluginOptions.apiKey === false && !options?.apiKey) {
144143
throw new GenkitError({
145144
status: 'INVALID_ARGUMENT',
@@ -152,8 +151,8 @@ export function defineGoogleAIEmbedder(
152151
).getGenerativeModel({
153152
model:
154153
options?.version ||
155-
embedder.config?.version ||
156-
embedder.version ||
154+
embedderReference.config?.version ||
155+
embedderReference.version ||
157156
apiModelName,
158157
});
159158
const embeddings = await Promise.all(

js/plugins/googleai/src/gemini.ts

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import {
3838
type ToolConfig,
3939
type UsageMetadata,
4040
} from '@google/generative-ai';
41-
import { GenkitError, z, type Genkit, type JSONSchema } from 'genkit';
41+
import { GenkitError, z, type JSONSchema } from 'genkit';
4242
import {
4343
GenerationCommonConfigDescriptions,
4444
GenerationCommonConfigSchema,
@@ -57,8 +57,8 @@ import {
5757
type ToolResponsePart,
5858
} from 'genkit/model';
5959
import { downloadRequestMedia } from 'genkit/model/middleware';
60-
import { runInNewSpan } from 'genkit/tracing';
61-
import { getApiKeyFromEnvVar, getGenkitClientHeader } from './common';
60+
import { model } from 'genkit/plugin';
61+
import { getApiKeyFromEnvVar, getGenkitClientHeader } from './common.js';
6262
import { handleCacheIfNeeded } from './context-caching';
6363
import { extractCacheConfig } from './context-caching/utils';
6464

@@ -644,26 +644,26 @@ export const gemma3ne4bit = modelRef({
644644
});
645645

646646
export const SUPPORTED_GEMINI_MODELS = {
647-
'gemini-1.5-pro': gemini15Pro,
648-
'gemini-1.5-flash': gemini15Flash,
649-
'gemini-1.5-flash-8b': gemini15Flash8b,
650-
'gemini-2.0-pro-exp-02-05': gemini20ProExp0205,
651-
'gemini-2.0-flash': gemini20Flash,
652-
'gemini-2.0-flash-lite': gemini20FlashLite,
653-
'gemini-2.0-flash-exp': gemini20FlashExp,
654-
'gemini-2.5-pro-exp-03-25': gemini25ProExp0325,
655-
'gemini-2.5-pro-preview-03-25': gemini25ProPreview0325,
656-
'gemini-2.5-pro-preview-tts': gemini25ProPreviewTts,
657-
'gemini-2.5-flash-preview-04-17': gemini25FlashPreview0417,
658-
'gemini-2.5-flash-preview-tts': gemini25FlashPreviewTts,
659-
'gemini-2.5-flash': gemini25Flash,
660-
'gemini-2.5-flash-lite': gemini25FlashLite,
661-
'gemini-2.5-pro': gemini25Pro,
662-
'gemma-3-12b-it': gemma312bit,
663-
'gemma-3-1b-it': gemma31bit,
664-
'gemma-3-27b-it': gemma327bit,
665-
'gemma-3-4b-it': gemma34bit,
666-
'gemma-3n-e4b-it': gemma3ne4bit,
647+
'googleai/gemini-1.5-pro': gemini15Pro,
648+
'googleai/gemini-1.5-flash': gemini15Flash,
649+
'googleai/gemini-1.5-flash-8b': gemini15Flash8b,
650+
'googleai/gemini-2.0-pro-exp-02-05': gemini20ProExp0205,
651+
'googleai/gemini-2.0-flash': gemini20Flash,
652+
'googleai/gemini-2.0-flash-lite': gemini20FlashLite,
653+
'googleai/gemini-2.0-flash-exp': gemini20FlashExp,
654+
'googleai/gemini-2.5-pro-exp-03-25': gemini25ProExp0325,
655+
'googleai/gemini-2.5-pro-preview-03-25': gemini25ProPreview0325,
656+
'googleai/gemini-2.5-pro-preview-tts': gemini25ProPreviewTts,
657+
'googleai/gemini-2.5-flash-preview-04-17': gemini25FlashPreview0417,
658+
'googleai/gemini-2.5-flash-preview-tts': gemini25FlashPreviewTts,
659+
'googleai/gemini-2.5-flash': gemini25Flash,
660+
'googleai/gemini-2.5-flash-lite': gemini25FlashLite,
661+
'googleai/gemini-2.5-pro': gemini25Pro,
662+
'googleai/gemma-3-12b-it': gemma312bit,
663+
'googleai/gemma-3-1b-it': gemma31bit,
664+
'googleai/gemma-3-27b-it': gemma327bit,
665+
'googleai/gemma-3-4b-it': gemma34bit,
666+
'googleai/gemma-3n-e4b-it': gemma3ne4bit,
667667
};
668668

669669
export const GENERIC_GEMINI_MODEL = modelRef({
@@ -705,7 +705,7 @@ export type GeminiVersionString =
705705
* ```js
706706
* await ai.generate({
707707
* prompt: 'hi',
708-
* model: gemini('gemini-1.5-flash')
708+
* model: gemini('googleai/gemini-1.5-flash')
709709
* });
710710
* ```
711711
*/
@@ -1118,7 +1118,6 @@ export function cleanSchema(schema: JSONSchema): JSONSchema {
11181118
* Defines a new GoogleAI model.
11191119
*/
11201120
export function defineGoogleAIModel({
1121-
ai,
11221121
name,
11231122
apiKey: apiKeyOption,
11241123
apiVersion,
@@ -1127,7 +1126,6 @@ export function defineGoogleAIModel({
11271126
defaultConfig,
11281127
debugTraces,
11291128
}: {
1130-
ai: Genkit;
11311129
name: string;
11321130
apiKey?: string | false;
11331131
apiVersion?: string;
@@ -1154,10 +1152,10 @@ export function defineGoogleAIModel({
11541152
? name.substring('googleai/'.length)
11551153
: name;
11561154

1157-
const model: ModelReference<z.ZodTypeAny> =
1155+
const modelReference: ModelReference<z.ZodTypeAny> =
11581156
SUPPORTED_GEMINI_MODELS[apiModelName] ??
11591157
modelRef({
1160-
name: `googleai/${apiModelName}`,
1158+
name: name, // Keep the full name for the model reference
11611159
info: {
11621160
label: `Google AI - ${apiModelName}`,
11631161
supports: {
@@ -1173,7 +1171,7 @@ export function defineGoogleAIModel({
11731171
});
11741172

11751173
const middleware: ModelMiddleware[] = [];
1176-
if (model.info?.supports?.media) {
1174+
if (modelReference.info?.supports?.media) {
11771175
// the gemini api doesn't support downloading media from http(s)
11781176
middleware.push(
11791177
downloadRequestMedia({
@@ -1199,12 +1197,11 @@ export function defineGoogleAIModel({
11991197
);
12001198
}
12011199

1202-
return ai.defineModel(
1200+
return model(
12031201
{
1204-
apiVersion: 'v2',
1205-
name: model.name,
1206-
...model.info,
1207-
configSchema: model.configSchema,
1202+
name: modelReference.name,
1203+
...modelReference.info,
1204+
configSchema: modelReference.configSchema,
12081205
use: middleware,
12091206
},
12101207
async (request, { streamingRequested, sendChunk, abortSignal }) => {
@@ -1228,7 +1225,7 @@ export function defineGoogleAIModel({
12281225
// systemInstructions to be provided as a separate input. The first
12291226
// message detected with role=system will be used for systemInstructions.
12301227
let systemInstruction: GeminiMessage | undefined = undefined;
1231-
if (model.info?.supports?.systemRole) {
1228+
if (modelReference.info?.supports?.systemRole) {
12321229
const systemMessage = messages.find((m) => m.role === 'system');
12331230
if (systemMessage) {
12341231
messages.splice(messages.indexOf(systemMessage), 1);
@@ -1306,7 +1303,10 @@ export function defineGoogleAIModel({
13061303
generationConfig.responseSchema = cleanSchema(request.output.schema);
13071304
}
13081305

1309-
const msg = toGeminiMessage(messages[messages.length - 1], model);
1306+
const msg = toGeminiMessage(
1307+
messages[messages.length - 1],
1308+
modelReference
1309+
);
13101310

13111311
const fromJSONModeScopedGeminiCandidate = (
13121312
candidate: GeminiCandidate
@@ -1321,11 +1321,11 @@ export function defineGoogleAIModel({
13211321
toolConfig,
13221322
history: messages
13231323
.slice(0, -1)
1324-
.map((message) => toGeminiMessage(message, model)),
1324+
.map((message) => toGeminiMessage(message, modelReference)),
13251325
safetySettings: safetySettingsFromConfig,
13261326
} as StartChatParams;
13271327
const modelVersion = (versionFromConfig ||
1328-
model.version ||
1328+
modelReference.version ||
13291329
apiModelName) as string;
13301330
const cacheConfigDetails = extractCacheConfig(request);
13311331

@@ -1426,31 +1426,8 @@ export function defineGoogleAIModel({
14261426
};
14271427
};
14281428

1429-
// If debugTraces is enable, we wrap the actual model call with a span, add raw
1430-
// API params as for input.
1431-
return debugTraces
1432-
? await runInNewSpan(
1433-
ai.registry,
1434-
{
1435-
metadata: {
1436-
name: streamingRequested ? 'sendMessageStream' : 'sendMessage',
1437-
},
1438-
},
1439-
async (metadata) => {
1440-
metadata.input = {
1441-
sdk: '@google/generative-ai',
1442-
cache: cache,
1443-
model: genModel.model,
1444-
chatOptions: updatedChatRequest,
1445-
parts: msg.parts,
1446-
options,
1447-
};
1448-
const response = await callGemini();
1449-
metadata.output = response.custom;
1450-
return response;
1451-
}
1452-
)
1453-
: await callGemini();
1429+
// TODO v2: no ai.registry available here; run without the debug span wrapper.
1430+
return await callGemini();
14541431
}
14551432
);
14561433
}

js/plugins/googleai/src/imagen.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { GenkitError, MessageData, z, type Genkit } from 'genkit';
17+
import { GenkitError, MessageData, z } from 'genkit';
1818
import {
1919
getBasicUsageStats,
2020
modelRef,
@@ -23,6 +23,7 @@ import {
2323
type ModelInfo,
2424
type ModelReference,
2525
} from 'genkit/model';
26+
import { model } from 'genkit/plugin';
2627
import { getApiKeyFromEnvVar } from './common.js';
2728
import { predictModel } from './predict.js';
2829

@@ -109,7 +110,6 @@ export const GENERIC_IMAGEN_INFO = {
109110
} as ModelInfo;
110111

111112
export function defineImagenModel(
112-
ai: Genkit,
113113
name: string,
114114
apiKey?: string | false
115115
): ModelAction {
@@ -125,7 +125,7 @@ export function defineImagenModel(
125125
}
126126
}
127127
const modelName = `googleai/${name}`;
128-
const model: ModelReference<z.ZodTypeAny> = modelRef({
128+
const modelReference: ModelReference<z.ZodTypeAny> = modelRef({
129129
name: modelName,
130130
info: {
131131
...GENERIC_IMAGEN_INFO,
@@ -134,10 +134,10 @@ export function defineImagenModel(
134134
configSchema: ImagenConfigSchema,
135135
});
136136

137-
return ai.defineModel(
137+
return model(
138138
{
139139
name: modelName,
140-
...model.info,
140+
...modelReference.info,
141141
configSchema: ImagenConfigSchema,
142142
},
143143
async (request) => {
@@ -153,7 +153,7 @@ export function defineImagenModel(
153153
ImagenInstance,
154154
ImagenPrediction,
155155
ImagenParameters
156-
>(model.version || name, apiKey as string, 'predict');
156+
>(modelReference.version || name, apiKey as string, 'predict');
157157
const response = await predictClient([instance], toParameters(request));
158158

159159
if (!response.predictions || response.predictions.length == 0) {

0 commit comments

Comments
 (0)