Skip to content

Commit

Permalink
Simplified the code calling Replicate. Closes #5
Browse files Browse the repository at this point in the history
  • Loading branch information
dsebastien committed Aug 19, 2024
1 parent ea5ff07 commit 72e8d5b
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 150 deletions.
24 changes: 20 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,27 @@ Future: text generation.
- Replicate.com API Key: the Replicate.com API key to use
- Copy output to clipboard: if you want the generated output to be automatically copied to the clipboard

### Image generation
### Image generation model

- Image generation model: the name of the image generation model to use (Replicate parameter: model)
- Image generation model version: if you want to enforce using a specific model version (Replicate parameter: version)
- Image generation model configuration: JSON object to pass as input to the model. This varies depending on the chosen model and is documented on Replicate.com's website
Image generation model: the name of the image generation model to use, either with or without the version.

The two possible forms are:

- `<model_owner>/<model_name>`
- `<model_owner>/<model_name>:<version>`

Examples:

- black-forest-labs/flux-dev
- black-forest-labs/flux-dev:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa

The advantage of specifying the version is that you can lock the model to a specific version, which is useful if you want to ensure that the output remains consistent over time. If you don't specify the version, the latest version will be used.

You can find the existing versions here using the method described here: https://replicate.com/docs/reference/http#list-model-versions

### Image generation model configuration

A JSON object to pass as input to the image generation model. This varies depending on the chosen model and is documented on Replicate's website

## News & support

Expand Down
11 changes: 0 additions & 11 deletions apps/plugin/src/app/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,6 @@ export class ReplicatePlugin extends Plugin {
needToSaveSettings = true;
}

if (loadedSettings.imageGenerationModelVersion) {
draft.imageGenerationModelVersion =
loadedSettings.imageGenerationModelVersion;
} else {
log(
'The loaded settings miss the [imageGenerationModelVersion] property',
'debug'
);
needToSaveSettings = true;
}

if (loadedSettings.imageGenerationConfiguration) {
draft.imageGenerationConfiguration =
loadedSettings.imageGenerationConfiguration;
Expand Down
28 changes: 1 addition & 27 deletions apps/plugin/src/app/settingTab/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ export class SettingsTab extends PluginSettingTab {
imageGenerationSettingsGroup.setHeading();

this.renderImageGenerationModel(containerEl);
this.renderImageGenerationModelVersion(containerEl);
this.renderImageGenerationModelConfiguration(containerEl);

this.renderFollowButton(containerEl);
Expand Down Expand Up @@ -98,32 +97,7 @@ export class SettingsTab extends PluginSettingTab {
this.plugin.settings = produce(
this.plugin.settings,
(draft: Draft<PluginSettings>) => {
draft.imageGenerationModel = newValue;
}
);
await this.plugin.saveSettings();
});
});
}

