-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #139 from edgenai/feat/image-generation
Feat/image generation
- Loading branch information
Showing
18 changed files
with
1,998 additions
and
122 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use serde::Serialize; | ||
use std::path::PathBuf; | ||
use thiserror::Error; | ||
|
||
pub struct ImageGenerationArgs { | ||
pub prompt: String, | ||
pub uncond_prompt: String, | ||
pub width: Option<usize>, | ||
pub height: Option<usize>, | ||
pub steps: usize, | ||
pub images: u32, | ||
pub seed: Option<u64>, | ||
pub guidance_scale: f64, | ||
pub vae_scale: f64, | ||
} | ||
|
||
pub struct ModelFiles { | ||
pub tokenizer: PathBuf, | ||
pub clip_weights: PathBuf, | ||
pub clip2_weights: Option<PathBuf>, | ||
pub vae_weights: PathBuf, | ||
pub unet_weights: PathBuf, | ||
} | ||
|
||
#[derive(Serialize, Error, Debug)] | ||
pub enum ImageGenerationEndpointError { | ||
#[error("Could not load model: {0}")] | ||
Load(String), | ||
#[error("Failed to tokenize prompts: {0}")] | ||
Decoding(String), | ||
#[error("Failed to generate image: {0}")] | ||
Generation(String), | ||
#[error("Could not convert the output tensor into an encoded image")] | ||
Encoding(String), | ||
} | ||
|
||
#[async_trait::async_trait] | ||
pub trait ImageGenerationEndpoint { | ||
async fn generate_image( | ||
&self, | ||
model: ModelFiles, | ||
args: ImageGenerationArgs, | ||
) -> Result<Vec<Vec<u8>>, ImageGenerationEndpointError>; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
[package] | ||
name = "edgen_rt_image_generation_candle" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
async-trait = { workspace = true } | ||
candle-core = "0.4.1" | ||
candle-transformers = "0.4.1" | ||
edgen_core = { path = "../edgen_core" } | ||
image = "0.25.1" | ||
rand = "0.8.5" | ||
thiserror = { workspace = true } | ||
# https://github.com/huggingface/tokenizers/issues/1454 | ||
tokenizers = { version = "0.19.1", default-features = false, features = ["progressbar", "onig"] } | ||
tokio = { workspace = true, features = ["sync", "rt", "fs"] } | ||
tracing = { workspace = true } | ||
|
||
[features] | ||
cuda = ["candle-core/cuda", "candle-transformers/cuda"] |
Oops, something went wrong.