diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index b731b5347..7ae05ee58 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -206,32 +206,40 @@ export const GenerationCommonConfigDescriptions = { /** * Zod schema of a common config object. */ -export const GenerationCommonConfigSchema = z.object({ - version: z - .string() - .describe( - 'A specific version of a model family, e.g. `gemini-2.0-flash` ' + - 'for the `googleai` family.' - ) - .optional(), - temperature: z - .number() - .describe(GenerationCommonConfigDescriptions.temperature) - .optional(), - maxOutputTokens: z - .number() - .describe(GenerationCommonConfigDescriptions.maxOutputTokens) - .optional(), - topK: z.number().describe(GenerationCommonConfigDescriptions.topK).optional(), - topP: z.number().describe(GenerationCommonConfigDescriptions.topP).optional(), - stopSequences: z - .array(z.string()) - .length(5) - .describe( - 'Set of character sequences (up to 5) that will stop output generation.' - ) - .optional(), -}); +export const GenerationCommonConfigSchema = z + .object({ + version: z + .string() + .describe( + 'A specific version of a model family, e.g. `gemini-2.0-flash` ' + + 'for the `googleai` family.' + ) + .optional(), + temperature: z + .number() + .describe(GenerationCommonConfigDescriptions.temperature) + .optional(), + maxOutputTokens: z + .number() + .describe(GenerationCommonConfigDescriptions.maxOutputTokens) + .optional(), + topK: z + .number() + .describe(GenerationCommonConfigDescriptions.topK) + .optional(), + topP: z + .number() + .describe(GenerationCommonConfigDescriptions.topP) + .optional(), + stopSequences: z + .array(z.string()) + .length(5) + .describe( + 'Set of character sequences (up to 5) that will stop output generation.' + ) + .optional(), + }) + .passthrough(); /** * Common config object. @@ -619,9 +627,10 @@ function getPartCounts(parts: Part[]): PartCounts { ); } -export type ModelArgument< - CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, -> = ModelAction | ModelReference | string; +export type ModelArgument = + | ModelAction + | ModelReference + | string; export interface ResolvedModel< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 750db3730..8121bbac6 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -132,7 +132,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ "'gemini-2.0-flash-exp' model at present." ) .optional(), -}); +}).passthrough(); export type GeminiConfig = z.infer; export const gemini10Pro = modelRef({ @@ -911,7 +911,16 @@ export function defineGoogleAIModel({ }); } - if (requestConfig.codeExecution) { + const { + apiKey: apiKeyFromConfig, + safetySettings: safetySettingsFromConfig, + codeExecution: codeExecutionFromConfig, + version: versionFromConfig, + functionCallingConfig, + ...restOfConfigOptions + } = requestConfig; + + if (codeExecutionFromConfig) { tools.push({ codeExecution: request.config.codeExecution === true @@ -921,12 +930,11 @@ export function defineGoogleAIModel({ } let toolConfig: ToolConfig | undefined; - if (requestConfig.functionCallingConfig) { + if (functionCallingConfig) { toolConfig = { functionCallingConfig: { - allowedFunctionNames: - requestConfig.functionCallingConfig.allowedFunctionNames, - mode: toFunctionModeEnum(requestConfig.functionCallingConfig.mode), + allowedFunctionNames: functionCallingConfig.allowedFunctionNames, + mode: toFunctionModeEnum(functionCallingConfig.mode), }, }; } else if (request.toolChoice) { @@ -944,19 +952,10 @@ export function defineGoogleAIModel({ tools.length === 0); const generationConfig: GenerationConfig = { + ...restOfConfigOptions, candidateCount: request.candidates || undefined, - temperature: requestConfig.temperature, - maxOutputTokens: requestConfig.maxOutputTokens, - topK: requestConfig.topK, - topP: requestConfig.topP, - stopSequences: requestConfig.stopSequences, responseMimeType: jsonMode ? 'application/json' : undefined, }; - if (requestConfig.responseModalities) { - // HACK: cast to any since this isn't officially supported in the old SDK yet - (generationConfig as any).responseModalities = - requestConfig.responseModalities; - } if (request.output?.constrained && jsonMode) { generationConfig.responseSchema = cleanSchema(request.output.schema); @@ -978,9 +977,9 @@ export function defineGoogleAIModel({ history: messages .slice(0, -1) .map((message) => toGeminiMessage(message, model)), - safetySettings: requestConfig.safetySettings, + safetySettings: safetySettingsFromConfig, } as StartChatParams; - const modelVersion = (request.config?.version || + const modelVersion = (versionFromConfig || model.version || apiModelName) as string; const cacheConfigDetails = extractCacheConfig(request); @@ -994,14 +993,14 @@ export function defineGoogleAIModel({ cacheConfigDetails ); - if (!requestConfig.apiKey && !apiKey) { + if (!apiKeyFromConfig && !apiKey) { throw new GenkitError({ status: 'INVALID_ARGUMENT', message: 'GoogleAI plugin was initialized with {apiKey: false} but no apiKey configuration was passed at call time.', }); } - const client = new GoogleGenerativeAI(requestConfig.apiKey || apiKey!); + const client = new GoogleGenerativeAI(apiKeyFromConfig || apiKey!); let genModel: GenerativeModel; if (cache) { diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 5eba25469..72bb8426a 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -25,6 +25,7 @@ import { GenerativeModelPreview, HarmBlockThreshold, HarmCategory, + SafetySetting, Schema, StartChatParams, ToolConfig, @@ -66,9 +67,25 @@ import { PluginOptions } from './common/types.js'; import { handleCacheIfNeeded } from './context-caching/index.js'; import { extractCacheConfig } from './context-caching/utils.js'; -const SafetySettingsSchema = z.object({ - category: z.nativeEnum(HarmCategory), - threshold: z.nativeEnum(HarmBlockThreshold), +export const SafetySettingsSchema = z.object({ + category: z.enum([ + /** The harm category is unspecified. */ + 'HARM_CATEGORY_UNSPECIFIED', + /** The harm category is hate speech. */ + 'HARM_CATEGORY_HATE_SPEECH', + /** The harm category is dangerous content. */ + 'HARM_CATEGORY_DANGEROUS_CONTENT', + /** The harm category is harassment. */ + 'HARM_CATEGORY_HARASSMENT', + /** The harm category is sexually explicit content. */ + 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + ]), + threshold: z.enum([ + 'BLOCK_LOW_AND_ABOVE', + 'BLOCK_MEDIUM_AND_ABOVE', + 'BLOCK_ONLY_HIGH', + 'BLOCK_NONE', + ]), }); const VertexRetrievalSchema = z.object({ @@ -250,7 +267,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ 'With NONE, the model is prohibited from making function calls.' ) .optional(), -}); +}).passthrough(); /** * Known model names, to allow code completion for convenience. Allows other model names. @@ -1019,17 +1036,29 @@ export function defineGeminiModel({ } } + const requestConfig = request.config as z.infer< + typeof GeminiConfigSchema + >; + const { + functionCallingConfig, + version: versionFromConfig, + googleSearchRetrieval, + vertexRetrieval, + location, // location can be overridden via config, take it out. + safetySettings, + ...restOfConfig + } = requestConfig; + const tools = request.tools?.length ? [{ functionDeclarations: request.tools.map(toGeminiTool) }] : []; let toolConfig: ToolConfig | undefined; - if (request?.config?.functionCallingConfig) { + if (functionCallingConfig) { toolConfig = { functionCallingConfig: { - allowedFunctionNames: - request.config.functionCallingConfig.allowedFunctionNames, - mode: toFunctionModeEnum(request.config.functionCallingConfig.mode), + allowedFunctionNames: functionCallingConfig.allowedFunctionNames, + mode: toFunctionModeEnum(functionCallingConfig.mode), }, }; } else if (request.toolChoice) { @@ -1053,19 +1082,15 @@ export function defineGeminiModel({ .slice(0, -1) .map((message) => toGeminiMessage(message, modelInfo)), generationConfig: { + ...restOfConfig, candidateCount: request.candidates || undefined, - temperature: request.config?.temperature, - maxOutputTokens: request.config?.maxOutputTokens, - topK: request.config?.topK, - topP: request.config?.topP, responseMimeType: jsonMode ? 'application/json' : undefined, - stopSequences: request.config?.stopSequences, }, - safetySettings: request.config?.safetySettings, + safetySettings: toGeminiSafetySettings(safetySettings), }; // Handle cache - const modelVersion = (request.config?.version || version) as string; + const modelVersion = (versionFromConfig || version) as string; const cacheConfigDetails = extractCacheConfig(request); const apiClient = new ApiClient( @@ -1092,15 +1117,13 @@ export function defineGeminiModel({ ); } - if (request.config?.googleSearchRetrieval) { + if (googleSearchRetrieval) { updatedChatRequest.tools?.push({ - googleSearchRetrieval: request.config - .googleSearchRetrieval as GoogleSearchRetrieval, + googleSearchRetrieval: googleSearchRetrieval as GoogleSearchRetrieval, }); } - if (request.config?.vertexRetrieval) { - const vertexRetrieval = request.config.vertexRetrieval; + if (vertexRetrieval) { const _projectId = vertexRetrieval.datastore.projectId || options.projectId; const _location = @@ -1247,6 +1270,18 @@ function toFunctionModeEnum( } } +function toGeminiSafetySettings( + genkitSettings?: z.infer[] +): SafetySetting[] | undefined { + if (!genkitSettings) return undefined; + return genkitSettings.map((s) => { + return { + category: s.category as HarmCategory, + threshold: s.threshold as HarmBlockThreshold, + }; + }); +} + /** Converts mode from genkit tool choice. */ function toGeminiFunctionModeEnum( genkitMode: 'auto' | 'required' | 'none' diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index 42c8187af..c1c64c0db 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -60,6 +60,7 @@ import { gemini25ProExp0325, gemini25ProPreview0325, GeminiConfigSchema, + SafetySettingsSchema, SUPPORTED_GEMINI_MODELS, type GeminiConfig, type GeminiVersionString, @@ -90,6 +91,7 @@ export { imagen3, imagen3Fast, multimodalEmbedding001, + SafetySettingsSchema, textEmbedding004, textEmbedding005, textEmbeddingGecko003,