Skip to content

Commit

Permalink
Collect and report per translate token count. (#785)
Browse files Browse the repository at this point in the history
  • Loading branch information
curtisman authored Mar 5, 2025
1 parent 0c81e51 commit 4935a69
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 57 deletions.
5 changes: 5 additions & 0 deletions ts/packages/aiclient/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import { PromptSection, Result, TypeChatLanguageModel } from "typechat";
import { CompletionUsageStats } from "./openai";

/**
* Translation settings for Chat models
Expand Down Expand Up @@ -43,6 +44,8 @@ export type FunctionCallingResult = {
arguments: any;
};

export type CompleteUsageStatsCallback = (usage: CompletionUsageStats) => void;

/**
* A TypeChat language model with greater control on settings
*/
Expand All @@ -56,6 +59,7 @@ export interface ChatModel extends TypeChatLanguageModel {
*/
complete(
prompt: string | PromptSection[],
usageCallback?: CompleteUsageStatsCallback,
jsonSchema?: CompletionJsonSchema,
): Promise<Result<string>>;
}
Expand All @@ -68,6 +72,7 @@ export interface ChatModelWithStreaming extends ChatModel {
*/
completeStream(
prompt: string | PromptSection[],
usageCallback?: CompleteUsageStatsCallback,
jsonSchema?: CompletionJsonSchema,
): Promise<Result<AsyncIterableIterator<string>>>;
}
Expand Down
5 changes: 5 additions & 0 deletions ts/packages/aiclient/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
ImageModel,
ImageGeneration,
CompletionJsonSchema,
CompleteUsageStatsCallback,
} from "./models";
import { callApi, callJsonApi, FetchThrottler } from "./restClient";
import { getEnvSetting } from "./common";
Expand Down Expand Up @@ -468,6 +469,7 @@ function createAzureOpenAIChatModel(
}
async function complete(
prompt: string | PromptSection[],
usageCallback?: CompleteUsageStatsCallback,
jsonSchema?: CompletionJsonSchema,
): Promise<Result<string>> {
verifyPromptLength(settings, prompt);
Expand Down Expand Up @@ -516,6 +518,7 @@ function createAzureOpenAIChatModel(
}
// track token usage
TokenCounter.getInstance().add(data.usage, tags);
usageCallback?.(data.usage);
} catch {}

