Skip to content

feat(js/plugins/{googleai,vertexai}): pasthrough unknown config options directly to the underlying API #2848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,32 +206,40 @@ export const GenerationCommonConfigDescriptions = {
/**
* Zod schema of a common config object.
*/
export const GenerationCommonConfigSchema = z.object({
version: z
.string()
.describe(
'A specific version of a model family, e.g. `gemini-2.0-flash` ' +
'for the `googleai` family.'
)
.optional(),
temperature: z
.number()
.describe(GenerationCommonConfigDescriptions.temperature)
.optional(),
maxOutputTokens: z
.number()
.describe(GenerationCommonConfigDescriptions.maxOutputTokens)
.optional(),
topK: z.number().describe(GenerationCommonConfigDescriptions.topK).optional(),
topP: z.number().describe(GenerationCommonConfigDescriptions.topP).optional(),
stopSequences: z
.array(z.string())
.length(5)
.describe(
'Set of character sequences (up to 5) that will stop output generation.'
)
.optional(),
});
export const GenerationCommonConfigSchema = z
.object({
version: z
.string()
.describe(
'A specific version of a model family, e.g. `gemini-2.0-flash` ' +
'for the `googleai` family.'
)
.optional(),
temperature: z
.number()
.describe(GenerationCommonConfigDescriptions.temperature)
.optional(),
maxOutputTokens: z
.number()
.describe(GenerationCommonConfigDescriptions.maxOutputTokens)
.optional(),
topK: z
.number()
.describe(GenerationCommonConfigDescriptions.topK)
.optional(),
topP: z
.number()
.describe(GenerationCommonConfigDescriptions.topP)
.optional(),
stopSequences: z
.array(z.string())
.length(5)
.describe(
'Set of character sequences (up to 5) that will stop output generation.'
)
.optional(),
})
.passthrough();

/**
* Common config object.
Expand Down Expand Up @@ -619,9 +627,10 @@ function getPartCounts(parts: Part[]): PartCounts {
);
}

export type ModelArgument<
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
> = ModelAction<CustomOptions> | ModelReference<CustomOptions> | string;
export type ModelArgument<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
| ModelAction<CustomOptions>
| ModelReference<CustomOptions>
| string;

export interface ResolvedModel<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
Expand Down
39 changes: 19 additions & 20 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
"'gemini-2.0-flash-exp' model at present."
)
.optional(),
});
}).passthrough();
export type GeminiConfig = z.infer<typeof GeminiConfigSchema>;

export const gemini10Pro = modelRef({
Expand Down Expand Up @@ -911,7 +911,16 @@ export function defineGoogleAIModel({
});
}

if (requestConfig.codeExecution) {
const {
apiKey: apiKeyFromConfig,
safetySettings: safetySettingsFromConfig,
codeExecution: codeExecutionFromConfig,
version: versionFromConfig,
functionCallingConfig,
...restOfConfigOptions
} = requestConfig;

if (codeExecutionFromConfig) {
tools.push({
codeExecution:
request.config.codeExecution === true
Expand All @@ -921,12 +930,11 @@ export function defineGoogleAIModel({
}

let toolConfig: ToolConfig | undefined;
if (requestConfig.functionCallingConfig) {
if (functionCallingConfig) {
toolConfig = {
functionCallingConfig: {
allowedFunctionNames:
requestConfig.functionCallingConfig.allowedFunctionNames,
mode: toFunctionModeEnum(requestConfig.functionCallingConfig.mode),
allowedFunctionNames: functionCallingConfig.allowedFunctionNames,
mode: toFunctionModeEnum(functionCallingConfig.mode),
},
};
} else if (request.toolChoice) {
Expand All @@ -944,19 +952,10 @@ export function defineGoogleAIModel({
tools.length === 0);

const generationConfig: GenerationConfig = {
...restOfConfigOptions,
candidateCount: request.candidates || undefined,
temperature: requestConfig.temperature,
maxOutputTokens: requestConfig.maxOutputTokens,
topK: requestConfig.topK,
topP: requestConfig.topP,
stopSequences: requestConfig.stopSequences,
responseMimeType: jsonMode ? 'application/json' : undefined,
};
if (requestConfig.responseModalities) {
// HACK: cast to any since this isn't officially supported in the old SDK yet
(generationConfig as any).responseModalities =
requestConfig.responseModalities;
}

if (request.output?.constrained && jsonMode) {
generationConfig.responseSchema = cleanSchema(request.output.schema);
Expand All @@ -978,9 +977,9 @@ export function defineGoogleAIModel({
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
safetySettings: requestConfig.safetySettings,
safetySettings: safetySettingsFromConfig,
} as StartChatParams;
const modelVersion = (request.config?.version ||
const modelVersion = (versionFromConfig ||
model.version ||
apiModelName) as string;
const cacheConfigDetails = extractCacheConfig(request);
Expand All @@ -994,14 +993,14 @@ export function defineGoogleAIModel({
cacheConfigDetails
);

if (!requestConfig.apiKey && !apiKey) {
if (!apiKeyFromConfig && !apiKey) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message:
'GoogleAI plugin was initialized with {apiKey: false} but no apiKey configuration was passed at call time.',
});
}
const client = new GoogleGenerativeAI(requestConfig.apiKey || apiKey!);
const client = new GoogleGenerativeAI(apiKeyFromConfig || apiKey!);
let genModel: GenerativeModel;

if (cache) {
Expand Down
75 changes: 55 additions & 20 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
GenerativeModelPreview,
HarmBlockThreshold,
HarmCategory,
SafetySetting,
Schema,
StartChatParams,
ToolConfig,
Expand Down Expand Up @@ -66,9 +67,25 @@ import { PluginOptions } from './common/types.js';
import { handleCacheIfNeeded } from './context-caching/index.js';
import { extractCacheConfig } from './context-caching/utils.js';

const SafetySettingsSchema = z.object({
category: z.nativeEnum(HarmCategory),
threshold: z.nativeEnum(HarmBlockThreshold),
export const SafetySettingsSchema = z.object({
category: z.enum([
/** The harm category is unspecified. */
'HARM_CATEGORY_UNSPECIFIED',
/** The harm category is hate speech. */
'HARM_CATEGORY_HATE_SPEECH',
/** The harm category is dangerous content. */
'HARM_CATEGORY_DANGEROUS_CONTENT',
/** The harm category is harassment. */
'HARM_CATEGORY_HARASSMENT',
/** The harm category is sexually explicit content. */
'HARM_CATEGORY_SEXUALLY_EXPLICIT',
]),
threshold: z.enum([
'BLOCK_LOW_AND_ABOVE',
'BLOCK_MEDIUM_AND_ABOVE',
'BLOCK_ONLY_HIGH',
'BLOCK_NONE',
]),
});

