Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions js/plugins/googleai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import {
z,
type EmbedderAction,
type EmbedderReference,
type Genkit,
} from 'genkit';
import { embedderRef } from 'genkit/embedder';
import { getApiKeyFromEnvVar } from './common.js';
import type { PluginOptions } from './index.js';
import { embedder } from 'genkit/plugin';
import { getApiKeyFromEnvVar } from './common';
import type { PluginOptions } from './index';

export const TaskTypeSchema = z.enum([
'RETRIEVAL_DOCUMENT',
Expand Down Expand Up @@ -103,7 +103,6 @@ export const SUPPORTED_MODELS = {
};

export function defineGoogleAIEmbedder(
ai: Genkit,
name: string,
pluginOptions: PluginOptions
): EmbedderAction<any> {
Expand All @@ -117,29 +116,29 @@ export function defineGoogleAIEmbedder(
'For more details see https://genkit.dev/docs/plugins/google-genai'
);
}
const embedder: EmbedderReference =
SUPPORTED_MODELS[name] ??
// In v2, plugin internals use UNPREFIXED action names.
const actionName = name;

const embedderReference: EmbedderReference =
SUPPORTED_MODELS[actionName] ??
embedderRef({
name: name,
name: actionName,
configSchema: GeminiEmbeddingConfigSchema,
info: {
dimensions: 768,
label: `Google AI - ${name}`,
label: `Google AI - ${actionName}`,
supports: {
input: ['text', 'image', 'video'],
},
},
});
const apiModelName = embedder.name.startsWith('googleai/')
? embedder.name.substring('googleai/'.length)
: embedder.name;
return ai.defineEmbedder(
return embedder(
{
name: embedder.name,
name: actionName,
configSchema: GeminiEmbeddingConfigSchema,
info: embedder.info!,
info: embedderReference.info!,
},
async (input, options) => {
async ({ input, options }) => {
if (pluginOptions.apiKey === false && !options?.apiKey) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
Expand All @@ -152,9 +151,9 @@ export function defineGoogleAIEmbedder(
).getGenerativeModel({
model:
options?.version ||
embedder.config?.version ||
embedder.version ||
apiModelName,
embedderReference.config?.version ||
embedderReference.version ||
actionName,
});
const embeddings = await Promise.all(
input.map(async (doc) => {
Expand Down
47 changes: 23 additions & 24 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import {
type ToolConfig,
type UsageMetadata,
} from '@google/generative-ai';
import { GenkitError, z, type Genkit, type JSONSchema } from 'genkit';
import { GenkitError, z, type JSONSchema } from 'genkit';
import {
GenerationCommonConfigDescriptions,
GenerationCommonConfigSchema,
Expand All @@ -57,6 +57,7 @@ import {
type ToolResponsePart,
} from 'genkit/model';
import { downloadRequestMedia } from 'genkit/model/middleware';
import { model } from 'genkit/plugin';
import { runInNewSpan } from 'genkit/tracing';
import { getApiKeyFromEnvVar, getGenkitClientHeader } from './common';
import { handleCacheIfNeeded } from './context-caching';
Expand Down Expand Up @@ -1118,7 +1119,6 @@ export function cleanSchema(schema: JSONSchema): JSONSchema {
* Defines a new GoogleAI model.
*/
export function defineGoogleAIModel({
ai,
name,
apiKey: apiKeyOption,
apiVersion,
Expand All @@ -1127,7 +1127,6 @@ export function defineGoogleAIModel({
defaultConfig,
debugTraces,
}: {
ai: Genkit;
name: string;
apiKey?: string | false;
apiVersion?: string;
Expand All @@ -1150,16 +1149,15 @@ export function defineGoogleAIModel({
}
}

const apiModelName = name.startsWith('googleai/')
? name.substring('googleai/'.length)
: name;
// In v2, plugin internals use UNPREFIXED action names.
const actionName = name;

const model: ModelReference<z.ZodTypeAny> =
SUPPORTED_GEMINI_MODELS[apiModelName] ??
const modelReference: ModelReference<z.ZodTypeAny> =
SUPPORTED_GEMINI_MODELS[actionName] ??
modelRef({
name: `googleai/${apiModelName}`,
name: actionName,
info: {
label: `Google AI - ${apiModelName}`,
label: `Google AI - ${actionName}`,
supports: {
multiturn: true,
media: true,
Expand All @@ -1173,7 +1171,7 @@ export function defineGoogleAIModel({
});

const middleware: ModelMiddleware[] = [];
if (model.info?.supports?.media) {
if (modelReference.info?.supports?.media) {
// the gemini api doesn't support downloading media from http(s)
middleware.push(
downloadRequestMedia({
Expand All @@ -1199,12 +1197,11 @@ export function defineGoogleAIModel({
);
}

return ai.defineModel(
return model(
{
apiVersion: 'v2',
name: model.name,
...model.info,
configSchema: model.configSchema,
name: actionName,
...modelReference.info,
configSchema: modelReference.configSchema,
use: middleware,
},
async (request, { streamingRequested, sendChunk, abortSignal }) => {
Expand All @@ -1228,7 +1225,7 @@ export function defineGoogleAIModel({
// systemInstructions to be provided as a separate input. The first
// message detected with role=system will be used for systemInstructions.
let systemInstruction: GeminiMessage | undefined = undefined;
if (model.info?.supports?.systemRole) {
if (modelReference.info?.supports?.systemRole) {
const systemMessage = messages.find((m) => m.role === 'system');
if (systemMessage) {
messages.splice(messages.indexOf(systemMessage), 1);
Expand Down Expand Up @@ -1306,7 +1303,10 @@ export function defineGoogleAIModel({
generationConfig.responseSchema = cleanSchema(request.output.schema);
}

const msg = toGeminiMessage(messages[messages.length - 1], model);
const msg = toGeminiMessage(
messages[messages.length - 1],
modelReference
);

const fromJSONModeScopedGeminiCandidate = (
candidate: GeminiCandidate
Expand All @@ -1321,12 +1321,12 @@ export function defineGoogleAIModel({
toolConfig,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
.map((message) => toGeminiMessage(message, modelReference)),
safetySettings: safetySettingsFromConfig,
} as StartChatParams;
const modelVersion = (versionFromConfig ||
model.version ||
apiModelName) as string;
modelReference.version ||
actionName) as string;
const cacheConfigDetails = extractCacheConfig(request);

const { chatRequest: updatedChatRequest, cache } =
Expand Down Expand Up @@ -1426,11 +1426,10 @@ export function defineGoogleAIModel({
};
};

// If debugTraces is enable, we wrap the actual model call with a span, add raw
// API params as for input.
// If debugTraces is enabled, we wrap the actual model call with a span, add raw
// API params as input.
return debugTraces
? await runInNewSpan(
ai.registry,
{
metadata: {
name: streamingRequested ? 'sendMessageStream' : 'sendMessage',
Expand Down
21 changes: 11 additions & 10 deletions js/plugins/googleai/src/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { GenkitError, MessageData, z, type Genkit } from 'genkit';
import { GenkitError, MessageData, z } from 'genkit';
import {
getBasicUsageStats,
modelRef,
Expand All @@ -23,6 +23,7 @@ import {
type ModelInfo,
type ModelReference,
} from 'genkit/model';
import { model } from 'genkit/plugin';
import { getApiKeyFromEnvVar } from './common.js';
import { predictModel } from './predict.js';

Expand Down Expand Up @@ -109,7 +110,6 @@ export const GENERIC_IMAGEN_INFO = {
} as ModelInfo;

export function defineImagenModel(
ai: Genkit,
name: string,
apiKey?: string | false
): ModelAction {
Expand All @@ -124,20 +124,21 @@ export function defineImagenModel(
});
}
}
const modelName = `googleai/${name}`;
const model: ModelReference<z.ZodTypeAny> = modelRef({
name: modelName,
// In v2, plugin internals use UNPREFIXED action names.
const actionName = name;
const modelReference: ModelReference<z.ZodTypeAny> = modelRef({
name: actionName,
info: {
...GENERIC_IMAGEN_INFO,
label: `Google AI - ${name}`,
label: `Google AI - ${actionName}`,
},
configSchema: ImagenConfigSchema,
});

return ai.defineModel(
return model(
{
name: modelName,
...model.info,
name: actionName,
...modelReference.info,
configSchema: ImagenConfigSchema,
},
async (request) => {
Expand All @@ -153,7 +154,7 @@ export function defineImagenModel(
ImagenInstance,
ImagenPrediction,
ImagenParameters
>(model.version || name, apiKey as string, 'predict');
>(modelReference.version || actionName, apiKey as string, 'predict');
const response = await predictClient([instance], toParameters(request));

if (!response.predictions || response.predictions.length == 0) {
Expand Down
Loading
Loading