Skip to content
42 changes: 16 additions & 26 deletions packages/firebase_ai/firebase_ai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -172,25 +172,6 @@ final class UsageMetadata {
final List<ModalityTokenCount>? candidatesTokensDetails;
}

/// Constructe a UsageMetadata with all it's fields.
///
/// Expose access to the private constructor for use within the package..
UsageMetadata createUsageMetadata({
required int? promptTokenCount,
required int? candidatesTokenCount,
required int? totalTokenCount,
required int? thoughtsTokenCount,
required List<ModalityTokenCount>? promptTokensDetails,
required List<ModalityTokenCount>? candidatesTokensDetails,
}) =>
UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);

/// Response candidate generated from a [GenerativeModel].
final class Candidate {
// TODO: token count?
Expand Down Expand Up @@ -1128,7 +1109,7 @@ final class VertexSerialization implements SerializationStrategy {
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
parseUsageMetadata(usageMetadata),
{'totalTokens': final int totalTokens} =>
UsageMetadata._(totalTokenCount: totalTokens),
_ => null,
Expand Down Expand Up @@ -1258,7 +1239,10 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
/// Parses a UsageMetadata from a JSON object.
///
/// Expose access to the private helper for use within the package.
UsageMetadata parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw unhandledFormat('UsageMetadata', jsonObject);
}
Expand All @@ -1275,6 +1259,10 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final thoughtsTokenCount = switch (jsonObject) {
{'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount,
_ => null,
};
final promptTokensDetails = switch (jsonObject) {
{'promptTokensDetails': final List<Object?> promptTokensDetails} =>
promptTokensDetails.map(_parseModalityTokenCount).toList(),
Expand All @@ -1286,11 +1274,13 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
_ => null,
};
return UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails);
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: promptTokensDetails,
candidatesTokensDetails: candidatesTokensDetails,
);
}

ModalityTokenCount _parseModalityTokenCount(Object? jsonObject) {
Expand Down
36 changes: 2 additions & 34 deletions packages/firebase_ai/firebase_ai/lib/src/developer/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ import '../api.dart'
SafetyRating,
SafetySetting,
SerializationStrategy,
UsageMetadata,
createUsageMetadata;
parseUsageMetadata;
import '../content.dart'
show Content, FunctionCall, InlineDataPart, Part, TextPart;
import '../error.dart';
Expand Down Expand Up @@ -116,7 +115,7 @@ final class DeveloperSerialization implements SerializationStrategy {
};
final usageMedata = switch (jsonObject) {
{'usageMetadata': final usageMetadata?} =>
_parseUsageMetadata(usageMetadata),
parseUsageMetadata(usageMetadata),
_ => null,
};
return GenerateContentResponse(candidates, promptFeedback,
Expand Down Expand Up @@ -230,37 +229,6 @@ PromptFeedback _parsePromptFeedback(Object jsonObject) {
};
}

UsageMetadata _parseUsageMetadata(Object jsonObject) {
if (jsonObject is! Map<String, Object?>) {
throw unhandledFormat('UsageMetadata', jsonObject);
}
final promptTokenCount = switch (jsonObject) {
{'promptTokenCount': final int promptTokenCount} => promptTokenCount,
_ => null,
};
final candidatesTokenCount = switch (jsonObject) {
{'candidatesTokenCount': final int candidatesTokenCount} =>
candidatesTokenCount,
_ => null,
};
final totalTokenCount = switch (jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
final thoughtsTokenCount = switch (jsonObject) {
{'thoughtsTokenCount': final int thoughtsTokenCount} => thoughtsTokenCount,
_ => null,
};
return createUsageMetadata(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount,
thoughtsTokenCount: thoughtsTokenCount,
promptTokensDetails: null,
candidatesTokensDetails: null,
);
}

SafetyRating _parseSafetyRating(Object? jsonObject) {
return switch (jsonObject) {
{
Expand Down
34 changes: 34 additions & 0 deletions packages/firebase_ai/firebase_ai/test/api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,40 @@ void main() {
expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1));
});

group('usageMetadata parsing', () {
test('parses usageMetadata when thoughtsTokenCount is set', () {
final json = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 20,
'totalTokenCount': 30,
'thoughtsTokenCount': 5,
}
};
final response =
VertexSerialization().parseGenerateContentResponse(json);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 20);
expect(response.usageMetadata!.totalTokenCount, 30);
expect(response.usageMetadata!.thoughtsTokenCount, 5);
});

test('parses usageMetadata when thoughtsTokenCount is missing', () {
final json = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 20,
'totalTokenCount': 30,
}
};
final response =
VertexSerialization().parseGenerateContentResponse(json);
expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.thoughtsTokenCount, isNull);
});
});

group('groundingMetadata parsing', () {
test('parses valid response with full grounding metadata', () {
final jsonResponse = {
Expand Down
41 changes: 41 additions & 0 deletions packages/firebase_ai/firebase_ai/test/developer_api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ void main() {
'candidatesTokenCount': 5,
'totalTokenCount': 15,
'thoughtsTokenCount': 3,
'promptTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 10}
],
'candidatesTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 25}
],
}
};
final response =
Expand All @@ -48,6 +54,15 @@ void main() {
expect(response.usageMetadata!.candidatesTokenCount, 5);
expect(response.usageMetadata!.totalTokenCount, 15);
expect(response.usageMetadata!.thoughtsTokenCount, 3);
expect(response.usageMetadata!.promptTokensDetails, isNotNull);
expect(response.usageMetadata!.promptTokensDetails, hasLength(1));
expect(
response.usageMetadata!.promptTokensDetails!.first.tokenCount, 10);
expect(response.usageMetadata!.candidatesTokensDetails, isNotNull);
expect(response.usageMetadata!.candidatesTokensDetails, hasLength(1));
expect(
response.usageMetadata!.candidatesTokensDetails!.first.tokenCount,
25);
});

test('parses usageMetadata when thoughtsTokenCount is missing', () {
Expand All @@ -68,6 +83,12 @@ void main() {
'candidatesTokenCount': 5,
'totalTokenCount': 15,
// thoughtsTokenCount is missing
'promptTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 10}
],
'candidatesTokensDetails': [
{'modality': 'TEXT', 'tokenCount': 25}
],
}
};
final response =
Expand Down Expand Up @@ -126,6 +147,26 @@ void main() {
expect(response.usageMetadata, isNull);
});

test('parses usageMetadata when token details are missing', () {
final jsonResponse = {
'usageMetadata': {
'promptTokenCount': 10,
'candidatesTokenCount': 25,
'totalTokenCount': 35,
}
};

final response =
DeveloperSerialization().parseGenerateContentResponse(jsonResponse);

expect(response.usageMetadata, isNotNull);
expect(response.usageMetadata!.promptTokenCount, 10);
expect(response.usageMetadata!.candidatesTokenCount, 25);
expect(response.usageMetadata!.totalTokenCount, 35);
expect(response.usageMetadata!.promptTokensDetails, isNull);
expect(response.usageMetadata!.candidatesTokensDetails, isNull);
});

test('parses inlineData part correctly', () {
final inlineData = Uint8List.fromList([1, 2, 3, 4]);
final jsonResponse = {
Expand Down
Loading