const VertexRetrievalSchema = z.object({
Expand Down Expand Up @@ -250,7 +267,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
'With NONE, the model is prohibited from making function calls.'
)
.optional(),
});
}).passthrough();

/**
* Known model names, to allow code completion for convenience. Allows other model names.
Expand Down Expand Up @@ -1019,17 +1036,29 @@ export function defineGeminiModel({
}
}

const requestConfig = request.config as z.infer<
typeof GeminiConfigSchema
>;
const {
functionCallingConfig,
version: versionFromConfig,
googleSearchRetrieval,
vertexRetrieval,
location, // location can be overridden via config, take it out.
safetySettings,
...restOfConfig
} = requestConfig;

const tools = request.tools?.length
? [{ functionDeclarations: request.tools.map(toGeminiTool) }]
: [];

let toolConfig: ToolConfig | undefined;
if (request?.config?.functionCallingConfig) {
if (functionCallingConfig) {
toolConfig = {
functionCallingConfig: {
allowedFunctionNames:
request.config.functionCallingConfig.allowedFunctionNames,
mode: toFunctionModeEnum(request.config.functionCallingConfig.mode),
allowedFunctionNames: functionCallingConfig.allowedFunctionNames,
mode: toFunctionModeEnum(functionCallingConfig.mode),
},
};
} else if (request.toolChoice) {
Expand All @@ -1053,19 +1082,15 @@ export function defineGeminiModel({
.slice(0, -1)
.map((message) => toGeminiMessage(message, modelInfo)),
generationConfig: {
...restOfConfig,
candidateCount: request.candidates || undefined,
temperature: request.config?.temperature,
maxOutputTokens: request.config?.maxOutputTokens,
topK: request.config?.topK,
topP: request.config?.topP,
responseMimeType: jsonMode ? 'application/json' : undefined,
stopSequences: request.config?.stopSequences,
},
safetySettings: request.config?.safetySettings,
safetySettings: toGeminiSafetySettings(safetySettings),
};

// Handle cache
const modelVersion = (request.config?.version || version) as string;
const modelVersion = (versionFromConfig || version) as string;
const cacheConfigDetails = extractCacheConfig(request);

const apiClient = new ApiClient(
Expand All @@ -1092,15 +1117,13 @@ export function defineGeminiModel({
);
}

if (request.config?.googleSearchRetrieval) {
if (googleSearchRetrieval) {
updatedChatRequest.tools?.push({
googleSearchRetrieval: request.config
.googleSearchRetrieval as GoogleSearchRetrieval,
googleSearchRetrieval: googleSearchRetrieval as GoogleSearchRetrieval,
});
}

if (request.config?.vertexRetrieval) {
const vertexRetrieval = request.config.vertexRetrieval;
if (vertexRetrieval) {
const _projectId =
vertexRetrieval.datastore.projectId || options.projectId;
const _location =
Expand Down Expand Up @@ -1247,6 +1270,18 @@ function toFunctionModeEnum(
}
}

function toGeminiSafetySettings(
genkitSettings?: z.infer<typeof SafetySettingsSchema>[]
): SafetySetting[] | undefined {
if (!genkitSettings) return undefined;
return genkitSettings.map((s) => {
return {
category: s.category as HarmCategory,
threshold: s.threshold as HarmBlockThreshold,
};
});
}

/** Converts mode from genkit tool choice. */
function toGeminiFunctionModeEnum(
genkitMode: 'auto' | 'required' | 'none'
Expand Down
2 changes: 2 additions & 0 deletions js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import {
gemini25ProExp0325,
gemini25ProPreview0325,
GeminiConfigSchema,
SafetySettingsSchema,
SUPPORTED_GEMINI_MODELS,
type GeminiConfig,
type GeminiVersionString,
Expand Down Expand Up @@ -90,6 +91,7 @@ export {
imagen3,
imagen3Fast,
multimodalEmbedding001,
SafetySettingsSchema,
textEmbedding004,
textEmbedding005,
textEmbeddingGecko003,
Expand Down