renderImageGenerationModelVersion(containerEl: HTMLElement) {
new Setting(containerEl)
.setName('Image generation model version (optional)')
.setDesc('The version of the image generation model to use.')
.addText((text) => {
text
.setPlaceholder('')
.setValue(
this.plugin.settings.imageGenerationModelVersion
? this.plugin.settings.imageGenerationModelVersion
: ''
)
.onChange(async (newValue) => {
log(`Image generation model version set to: `, 'debug', newValue);
this.plugin.settings = produce(
this.plugin.settings,
(draft: Draft<PluginSettings>) => {
draft.imageGenerationModelVersion = newValue;
draft.imageGenerationModel = newValue as `${string}/${string}`; // FIXME is this ok?
}
);
await this.plugin.saveSettings();
Expand Down
7 changes: 4 additions & 3 deletions apps/plugin/src/app/types/plugin-settings.intf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ export interface PluginSettings {
copyOutputToClipboard: boolean;

// Image Generation
imageGenerationModel: string;
imageGenerationModelVersion?: string;
imageGenerationModel: `${string}/${string}` | `${string}/${string}:${string}`;
imageGenerationConfiguration: object;
}

Expand All @@ -14,7 +13,9 @@ export const DEFAULT_SETTINGS: PluginSettings = {
apiKey: '',
copyOutputToClipboard: true,

// Image Generation
// Image Generation model
// Form 1: <model_owner>/<model_name>
// Form 2: <model_owner>/<model_name>:<version>
// black-forest-labs/flux-pro
// black-forest-labs/flux-dev
// black-forest-labs/flux-schnell
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export interface ReplicateRunModelConfiguration {
input: object;
wait?: {
interval?: number;
};
webhook?: string;
signal?: AbortSignal;
}
129 changes: 34 additions & 95 deletions apps/plugin/src/app/utils/generate-images.fn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import {
} from '../constants';
import { isApiKeyConfigured } from './is-api-key-configured.fn';
import { isImageGenerationModelConfigured } from './is-image-generation-model-configured.fn';
import { ReplicateCreatePrediction } from './replicate-create-prediction-input.intf';
import { getReplicateClient } from './get-replicate-client.fn';
import { ReplicateRunModelConfiguration } from '../types/replicate-run-model-configuration.intf';

export const generateImages = async (
prompt: string | undefined,
Expand Down Expand Up @@ -45,115 +45,54 @@ export const generateImages = async (
NOTICE_TIMEOUT
);

try {
const replicateCreatePredictionConfiguration: ReplicateCreatePrediction = {
model: settings.imageGenerationModel,
// Model configuration
input: {
...settings.imageGenerationConfiguration,
prompt, // FIXME ensure that the prompt is the one we expect in the request
},
};

if (settings.imageGenerationModelVersion) {
if ('' !== settings.imageGenerationModelVersion.trim()) {
replicateCreatePredictionConfiguration.version =
settings.imageGenerationModelVersion;
}
}
const replicateRunModelConfiguration: ReplicateRunModelConfiguration = {
// Model configuration
input: {
...settings.imageGenerationConfiguration,
prompt,
},
};

try {
log(
'Sending image generation request to Replicate.com',
'Sending image generation request to Replicate.com. Configuration: ',
'debug',
replicateCreatePredictionConfiguration
replicateRunModelConfiguration
);

let predictionResult = await replicate.predictions.create(
replicateCreatePredictionConfiguration
const output = await replicate.run(
settings.imageGenerationModel,
replicateRunModelConfiguration
);

if (predictionResult.error) {
log('Error received from Replicate.com', 'warn', predictionResult.error);
new Notice(
`${MSG_IMAGE_GENERATION_ERROR}: [${predictionResult.error}]`,
NOTICE_TIMEOUT
);
if (!output) {
log('Failed to generate images using Replicate.com', 'warn');
new Notice(MSG_IMAGE_GENERATION_ERROR, NOTICE_TIMEOUT);
return;
}

while (
predictionResult.status !== 'succeeded' &&
predictionResult.status !== 'failed'
) {
await sleep(1000);

log('Loading the image generation results from Replicate.com', 'debug');
predictionResult = await replicate.predictions.get(predictionResult.id);
log('Received response from Replicate.com', 'debug', predictionResult);

if (predictionResult?.error) {
log(
'Error received from Replicate',
'warn',
predictionResult.error.detail
);
new Notice(
`${MSG_IMAGE_GENERATION_ERROR}: [${predictionResult.error.detail}]`,
NOTICE_TIMEOUT
);
return;
}

if (predictionResult.status === 'failed') {
log('Failed to load the results from Replicate', 'warn');
new Notice(MSG_IMAGE_GENERATION_ERROR, NOTICE_TIMEOUT);
return;
}

if (predictionResult.error) {
log(
'Error received from Replicate while loading the results',
'warn',
predictionResult.error
);
new Notice(
`${MSG_IMAGE_GENERATION_ERROR}: [${predictionResult.error}]`,
NOTICE_TIMEOUT
);
return;
}

if (predictionResult.status === 'succeeded') {
log(
'Successfully loaded the results from Replicate',
'debug',
predictionResult
);

let result = '';

if (Array.isArray(predictionResult.output)) {
result = predictionResult.output.join('\n');
} else {
result = predictionResult.output;
}
let result = '';

log('Image generation result: ', 'debug', result);
if (Array.isArray(output)) {
result = output.join('\n');
} else {
result = JSON.stringify(output); // FIXME is this ok?
}

if (settings.copyOutputToClipboard) {
try {
await navigator.clipboard.writeText(result);
} catch (_) {
// Ignore errors (can occur if DevTools are open)
}
}
log('Image generation result: ', 'debug', result);

new Notice(
`Successfully generated image(s) using Replicate.com: [${result}]`,
NOTICE_TIMEOUT
);
if (settings.copyOutputToClipboard) {
try {
await navigator.clipboard.writeText(result);
} catch (_) {
// Ignore errors (can occur if DevTools are open)
}
}

new Notice(
`Successfully generated image(s) using Replicate.com: [${result}]`,
NOTICE_TIMEOUT
);
} catch (error) {
log('Error while generating image(s) using Replicate.com', 'warn', error);
new Notice(`${MSG_IMAGE_GENERATION_ERROR}: [${error}]`, NOTICE_TIMEOUT);
Expand Down

This file was deleted.

0 comments on commit 72e8d5b

Please sign in to comment.