if (Array.isArray(jsonSchema)) {
Expand All @@ -542,6 +545,7 @@ function createAzureOpenAIChatModel(

async function completeStream(
prompt: string | PromptSection[],
usageCallback?: CompleteUsageStatsCallback,
jsonSchema?: CompletionJsonSchema,
): Promise<Result<AsyncIterableIterator<string>>> {
verifyPromptLength(settings, prompt);
Expand Down Expand Up @@ -657,6 +661,7 @@ function createAzureOpenAIChatModel(
data.usage,
tags,
);
usageCallback?.(data.usage);
} catch {}
}
}
Expand Down
68 changes: 64 additions & 4 deletions ts/packages/cli/src/commands/test/translate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {
import chalk from "chalk";
import fs from "node:fs";
import { getElapsedString } from "common-utils";
import { getChatModelNames } from "aiclient";
import { getChatModelNames, openai as ai } from "aiclient";

type TestResult = {
request: string;
Expand Down Expand Up @@ -86,8 +86,20 @@ const defaultAppAgentProviders = getDefaultAppAgentProviders(getInstanceDir());
const schemaNames = getSchemaNamesForActionConfigProvider(
await createActionConfigProvider(defaultAppAgentProviders),
);

const defaultRepeat = 5;

function addTokenUsage(
total: ai.CompletionUsageStats,
usage: ai.CompletionUsageStats,
) {
total.prompt_tokens += usage.prompt_tokens;
total.completion_tokens += usage.completion_tokens;
total.total_tokens += usage.total_tokens;
}

function getTokenUsageStr(usage: ai.CompletionUsageStats, count: number = 1) {
return `${Math.round(usage.prompt_tokens / count)}+${Math.round(usage.completion_tokens / count)}=${Math.round(usage.total_tokens / count)}`;
}
export default class TestTranslateCommand extends Command {
static args = {
files: Args.string({
Expand Down Expand Up @@ -327,6 +339,11 @@ export default class TestTranslateCommand extends Command {

let totalExecTime = 0;
let maxExecTime = 0;
const totalTokenUsage: ai.CompletionUsageStats = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
};
async function worker() {
const dispatcher = await createDispatcher("cli test translate", {
appAgentProviders: defaultAppAgentProviders,
Expand Down Expand Up @@ -367,20 +384,42 @@ export default class TestTranslateCommand extends Command {

let currentTotalExecTime = 0;
let currentMaxExecTime = 0;
const currentTokenUsage: ai.CompletionUsageStats = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
};
let maxTokenUsage: ai.CompletionUsageStats = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
};

for (let i = 0; i < repeat; i++) {
const time = performance.now();
const commandResult =
await dispatcher.processCommand(request);
const execTime = performance.now() - time;
currentMaxExecTime = Math.max(currentMaxExecTime, execTime);
currentTotalExecTime += execTime;

const tokenUsage = commandResult?.tokenUsage;
if (tokenUsage) {
addTokenUsage(currentTokenUsage, tokenUsage);
addTokenUsage(totalTokenUsage, tokenUsage);

if (
tokenUsage.total_tokens > maxTokenUsage.total_tokens
) {
maxTokenUsage = tokenUsage;
}
}
results.push(commandResult?.actions);
}

maxExecTime = Math.max(maxExecTime, currentMaxExecTime);
totalExecTime += currentTotalExecTime;

const timeStr = `${getElapsedString(currentTotalExecTime)} (${getElapsedString(currentTotalExecTime / repeat)}/call) Max: ${getElapsedString(currentMaxExecTime)}`;
const expected = results[0];
let failed = false;
for (let i = 1; i < results.length; i++) {
Expand Down Expand Up @@ -456,7 +495,25 @@ export default class TestTranslateCommand extends Command {
noActions++;
msg = "Passed (no actions)";
}
print(`${chalk.green(msg)} ${chalk.grey(timeStr)}`);
const timeStr =
repeat === 1
? getElapsedString(currentTotalExecTime)
: `${getElapsedString(currentTotalExecTime)} (${getElapsedString(currentTotalExecTime / repeat)}/call) Max: ${getElapsedString(currentMaxExecTime)}`;
const avgTokenStr = getTokenUsageStr(
currentTokenUsage,
repeat,
);
const maxTokenStr = getTokenUsageStr(maxTokenUsage);

const tokenStr =
repeat === 1
? `(Token: ${avgTokenStr})`
: avgTokenStr === maxTokenStr
? `(Token Avg: ${avgTokenStr})`
: `(Token Avg: ${avgTokenStr} Max: ${maxTokenStr})`;
print(
`${chalk.green(msg)} ${chalk.grey(timeStr)} ${tokenStr}`,
);
}
}
await dispatcher.close();
Expand Down Expand Up @@ -515,5 +572,8 @@ export default class TestTranslateCommand extends Command {
console.log(
`Execution Time: ${getElapsedString(totalExecTime)}, Avg: ${getElapsedString(executionTimePerRequest)} (${getElapsedString(executionTimePerCall)}/call) Max: ${getElapsedString(maxExecTime)}`,
);
console.log(
`Token Usage: ${getTokenUsageStr(totalTokenUsage)}, Avg per call: ${getTokenUsageStr(totalTokenUsage, processed * repeat)}`,
);
}
}
97 changes: 61 additions & 36 deletions ts/packages/commonUtils/src/jsonTranslator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import {
import { createTypeScriptJsonValidator } from "typechat/ts";
import { TypeChatConstraintsValidator } from "./constraints.js";
import registerDebug from "debug";
import { openai as ai, CompletionJsonSchema } from "aiclient";
import {
openai as ai,
CompleteUsageStatsCallback,
CompletionJsonSchema,
} from "aiclient";
import {
createIncrementalJsonParser,
IncrementalJsonParser,
Expand Down Expand Up @@ -59,57 +63,68 @@ export interface TypeChatJsonTranslatorWithStreaming<T extends object>
translate: (
request: string,
promptPreamble?: string | PromptSection[],
cb?: IncrementalJsonValueCallBack,
attachments?: CachedImageWithDetails[] | undefined,
cb?: IncrementalJsonValueCallBack,
usageCallback?: CompleteUsageStatsCallback,
) => Promise<Result<T>>;
}

// This rely on the fact that the prompt preamble based to typechat are copied to the final prompt.
// Add a internal section so we can pass information from the caller to the model.complete function.
type StreamingSection = {
role: "streaming";
content: IncrementalJsonParser;
type ModelParamSection = {
role: "model";
content: {
parser: IncrementalJsonParser | undefined;
usageCallback: CompleteUsageStatsCallback | undefined;
};
};

function initializeStreamingParser(
function addModelParamSection(
promptPreamble?: string | PromptSection[],
cb?: IncrementalJsonValueCallBack,
usageCallback?: CompleteUsageStatsCallback,
) {
if (cb === undefined) {
if (cb === undefined && usageCallback === undefined) {
return promptPreamble;
}
const prompts: (PromptSection | StreamingSection)[] =
const prompts: (PromptSection | ModelParamSection)[] =
typeof promptPreamble === "string"
? [{ role: "user", content: promptPreamble }]
: promptPreamble
? [...promptPreamble] // Make a copy so that we don't modify the original array
: [];
const parser = createIncrementalJsonParser(cb, {
partial: true,
});
const parser = cb
? createIncrementalJsonParser(cb, {
partial: true,
})
: undefined;
prompts.unshift({
role: "streaming",
content: parser,
role: "model",
content: {
parser,
usageCallback,
},
});

return prompts as PromptSection[];
}

function getStreamingParser(
prompt: string | ReadonlyArray<PromptSection | StreamingSection>,
function getModelParams(
prompt: string | ReadonlyArray<PromptSection | ModelParamSection>,
) {
if (typeof prompt === "string") {
return undefined;
}
const internalIndex = prompt.findIndex((p) => p.role === "streaming");
const internalIndex = prompt.findIndex((p) => p.role === "model");
if (internalIndex === -1) {
return undefined;
}
// Make a copy so that we don't modify the original array;
const newPrompt = [...prompt];
const internal = newPrompt.splice(internalIndex, 1) as [StreamingSection];
const internal = newPrompt.splice(internalIndex, 1) as [ModelParamSection];
return {
parser: internal[0].content,
parser: internal[0].content.parser,
usageCallback: internal[0].content.usageCallback,
actualPrompt: newPrompt as PromptSection[],
};
}
Expand All @@ -122,15 +137,18 @@ export function enableJsonTranslatorStreaming<T extends object>(
throw new Error("Model does not support streaming");
}

const originalComplete = model.complete;
const originalComplete = model.complete.bind(model);
model.complete = async (prompt: string | PromptSection[]) => {
const streamingParser = getStreamingParser(prompt);
if (streamingParser === undefined) {
const modelParams = getModelParams(prompt);
if (modelParams === undefined) {
return originalComplete(prompt);
}
const { parser, actualPrompt } = streamingParser;
const { parser, usageCallback, actualPrompt } = modelParams;
if (parser === undefined) {
return originalComplete(actualPrompt, usageCallback);
}
const chunks = [];
const result = await model.completeStream(actualPrompt);
const result = await model.completeStream(actualPrompt, usageCallback);
if (!result.success) {
return result;
}
Expand All @@ -142,19 +160,20 @@ export function enableJsonTranslatorStreaming<T extends object>(
return success(chunks.join(""));
};

const originalTranslate = translator.translate;
const originalTranslate = translator.translate.bind(translator);
const translatorWithStreaming =
translator as TypeChatJsonTranslatorWithStreaming<T>;
translatorWithStreaming.translate = async (
request: string,
promptPreamble?: string | PromptSection[],
cb?: IncrementalJsonValueCallBack,
attachments?: CachedImageWithDetails[],
cb?: IncrementalJsonValueCallBack,
usageCallback?: CompleteUsageStatsCallback,
) => {
await attachAttachments(attachments, promptPreamble);
return originalTranslate(
request,
initializeStreamingParser(promptPreamble, cb),
addModelParamSection(promptPreamble, cb, usageCallback),
);
};

Expand Down Expand Up @@ -246,25 +265,31 @@ export function createJsonTranslatorWithValidator<T extends object>(
`typeagent:translate:${name}:jsonschema`,
);
const debugResult = registerDebug(`typeagent:translate:${name}:result`);
const complete = model.complete.bind(model);
model.complete = async (prompt: string | PromptSection[]) => {
const originalComplete = model.complete.bind(model);
model.complete = async (
prompt: string | PromptSection[],
usageCallback?: CompleteUsageStatsCallback,
) => {
debugPrompt(prompt);
const jsonSchema = validator.getJsonSchema?.();
if (jsonSchema !== undefined) {
debugJsonSchema(jsonSchema);
}
return complete(prompt, jsonSchema);
return originalComplete(prompt, usageCallback, jsonSchema);
};

if (ai.supportsStreaming(model)) {
const completeStream = model.completeStream.bind(model);
model.completeStream = async (prompt: string | PromptSection[]) => {
const originalCompleteStream = model.completeStream.bind(model);
model.completeStream = async (
prompt: string | PromptSection[],
usageCallback?: CompleteUsageStatsCallback,
) => {
debugPrompt(prompt);
const jsonSchema = validator.getJsonSchema?.();
if (jsonSchema !== undefined) {
debugJsonSchema(jsonSchema);
}
return completeStream(prompt, jsonSchema);
return originalCompleteStream(prompt, usageCallback, jsonSchema);
};
}

Expand All @@ -288,12 +313,12 @@ export function createJsonTranslatorWithValidator<T extends object>(
return;
}

const streamingParser = getStreamingParser(prompt);
if (streamingParser === undefined) {
const parser = getModelParams(prompt)?.parser;
if (parser === undefined) {
return;
}
const callback = streamingParser.parser.callback;
streamingParser.parser.callback = Array.isArray(jsonSchema)
const callback = parser.callback;
parser.callback = Array.isArray(jsonSchema)
? (prop, value, delta) => {
let actualPropName = "actionName";
if (prop !== "name") {
Expand Down
Loading

0 comments on commit 4935a69

Please sign in to comment.