Skip to content

Commit

Permalink
Refactor autocommit script to use the OpenAI API directly instead of …
Browse files Browse the repository at this point in the history
…langchain, and update token calculation methods to support new model types. Additionally, improve error handling and streamline the configuration checks for Azure OpenAI services.
  • Loading branch information
shanginn committed Jul 19, 2024
1 parent 54cc333 commit f7896f1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 127 deletions.
144 changes: 44 additions & 100 deletions autocommit.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
#!/usr/bin/env node
#!/usr/bin/env bun

import { execSync, spawn } from "child_process";
import rc from 'rc';
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate
} from "langchain/prompts";
import defaultConfig from './config.js';
import {ChatOpenAI} from "langchain/chat_models/openai";
import {getModelContextSize} from "./count_tokens.js";
import { OpenAI } from "openai";
import {calculateMaxTokens, getModelContextSize} from "./count_tokens.js";

const config = rc(
'git-aicommit',
Expand All @@ -22,26 +16,26 @@ const config = rc(
);

try {
execSync(
'git rev-parse --is-inside-work-tree',
{encoding: 'utf8', stdio: 'ignore'}
);
execSync(
'git rev-parse --is-inside-work-tree',
{encoding: 'utf8', stdio: 'ignore'}
);
} catch (e) {
console.error("This is not a git repository");
process.exit(1);
console.error("This is not a git repository");
process.exit(1);
}

if (!config.openAiKey && !config.azureOpenAiKey) {
console.error("Please set OPENAI_API_KEY or AZURE_OPENAI_API_KEY");
process.exit(1);
console.error("Please set OPENAI_API_KEY or AZURE_OPENAI_API_KEY");
process.exit(1);
}

// if any settings related to AZURE are set, if there are items that are not set, will error.
if (config.azureOpenAiKey && !(
config.azureOpenAiInstanceName && config.azureOpenAiDeploymentName && config.azureOpenAiVersion
)){
console.error("Please set AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_INSTANCE_NAME, AZURE_OPENAI_API_DEPLOYMENT_NAME, AZURE_OPENAI_API_VERSION when Azure OpenAI Service.");
process.exit(1);
console.error("Please set AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_INSTANCE_NAME, AZURE_OPENAI_API_DEPLOYMENT_NAME, AZURE_OPENAI_API_VERSION when Azure OpenAI Service.");
process.exit(1);
}

const excludeFromDiff = config.excludeFromDiff || [];
Expand All @@ -50,8 +44,8 @@ const diffCommand = `git diff --staged \
--no-ext-diff \
--diff-filter=${diffFilter} \
-- ${excludeFromDiff.map(
(pattern) => `':(exclude)${pattern}'`
).join(' ')}
(pattern) => `':(exclude)${pattern}'`
).join(' ')}
`;

let diff = execSync(diffCommand, {encoding: 'utf8'});
Expand All @@ -61,97 +55,47 @@ if (!diff) {
process.exit(1);
}

const openai = new ChatOpenAI({
modelName: config.modelName,
openAIApiKey: config.openAiKey,
azureOpenAIApiKey: config.azureOpenAiKey,
azureOpenAIApiInstanceName: config.azureOpenAiInstanceName,
azureOpenAIApiDeploymentName: config.azureOpenAiDeploymentName,
azureOpenAIApiVersion: config.azureOpenAiVersion,
temperature: config.temperature,
maxTokens: config.maxTokens,
const openai = new OpenAI({
apiKey: config.openAiKey,
baseURL: config.azureOpenAiKey ? `https://${config.azureOpenAiInstanceName}.openai.azure.com/openai/deployments/${config.azureOpenAiDeploymentName}` : undefined,
defaultHeaders: config.azureOpenAiKey ? { 'api-key': config.azureOpenAiKey } : undefined
});

const systemMessagePromptTemplate = SystemMessagePromptTemplate.fromTemplate(
config.systemMessagePromptTemplate
);
async function getChatCompletion(messages) {
const response = await openai.chat.completions.create({
model: config.modelName || 'gpt-4o-mini',
messages: messages,
temperature: config.temperature,
max_tokens: config.maxTokens,
});

const humanPromptTemplate = HumanMessagePromptTemplate.fromTemplate(
config.humanPromptTemplate
);
return response.choices[0].message.content.trim();
}

const systemMessage = { role: "system", content: config.systemMessagePromptTemplate };
const userMessage = { role: "user", content: config.humanPromptTemplate.replace("{diff}", diff).replace("{language}", config.language) };

const chatPrompt = ChatPromptTemplate.fromPromptMessages([
systemMessagePromptTemplate,
humanPromptTemplate,
]);
const chatMessages = [systemMessage, userMessage];

const chatMessages = await chatPrompt.formatMessages({
diff: diff,
language: config.language,
const tokenCount = await calculateMaxTokens({
prompt: diff,
modelName: config.modelName || 'gpt-4o-mini'
});

const tokenCount = (await openai.getNumTokensFromMessages(chatMessages)).totalCount
const contextSize = getModelContextSize(config.modelName)
const contextSize = getModelContextSize(config.modelName || 'gpt-4o-mini');

if (tokenCount > contextSize) {
console.log('Diff is too long. Splitting into multiple requests.')
// TODO: split smarter
const filenameRegex = /^a\/(.+?)\s+b\/(.+?)/;
const diffByFiles = diff
.split('diff ' + '--git ') // Wierd string concat in order to avoid splitting on this line when using autocommit in this repo :)
.filter((fileDiff) => fileDiff.length > 0)
.map((fileDiff) => {
const match = fileDiff.match(filenameRegex);
const filename = match ? match[1] : 'Unknown file';

const content = fileDiff
.replaceAll(filename, '')
.replaceAll('a/ b/\n', '')

return chatPrompt
.formatMessages({
diff: content,
language: config.language,
})
.then((prompt) => {
return openai.call(prompt)
.then((res) => {
return {
filename: filename,
changes: res.text.trim(),
}
})
.catch((e) => {
console.error(`Error during OpenAI request: ${e.message}`);
process.exit(1);
});
});
});

// wait for all promises to resolve
const mergeChanges = await Promise.all(diffByFiles);

diff = mergeChanges
.map((fileDiff) => {
return `diff --git ${fileDiff.filename}\n${fileDiff.changes}`

})
.join('\n\n')
}

const prompt = await chatPrompt.formatMessages({
diff: diff,
language: config.language,
})
console.log('Diff is too long. Please lower the amount of changes in the commit or switch to a model with bigger context size');

const res = await openai.call(prompt)
.catch((e) => {
console.error(`Error during OpenAI request: ${e.message}`);
process.exit(1);
});
process.exit(1);
}

const commitMessage = res.text.trim();
const messages = [
{ role: "system", content: config.systemMessagePromptTemplate },
{ role: "user", content: config.humanPromptTemplate.replace("{diff}", diff).replace("{language}", config.language) }
];

const commitMessage = await getChatCompletion(messages);

if (!config.autocommit) {
console.log(`Autocommit is disabled. Here is the message:\n ${commitMessage}`);
Expand Down
44 changes: 17 additions & 27 deletions count_tokens.js
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
// langchain/dist/base_language/count_tokens.js
export const getModelNameForTiktoken = (modelName) => {
if (modelName.startsWith("gpt-3.5-turbo-16k")) {
return "gpt-3.5-turbo-16k";
}

if (modelName.startsWith("gpt-3.5-turbo-")) {
return "gpt-3.5-turbo";
}

if (modelName.startsWith("gpt-4-32k-")) {
return "gpt-4-32k";
}

if (modelName.startsWith("gpt-4-")) {
return "gpt-4";
}
return modelName;
};
export const getEmbeddingContextSize = (modelName) => {
switch (modelName) {
case "text-embedding-ada-002":
return 8191;
default:
return 2046;

if (modelName.startsWith("gpt-4o-")) {
return "gpt-4o";
}
return modelName;
};


export const getModelContextSize = (modelName) => {
switch (getModelNameForTiktoken(modelName)) {
case "gpt-3.5-turbo-16k":
Expand All @@ -32,33 +32,24 @@ export const getModelContextSize = (modelName) => {
return 32768;
case "gpt-4":
return 8192;
case "text-davinci-003":
return 4097;
case "text-curie-001":
return 2048;
case "text-babbage-001":
return 2048;
case "text-ada-001":
return 2048;
case "code-davinci-002":
return 8000;
case "code-cushman-001":
return 2048;
case "gpt-4o":
return 128000;
default:
return 4097;
return 4096;
}
};

export const importTiktoken = async () => {
try {
const { encoding_for_model } = await import("@dqbd/tiktoken");
return { encoding_for_model };
}
catch (error) {
} catch (error) {
console.log(error);
return { encoding_for_model: null };
}
};
export const calculateMaxTokens = async ({ prompt, modelName, }) => {

export const calculateMaxTokens = async ({ prompt, modelName }) => {
const { encoding_for_model } = await importTiktoken();
// fallback to approximate calculation if tiktoken is not available
let numTokens = Math.ceil(prompt.length / 4);
Expand All @@ -69,8 +60,7 @@ export const calculateMaxTokens = async ({ prompt, modelName, }) => {
numTokens = tokenized.length;
encoding.free();
}
}
catch (error) {
} catch (error) {
console.warn("Failed to calculate number of tokens with tiktoken, falling back to approximate count", error);
}
const maxTokens = getModelContextSize(modelName);
Expand Down

0 comments on commit f7896f1

Please sign in to comment.