diff --git a/Cargo.lock b/Cargo.lock index 69d8e5c..ac7adb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1446,6 +1446,7 @@ dependencies = [ "serde_derive", "serde_json", "serde_yaml", + "tempfile", "testcontainers", "thiserror", "time", diff --git a/crates/edgen_core/src/settings.rs b/crates/edgen_core/src/settings.rs index 51826ba..f4b6f0e 100644 --- a/crates/edgen_core/src/settings.rs +++ b/crates/edgen_core/src/settings.rs @@ -95,6 +95,30 @@ fn build_config_file_path() -> PathBuf { config_dir.join(Path::new(&filename)) } +/// Helper to get the chat completions model directory. +pub async fn chat_completions_dir() -> String { + SETTINGS + .read() + .await + .read() + .await + .chat_completions_models_dir + .trim() + .to_string() +} + +/// Helper to get the audio transcriptions model directory. +pub async fn audio_transcriptions_dir() -> String { + SETTINGS + .read() + .await + .read() + .await + .audio_transcriptions_models_dir + .trim() + .to_string() +} + #[derive(Error, Debug, Serialize)] pub enum SettingsError { #[error("failed to read the settings file: {0}")] diff --git a/crates/edgen_server/Cargo.toml b/crates/edgen_server/Cargo.toml index 6ed02d4..cf55f54 100644 --- a/crates/edgen_server/Cargo.toml +++ b/crates/edgen_server/Cargo.toml @@ -43,3 +43,4 @@ testcontainers = "0.15.0" [dev-dependencies] levenshtein = "1.0.5" +tempfile = { workspace = true } diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index ccfe1c2..88b057f 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -20,7 +20,6 @@ use std::process::exit; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use axum::Router; use tower_http::cors::CorsLayer; use futures::executor::block_on; @@ -43,7 +42,9 @@ pub mod error; pub mod graceful_shutdown; mod llm; mod model; +pub mod model_man; pub mod openai_shim; +mod routes; pub mod status; pub mod util; mod whisper; @@ -189,32 +190,7 @@ async fn run_server(args: &cli::Serve) -> Result { ) .await; - let http_app = Router::new() - // -- AI endpoints ----------------------------------------------------- - // ---- Chat ----------------------------------------------------------- - .route( - "/v1/chat/completions", - axum::routing::post(openai_shim::chat_completions), - ) - // ---- Audio ---------------------------------------------------------- - .route( - "/v1/audio/transcriptions", - axum::routing::post(openai_shim::create_transcription), - ) - // -- AI status endpoints ---------------------------------------------- - // ---- Chat ----------------------------------------------------------- - .route( - "/v1/chat/completions/status", - axum::routing::get(status::chat_completions_status), - ) - // ---- Audio ---------------------------------------------------------- - .route( - "/v1/audio/transcriptions/status", - axum::routing::get(status::audio_transcriptions_status), - ) - // -- Miscellaneous services ------------------------------------------- - .route("/v1/misc/version", axum::routing::get(misc::edgen_version)) - .layer(CorsLayer::permissive()); + let http_app = routes::routes().layer(CorsLayer::permissive()); let uri_vector = if !args.uri.is_empty() { info!("Overriding default URI"); diff --git a/crates/edgen_server/src/misc.rs b/crates/edgen_server/src/misc.rs index 7e5aa77..c041c46 100644 --- a/crates/edgen_server/src/misc.rs +++ b/crates/edgen_server/src/misc.rs @@ -36,7 +36,7 @@ pub struct Version { build: String, } -/// GET `/v1/version`: returns the current version of edgend. +/// GET `/v1/misc/version`: returns the current version of edgend. /// /// The version is returned as json value with major, minor and patch as integer /// and build as string (which may be empty). diff --git a/crates/edgen_server/src/model_man.rs b/crates/edgen_server/src/model_man.rs new file mode 100644 index 0000000..4dc8ca6 --- /dev/null +++ b/crates/edgen_server/src/model_man.rs @@ -0,0 +1,539 @@ +/* Copyright 2023- The Binedge, Lda team. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Model Manager Endpoints + +use std::fmt; +use std::fmt::Display; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, SystemTimeError}; + +use axum::extract; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Json, Response}; +use serde::{Deserialize, Serialize}; +use thiserror; +use tracing::warn; +use utoipa::ToSchema; + +use edgen_core::settings; + +/// GET `/v1/models`: returns a list of model descriptors for all models in all model directories. +/// +/// For any error, the endpoint returns "internal server error". +pub async fn list_models() -> Response { + match list_all_models().await { + Ok(v) => Json(v).into_response(), + Err(e) => internal_server_error(&format!("model manager: cannot list models: {:?}", e)), + } +} + +/// GET `/v1/models{:id}`: returns the model descriptor for the model indicated by 'id'. +/// +/// For any error, the endpoint returns "internal server error". +pub async fn retrieve_model(extract::Path(id): extract::Path) -> Response { + match model_id_to_desc(&id).await { + Ok(d) => Json(d).into_response(), + Err(e) => { + internal_server_error(&format!("model manager: cannot get model {}: {:?}", id, e)) + } + } +} + +/// DELETE `/v1/models{:id}`: deletes the model indicated by 'id'. +/// +/// For any error, the endpoint returns "internal server error". +pub async fn delete_model(extract::Path(id): extract::Path) -> Response { + match remove_model(&id).await { + Ok(d) => Json(d).into_response(), + Err(e) => internal_server_error(&format!( + "model manager: cannot delete model {}: {:?}", + id, e + )), + } +} + +fn internal_server_error(msg: &str) -> Response { + warn!("{}", msg); + StatusCode::INTERNAL_SERVER_ERROR.into_response() +} + +/// Model Descriptor +#[derive(ToSchema, Deserialize, Serialize, Debug, PartialEq, Eq)] +pub struct ModelDesc { + /// model Id + pub id: String, + /// when the file was created + pub created: u64, + /// object type, always 'model' + pub object: String, + /// repo owner + pub owned_by: String, +} + +/// Model Deletion Status +#[derive(ToSchema, Deserialize, Serialize, Debug, PartialEq, Eq)] +pub struct ModelDeletionStatus { + /// model Id + pub id: String, + /// object type, always 'model' + pub object: String, + /// repo owner + pub deleted: bool, +} + +#[derive(Debug, thiserror::Error)] +enum PathError { + Generic(String), + ModelNotFound, + ParseError(#[from] ParseError), + IOError(#[from] std::io::Error), + TimeError(#[from] SystemTimeError), +} + +impl Display for PathError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +async fn list_all_models() -> Result, PathError> { + let completions_dir = settings::chat_completions_dir().await; + let transcriptions_dir = settings::audio_transcriptions_dir().await; + + let mut v = vec![]; + + list_models_in_dir(Path::new(&completions_dir), &mut v).await?; + list_models_in_dir(Path::new(&transcriptions_dir), &mut v).await?; + Ok(v) +} + +async fn list_models_in_dir(path: &Path, v: &mut Vec) -> Result<(), PathError> { + let es = tokio::fs::read_dir(path).await; + if es.is_err() { + warn!("model manager: cannot read directory {:?} ({:?})", path, es); + return Err(PathError::IOError(es.unwrap_err())); + }; + let mut es = es.unwrap(); + loop { + let e = es.next_entry().await; + if e.is_err() { + warn!("model manager: cannot get entry: {:?}", e); + break; + } + let tmp = e.unwrap(); + if tmp.is_none() { + break; + } + let tmp = tmp.unwrap(); + match path_to_model_desc(tmp.path().as_path()).await { + Ok(m) => v.push(m), + Err(e) => { + warn!( + "model manager: invalid entry in directory {:?}: {:?}", + path, e + ); + } + } + } + Ok(()) +} + +async fn model_id_to_desc(id: &str) -> Result { + let path = search_model(id).await?; + path_to_model_desc(path.as_path()).await +} + +async fn search_model(id: &str) -> Result { + let model = model_id_to_path(id)?; + let dir = settings::chat_completions_dir().await; + let path = Path::new(&dir).join(&model); + if path.is_dir() { + return Ok(path); + } + let dir = settings::audio_transcriptions_dir().await; + let path = Path::new(&dir).join(&model); + if path.is_dir() { + return Ok(path); + } + Err(PathError::ModelNotFound) +} + +async fn remove_model(id: &str) -> Result { + let model = search_model(id).await?; + let _ = tokio::fs::remove_dir_all(model).await?; + Ok(ModelDeletionStatus { + id: id.to_string(), + object: "model".to_string(), + deleted: true, + }) +} + +async fn path_to_model_desc(path: &Path) -> Result { + let f = path + .file_name() + .ok_or(PathError::Generic("empty path".to_string()))?; + let model = f + .to_str() + .ok_or(PathError::Generic("invalid file name".to_string()))?; + let (owner, repo) = parse_path(model)?; + let metadata = tokio::fs::metadata(path).await?; + if !metadata.is_dir() { + return Err(PathError::Generic("not a directory".to_string())); + }; + let tp = match metadata.created() { + Ok(n) => n, + Err(_) => SystemTime::UNIX_EPOCH, // unknown + }; + + let created = tp.duration_since(SystemTime::UNIX_EPOCH)?.as_secs(); + + Ok(ModelDesc { + id: to_model_id(&owner, &repo), + created: created, + object: "model".to_string(), + owned_by: owner.to_string(), + }) +} + +fn to_model_id(owner: &str, repo: &str) -> String { + format!("{}/{}", owner, repo) +} + +fn model_id_to_path(id: &str) -> Result { + let (owner, repo) = parse_model_id(id)?; + let s = format!("models--{}--{}", owner, repo); + Ok(PathBuf::from(s)) +} + +#[derive(Debug, PartialEq, thiserror::Error)] +enum ParseError { + MissingSeparator, + NotaModel, + NoOwner, + NoRepo, +} + +impl Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +fn parse_model_id(id: &str) -> Result<(String, String), ParseError> { + let vs = id.split("/").collect::>(); + if vs.len() < 2 { + return Err(ParseError::MissingSeparator); + } + + let owner = vs[0].to_string(); + if owner.is_empty() { + return Err(ParseError::NoOwner); + } + + let repo = if vs.len() > 2 { + vs[1..].join("/") + } else { + vs[1].to_string() + }; + if repo.is_empty() { + return Err(ParseError::NoRepo); + } + + Ok((owner, repo)) +} + +fn parse_path(model_string: &str) -> Result<(String, String), ParseError> { + let vs = model_string.split("--").collect::>(); + + if vs.len() < 3 { + return Err(ParseError::MissingSeparator); + } + + if vs[0] != "models" { + return Err(ParseError::NotaModel); + } + + // the owner is always the second + // if the original owner contained double dashes + // we won't find him + let owner = vs[1].to_string(); + if owner.is_empty() { + return Err(ParseError::NoOwner); + } + + let repo = if vs.len() > 3 { + vs[2..].join("--") + } else { + vs[2].to_string() + }; + if repo.is_empty() { + return Err(ParseError::NoRepo); + } + + Ok((owner, repo)) +} + +#[cfg(test)] +mod test { + use super::*; + use std::time::SystemTime; + + use tempfile; + + // --- Parse Model Id ------------------------------------------------------------------------- + #[test] + fn parse_simple_model_id_valid() { + assert_eq!( + parse_model_id("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"), + Ok(( + "TheBloke".to_string(), + "TinyLlama-1.1B-Chat-v1.0-GGUF".to_string(), + )) + ); + } + + #[test] + fn parse_model_id_slashes_in_repo() { + assert_eq!( + parse_model_id("TheBloke/TinyLlama/1.1B/Chat/v1.0-GGUF"), + Ok(( + "TheBloke".to_string(), + "TinyLlama/1.1B/Chat/v1.0-GGUF".to_string(), + )) + ); + } + + #[test] + fn parse_model_id_slashes_in_owner_valid() { + assert_eq!( + parse_model_id("The/Bloke/TinyLlama-1.1B-Chat-v1.0-GGUF"), + Ok(( + "The".to_string(), + "Bloke/TinyLlama-1.1B-Chat-v1.0-GGUF".to_string(), + )) + ); + } + + #[test] + fn fail_model_id_slashes_in_owner_valid() { + assert_ne!( + parse_model_id("The/Bloke/TinyLlama-1.1B-Chat-v1.0-GGUF"), + Ok(( + "TheBloke".to_string(), + "TinyLlama-1.1B-Chat-v1.0-GGUF".to_string(), + )) + ); + } + + #[test] + fn fail_model_id_no_slashes_between_owner_and_repo() { + assert_eq!( + parse_model_id("The-Bloke-TinyLlama-1.1B-Chat-v1.0-GGUF"), + Err(ParseError::MissingSeparator) + ); + } + + #[test] + fn fail_model_id_no_slashes_after_owner() { + assert_eq!( + parse_model_id("The-Bloke"), + Err(ParseError::MissingSeparator) + ); + } + + #[test] + fn fail_model_id_no_repo() { + assert_eq!(parse_model_id("The-Bloke/"), Err(ParseError::NoRepo)); + } + + #[test] + fn fail_model_id_no_owner() { + assert_eq!( + parse_model_id("/The-Bloke-TinyLlama-1.1B-Chat-v1.0-GGUF"), + Err(ParseError::NoOwner) + ); + } + + #[test] + fn fail_model_id_nothing() { + assert_eq!(parse_model_id("/"), Err(ParseError::NoOwner)); + } + + #[test] + fn fail_model_id_even_less() { + assert_eq!(parse_model_id(""), Err(ParseError::MissingSeparator)); + } + + // --- Parse Model Entry ---------------------------------------------------------------------- + #[test] + fn parse_path_simple_valid() { + assert_eq!( + parse_path("models--TheBloke--TinyLlama-1.1B-Chat-v1.0-GGUF"), + Ok(( + "TheBloke".to_string(), + "TinyLlama-1.1B-Chat-v1.0-GGUF".to_string(), + )) + ); + } + + #[test] + fn parse_path_dashes_in_repo_valid() { + assert_eq!( + parse_path("models--TheBloke--TinyLlama--1.1B--Chat--v1.0--GGUF"), + Ok(( + "TheBloke".to_string(), + "TinyLlama--1.1B--Chat--v1.0--GGUF".to_string(), + )) + ); + } + + #[test] + fn parse_path_dashes_in_owner_valid() { + assert_eq!( + parse_path("models--The--Bloke--TinyLlama--1.1B--Chat--v1.0--GGUF"), + Ok(( + "The".to_string(), + "Bloke--TinyLlama--1.1B--Chat--v1.0--GGUF".to_string(), + )) + ); + } + + #[test] + fn fail_path_dashes_in_owner() { + assert_ne!( + parse_path("models--The--Bloke--TinyLlama--1.1B--Chat--v1.0--GGUF"), + Ok(( + "TheBloke".to_string(), + "TinyLlama--1.1B--Chat--v1.0--GGUF".to_string(), + )) + ); + } + + #[test] + fn fail_path_does_not_start_with_model() { + assert_eq!( + parse_path("datasets--TheBloke--TinyLlama-1.1B-Chat-v1.0-GGUF"), + Err(ParseError::NotaModel) + ); + } + + #[test] + fn fail_path_no_dashes_between_owner_and_repo() { + assert_eq!( + parse_path("models--TheBloke-TinyLlama-1.1B-Chat-v1.0-GGUF"), + Err(ParseError::MissingSeparator) + ); + } + + #[test] + fn fail_path_no_dashes_after_owner() { + assert_eq!( + parse_path("models--TheBloke"), + Err(ParseError::MissingSeparator) + ); + } + + #[test] + fn fail_path_no_repo() { + assert_eq!(parse_path("models--TheBloke--"), Err(ParseError::NoRepo)); + } + + #[test] + fn fail_path_no_owner() { + assert_eq!(parse_path("models----"), Err(ParseError::NoOwner)); + } + + #[test] + fn fail_path_no_model() { + assert_eq!( + parse_path("--TheBlock--whatever"), + Err(ParseError::NotaModel) + ); + } + + #[test] + fn fail_path_nothing() { + assert_eq!(parse_path(""), Err(ParseError::MissingSeparator)); + } + + // --- Roundtrip ------------------------------------------------------------------------------ + #[test] + fn simple_roundtrip() { + let paths = vec![ + "models--TheBloke--TinyLlama-1.1B-Chat-v1.0-GGUF", + "models--The--Bloke--TinyLlama--1.1B--Chat--v1.0--GGUF", + "models--TheBloke--TinyLlama--1.1B--Chat--v1.0--GGUF", + ]; + for path in paths.into_iter() { + let (owner, repo) = parse_path(path).unwrap(); + let id = to_model_id(&owner, &repo); + let pb = model_id_to_path(&id).unwrap(); + let round = pb.as_path().to_str().unwrap(); + assert_eq!(path, round); + } + } + + // --- path to desc --------------------------------------------------------------------------- + #[tokio::test] + async fn test_list_models_in_dir() { + let bloke = "TheBloke"; + let the = "The"; + let r1 = "TinyLlama-1.1B-Chat-v1.0-GGUF"; + let r2 = "Bloke--TinyLlama-1.1B-Chat-v1.0-GGUF"; + let r3 = "TinyLlama--1.1B--Chat--v1.0--GGUF"; + let f1 = format!("models--{}--{}", bloke, r1); + let f2 = format!("models--{}--{}", the, r2); + let f3 = format!("models--{}--{}", bloke, r3); + let f4 = "invisible".to_string(); + let f5 = "models--TheBlokeInvisible".to_string(); + let f6 = "tmp".to_string(); + + let temp = tempfile::tempdir().expect("cannot create tempfile"); + + let recent = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + - 2; // careful with leap seconds + + std::fs::create_dir(temp.path().join(&f1)).expect(&format!("cannot create dir {:?}", f1)); + std::fs::create_dir(temp.path().join(&f2)).expect(&format!("cannot create dir {:?}", f2)); + std::fs::create_dir(temp.path().join(&f3)).expect(&format!("cannot create dir {:?}", f3)); + std::fs::create_dir(temp.path().join(&f4)).expect(&format!("cannot create dir {:?}", f4)); + std::fs::create_dir(temp.path().join(&f5)).expect(&format!("cannot create dir {:?}", f5)); + std::fs::create_dir(temp.path().join(&f6)).expect(&format!("cannot create dir {:?}", f6)); + + let mut v = vec![]; + + let _ = list_models_in_dir(temp.path(), &mut v) + .await + .expect("cannot list directory"); + + assert_eq!(v.len(), 3); + + println!("recent is {}", recent); + for m in v { + assert_eq!(m.object, "model"); + if m.owned_by != the { + assert_eq!(m.owned_by, bloke); + } + if m.id != format!("{}/{}", bloke, r1) && m.id != format!("{}/{}", bloke, r3) { + assert_eq!(m.id, format!("{}/{}", the, r2)); + } + println!("{:?}", m); + + let d = m.created.checked_sub(recent).unwrap(); + assert!(d <= 3); + } + } +} diff --git a/crates/edgen_server/src/routes.rs b/crates/edgen_server/src/routes.rs new file mode 100644 index 0000000..687c32b --- /dev/null +++ b/crates/edgen_server/src/routes.rs @@ -0,0 +1,58 @@ +/* Copyright 2023- The Binedge, Lda team. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Contains all routes served by Edgen + +use axum::Router; + +use crate::misc; +use crate::model_man; +use crate::openai_shim; +use crate::status; + +pub fn routes() -> Router { + Router::new() + // -- AI endpoints ----------------------------------------------------- + // ---- Chat ----------------------------------------------------------- + .route( + "/v1/chat/completions", + axum::routing::post(openai_shim::chat_completions), + ) + // ---- Audio ---------------------------------------------------------- + .route( + "/v1/audio/transcriptions", + axum::routing::post(openai_shim::create_transcription), + ) + // -- AI status endpoints ---------------------------------------------- + // ---- Chat ----------------------------------------------------------- + .route( + "/v1/chat/completions/status", + axum::routing::get(status::chat_completions_status), + ) + // ---- Audio ---------------------------------------------------------- + .route( + "/v1/audio/transcriptions/status", + axum::routing::get(status::audio_transcriptions_status), + ) + // -- Model Manager ---------------------------------------------------- + .route("/v1/models", axum::routing::get(model_man::list_models)) + .route( + "/v1/models/:model", + axum::routing::get(model_man::retrieve_model), + ) + .route( + "/v1/models/:model", + axum::routing::delete(model_man::delete_model), + ) + // -- Miscellaneous services ------------------------------------------- + .route("/v1/misc/version", axum::routing::get(misc::edgen_version)) +} diff --git a/crates/edgen_server/tests/common/mod.rs b/crates/edgen_server/tests/common/mod.rs index cc01f49..e1ac337 100644 --- a/crates/edgen_server/tests/common/mod.rs +++ b/crates/edgen_server/tests/common/mod.rs @@ -32,6 +32,7 @@ pub const TRANSCRIPTIONS_URL: &str = "/transcriptions"; pub const STATUS_URL: &str = "/status"; pub const MISC_URL: &str = "/misc"; pub const VERSION_URL: &str = "/version"; +pub const MODELS_URL: &str = "/models"; pub const CHAT_COMPLETIONS_BODY: &str = r#" { @@ -51,6 +52,7 @@ pub const CHAT_COMPLETIONS_BODY: &str = r#" "#; pub const BACKUP_DIR: &str = "env_backup"; +pub const CONFIG_BACKUP_DIR: &str = "config_backup"; pub const MY_MODEL_FILES: &str = "my_models"; #[derive(Debug, PartialEq, Eq)] @@ -69,8 +71,8 @@ impl Display for Endpoint { } } -// Backup environment (config and model directories) before running 'f'; -// restore environment, even if 'f' panicks. +/// Backup environment (config and model directories) before running 'f'; +/// restore environment, even if 'f' panicks. pub fn with_save_env(f: F) where F: FnOnce() + panic::UnwindSafe, @@ -102,7 +104,40 @@ where } } -// Start edgen before running 'f' +/// Backup config only before running 'f'; +/// restore config, even if 'f' panicks. +pub fn with_save_config(f: F) +where + F: FnOnce() + panic::UnwindSafe, +{ + println!("with save config!"); + + backup_config().unwrap(); + + println!("=============="); + println!("STARTING TESTS"); + println!("=============="); + + let r = panic::catch_unwind(f); + + println!("==========="); + println!("TESTS READY"); + println!("==========="); + + let _ = match restore_config() { + Ok(_) => (), + Err(e) => { + panic!("Panic! Cannot restore your config: {:?}", e); + } + }; + + match r { + Err(e) => panic::resume_unwind(e), + Ok(_) => (), + } +} + +/// Start edgen before running 'f' pub fn with_edgen(f: F) where F: FnOnce() + panic::UnwindSafe, @@ -123,9 +158,9 @@ where f(); } -// Backup environment (config and model directories) -// and start edgen before running 'f'; -// restore environment, even if 'f' or edgen panick. +/// Backup environment (config and model directories) +/// and start edgen before running 'f'; +/// restore environment, even if 'f' or edgen panick. pub fn with_save_edgen(f: F) where F: FnOnce() + panic::UnwindSafe, @@ -135,10 +170,27 @@ where }); } +/// Backup config directory +/// and start edgen before running 'f'; +/// restore config, even if 'f' or edgen panick. +pub fn with_save_config_edgen(f: F) +where + F: FnOnce() + panic::UnwindSafe, +{ + with_save_config(|| { + with_edgen(f); + }); +} + pub fn test_message(msg: &str) { println!("=== Test {}", msg); } +pub fn pass_always() { + test_message("pass always"); + assert!(true); +} + pub fn make_url(v: &[&str]) -> String { let mut s = "".to_string(); for e in v { @@ -191,8 +243,111 @@ pub fn reset_config() { edgen_server::config_reset().unwrap(); } -// spawn a thread to send a request to the indicated endpoint. -// This allows the caller to perform another task in the caller thread. +pub fn config_exists() { + test_message("config exists"); + assert!(settings::PROJECT_DIRS.config_dir().exists()); + assert!(settings::CONFIG_FILE.exists()); +} + +pub fn data_exists() { + test_message("data exists"); + let data = settings::PROJECT_DIRS.data_dir(); + println!("exists: {:?}", data); + assert!(data.exists()); + + let models = data.join("models"); + println!("exists: {:?}", models); + assert!(models.exists()); + + let chat = models.join("chat"); + println!("exists: {:?}", chat); + assert!(models.exists()); + + let completions = chat.join("completions"); + println!("exists: {:?}", completions); + assert!(completions.exists()); + + let audio = models.join("audio"); + println!("exists: {:?}", audio); + assert!(audio.exists()); + + let transcriptions = audio.join("transcriptions"); + println!("exists: {:?}", transcriptions); + assert!(transcriptions.exists()); +} + +/// Edit the config file: set another model dir for the indicated endpoint. +pub fn set_model_dir(ep: Endpoint, model_dir: &str) { + test_message(&format!("set {} model directory to {}", ep, model_dir,)); + + let mut config = get_config().unwrap(); + + match &ep { + Endpoint::ChatCompletions => { + config.chat_completions_models_dir = model_dir.to_string(); + } + Endpoint::AudioTranscriptions => { + config.audio_transcriptions_models_dir = model_dir.to_string(); + } + } + write_config(&config).unwrap(); + + println!("pausing for 4 secs to make sure the config file has been updated"); + std::thread::sleep(std::time::Duration::from_secs(4)); +} + +/// Edit the config file: set another model name and repo for the indicated endpoint. +/// Use the status endpoint to check whether the model was updated. +pub fn set_model(ep: Endpoint, model_name: &str, model_repo: &str) { + test_message(&format!("set {} model to {}", ep, model_name,)); + + let mut config = get_config().unwrap(); + + match &ep { + Endpoint::ChatCompletions => { + config.chat_completions_model_name = model_name.to_string(); + config.chat_completions_model_repo = model_repo.to_string(); + } + Endpoint::AudioTranscriptions => { + config.audio_transcriptions_model_name = model_name.to_string(); + config.audio_transcriptions_model_repo = model_repo.to_string(); + } + } + write_config(&config).unwrap(); + + println!("pausing for 4 secs to make sure the config file has been updated"); + std::thread::sleep(std::time::Duration::from_secs(4)); + + let url = match ep { + Endpoint::ChatCompletions => make_url(&[BASE_URL, CHAT_URL, COMPLETIONS_URL, STATUS_URL]), + Endpoint::AudioTranscriptions => { + make_url(&[BASE_URL, AUDIO_URL, TRANSCRIPTIONS_URL, STATUS_URL]) + } + }; + let stat: status::AIStatus = blocking::get(url).unwrap().json().unwrap(); + assert_eq!(stat.active_model, model_name); +} + +/// Exercise the edgen version endpoint to make sure the server is reachable. +pub fn connect_to_server_test() { + test_message("connect to server"); + assert!( + match blocking::get(make_url(&[BASE_URL, MISC_URL, VERSION_URL])) { + Err(e) => { + eprintln!("cannot connect: {:?}", e); + false + } + Ok(v) => { + assert!(v.status().is_success()); + println!("have: '{}'", v.text().unwrap()); + true + } + } + ); +} + +/// Spawn a thread to send a request to the indicated endpoint. +/// This allows the caller to perform another task in the caller thread. pub fn spawn_request(ep: Endpoint, body: String) -> thread::JoinHandle { match ep { Endpoint::ChatCompletions => spawn_chat_completions_request(body), @@ -257,7 +412,7 @@ pub fn spawn_audio_transcriptions_request() -> thread::JoinHandle { }) } -// Assert that a download is ongoing and download progress is reported. +/// Assert that a download is ongoing and download progress is reported. pub fn assert_download(endpoint: &str) { println!("requesting status of {}", endpoint); @@ -290,7 +445,7 @@ pub fn assert_download(endpoint: &str) { assert_eq!(stat.download_progress, 100); } -// Assert that *no* download is ongoing. +/// Assert that *no* download is ongoing. pub fn assert_no_download(endpoint: &str) { println!("requesting status of {}", endpoint); @@ -328,6 +483,7 @@ impl From> for BackupError { } } +// backup environment: config and data fn backup_env() -> Result<(), BackupError> { println!("backing up"); @@ -373,6 +529,7 @@ fn backup_env() -> Result<(), BackupError> { Ok(()) } +// restore environment: config and data fn restore_env() -> Result<(), io::Error> { println!("restoring"); @@ -411,3 +568,61 @@ fn restore_env() -> Result<(), io::Error> { Ok(()) } + +fn backup_config() -> Result<(), BackupError> { + println!("backing up"); + + let backup_dir = Path::new(CONFIG_BACKUP_DIR); + if backup_dir.exists() { + let msg = format!( + "directory {} exists! + This means an earlier test run did not finish correctly. \ + Restore your environment manually.", + CONFIG_BACKUP_DIR, + ); + eprintln!("{}", msg); + return Err(BackupError::Unfinished); + } + + println!("config dir: {:?}", settings::PROJECT_DIRS.config_dir()); + + fs::create_dir(&backup_dir)?; + + let cnfg = settings::PROJECT_DIRS.config_dir(); + let cnfg_bkp = backup_dir.join("config"); + + if cnfg.exists() { + println!("config bkp: {:?}", cnfg_bkp); + copy_dir(&cnfg, &cnfg_bkp)?; + fs::remove_dir_all(&cnfg)?; + } else { + println!("config {:?} does not exist", cnfg); + } + + Ok(()) +} + +fn restore_config() -> Result<(), io::Error> { + println!("restoring"); + + let backup_dir = Path::new(CONFIG_BACKUP_DIR); + + let cnfg = settings::PROJECT_DIRS.config_dir(); + let cnfg_bkp = backup_dir.join("config"); + + if cnfg.exists() { + fs::remove_dir_all(&cnfg)?; + } + + if cnfg_bkp.exists() { + println!("{:?} -> {:?}", cnfg_bkp, cnfg); + copy_dir(&cnfg_bkp, &cnfg)?; + } else { + println!("config bkp {:?} does not exist", cnfg_bkp); + } + + println!("removing {:?}", backup_dir); + fs::remove_dir_all(&backup_dir)?; + + Ok(()) +} diff --git a/crates/edgen_server/tests/modelmanager_tests.rs b/crates/edgen_server/tests/modelmanager_tests.rs new file mode 100644 index 0000000..0394388 --- /dev/null +++ b/crates/edgen_server/tests/modelmanager_tests.rs @@ -0,0 +1,189 @@ +use std::path; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; + +use futures::executor::block_on; +use reqwest::blocking; + +use edgen_core::settings; +use edgen_server::model_man::{ModelDeletionStatus, ModelDesc}; + +#[allow(dead_code)] +mod common; + +#[test] +fn test_modelmanager() { + common::with_save_config_edgen(|| { + common::pass_always(); + + common::config_exists(); + + common::connect_to_server_test(); + + let my_models_dir = format!( + "{}{}{}", + common::CONFIG_BACKUP_DIR, + path::MAIN_SEPARATOR, + common::MY_MODEL_FILES, + ); + + let new_chat_completions_dir = my_models_dir.clone() + + &format!( + "{}{}{}{}", + path::MAIN_SEPARATOR, + "chat", + path::MAIN_SEPARATOR, + "completions", + ); + + let new_audio_transcriptions_dir = my_models_dir.clone() + + &format!( + "{}{}{}{}", + path::MAIN_SEPARATOR, + "audio", + path::MAIN_SEPARATOR, + "transcriptions", + ); + + common::set_model_dir(common::Endpoint::ChatCompletions, &new_chat_completions_dir); + + common::set_model_dir( + common::Endpoint::AudioTranscriptions, + &new_audio_transcriptions_dir, + ); + + make_dirs(); + + test_list_models(); + test_delete_model(); + }) +} + +// actually create the model dirs before using them +fn make_dirs() { + let dir = block_on(async { settings::chat_completions_dir().await }); + std::fs::create_dir_all(&dir).expect("cannot create chat completions model dir"); + + assert!(PathBuf::from(&dir).exists()); + + let dir = block_on(async { settings::audio_transcriptions_dir().await }); + std::fs::create_dir_all(&dir).expect("cannot create audio transcriptions model dir"); + + assert!(PathBuf::from(&dir).exists()); +} + +fn test_list_models() { + common::test_message("list models"); + + let bloke = "TheBloke"; + let the = "The"; + let r1 = "TinyLlama-1.1B-Chat-v1.0-GGUF"; + let r2 = "Bloke--TinyLlama-1.1B-Chat-v1.0-GGUF"; + let r3 = "TinyLlama--1.1B--Chat--v1.0--GGUF"; + let f1 = format!("models--{}--{}", bloke, r1); + let f2 = format!("models--{}--{}", the, r2); + let f3 = format!("models--{}--{}", bloke, r3); + let f4 = "invisible".to_string(); + let f5 = "models--TheBlokeInvisible".to_string(); + let f6 = "tmp".to_string(); + + let recent = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + - 2; // careful with leap seconds + + let dir1 = block_on(async { settings::audio_transcriptions_dir().await }); + let dir1 = PathBuf::from(&dir1); + + let dir2 = block_on(async { settings::chat_completions_dir().await }); + let dir2 = PathBuf::from(&dir2); + + std::fs::create_dir(dir1.join(&f1)).expect(&format!("cannot create dir {:?}", f1)); + std::fs::create_dir(dir2.join(&f2)).expect(&format!("cannot create dir {:?}", f2)); + std::fs::create_dir(dir1.join(&f3)).expect(&format!("cannot create dir {:?}", f3)); + std::fs::create_dir(dir1.join(&f4)).expect(&format!("cannot create dir {:?}", f4)); + std::fs::create_dir(dir2.join(&f5)).expect(&format!("cannot create dir {:?}", f5)); + std::fs::create_dir(dir1.join(&f6)).expect(&format!("cannot create dir {:?}", f6)); + + // --- get model descriptor + let res = blocking::get(common::make_url(&[common::BASE_URL, common::MODELS_URL])) + .expect("models get endpoint failed"); + assert!(res.status().is_success(), "models failed"); + let v: Vec = res.json().expect("cannot convert to model descs"); + + assert_eq!(v.len(), 3); + + println!("recent is {}", recent); + for m in v { + assert_eq!(m.object, "model"); + if m.owned_by != the { + assert_eq!(m.owned_by, bloke); + } + if m.id != format!("{}/{}", bloke, r1) && m.id != format!("{}/{}", bloke, r3) { + assert_eq!(m.id, format!("{}/{}", the, r2)); + } + println!("{:?}", m); + + let d = m.created.checked_sub(recent).unwrap(); + assert!(d <= 3); + } +} + +fn test_delete_model() { + common::test_message("delete model"); + + let owner = "TheFaker"; + let repo = "my-faked-model-v1-GGUF"; + let model = format!("models--{}--{}", owner, repo); + let id = format!("{}/{}", owner, repo); + let id_url = format!("{}%2f{}", owner, repo); + + let recent = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + - 2; // careful with leap seconds + + let dir = block_on(async { settings::chat_completions_dir().await }); + + let dir = Path::new(&dir).join(&model); + std::fs::create_dir(&dir).expect(&format!("cannot create model {:?}", dir)); + + // --- get model descriptor + let res = blocking::get(common::make_url(&[ + common::BASE_URL, + common::MODELS_URL, + "/", + &id_url, + ])) + .expect("models get endpoint failed"); + + assert!(res.status().is_success(), "models failed"); + let m: ModelDesc = res.json().expect("cannot convert to model desc"); + + println!("model descriptor: {:?}", m); + assert_eq!(m.object, "model"); + assert_eq!(m.owned_by, owner); + assert_eq!(m.id, id); + let d = m.created.checked_sub(recent).unwrap(); + assert!(d <= 3); + + // --- delete model + println!("delete model"); + let res = blocking::Client::new() + .delete(common::make_url(&[ + common::BASE_URL, + common::MODELS_URL, + "/", + &id_url, + ])) + .send() + .expect("models delete endpoint failed"); + + assert!(res.status().is_success()); + let m: ModelDeletionStatus = res.json().expect("cannot convert to model deletion status"); + assert!(m.deleted); + assert_eq!(m.id, id); + assert!(!dir.exists()); +} diff --git a/crates/edgen_server/tests/settings_tests.rs b/crates/edgen_server/tests/settings_tests.rs index f0d74da..4dbb942 100644 --- a/crates/edgen_server/tests/settings_tests.rs +++ b/crates/edgen_server/tests/settings_tests.rs @@ -3,9 +3,7 @@ use std::path; use reqwest::blocking; -use edgen_core::settings; -use edgen_server::status; - +#[allow(dead_code)] mod common; #[test] @@ -48,16 +46,16 @@ fn fake_test() { fn test_battery() { common::with_save_edgen(|| { // make sure everything is right - pass_always(); + common::pass_always(); // ================================ common::test_message("SCENARIO 1"); // ================================ - config_exists(); - data_exists(); + common::config_exists(); + common::data_exists(); // endpoints reachable - connect_to_server_test(); + common::connect_to_server_test(); chat_completions_status_reachable(); audio_transcriptions_status_reachable(); @@ -66,12 +64,12 @@ fn test_battery() { common::test_message("SCENARIO 2"); // ================================ // set small models, so we don't need to download too much - set_model( + common::set_model( common::Endpoint::ChatCompletions, common::SMALL_LLM_NAME, common::SMALL_LLM_REPO, ); - set_model( + common::set_model( common::Endpoint::AudioTranscriptions, common::SMALL_WHISPER_NAME, common::SMALL_WHISPER_REPO, @@ -113,9 +111,9 @@ fn test_battery() { "transcriptions", ); - set_model_dir(common::Endpoint::ChatCompletions, &new_chat_completions_dir); + common::set_model_dir(common::Endpoint::ChatCompletions, &new_chat_completions_dir); - set_model_dir( + common::set_model_dir( common::Endpoint::AudioTranscriptions, &new_audio_transcriptions_dir, ); @@ -147,12 +145,12 @@ fn test_battery() { // ================================ test_config_reset(); - set_model( + common::set_model( common::Endpoint::ChatCompletions, common::SMALL_LLM_NAME, common::SMALL_LLM_REPO, ); - set_model( + common::set_model( common::Endpoint::AudioTranscriptions, common::SMALL_WHISPER_NAME, common::SMALL_WHISPER_REPO, @@ -167,63 +165,6 @@ fn test_battery() { }) } -fn pass_always() { - common::test_message("pass always"); - assert!(true); -} - -// exercise the edgen version endpoint to make sure the server is reachable. -fn connect_to_server_test() { - common::test_message("connect to server"); - assert!(match blocking::get(common::make_url(&[ - common::BASE_URL, - common::MISC_URL, - common::VERSION_URL - ])) { - Err(e) => { - eprintln!("cannot connect: {:?}", e); - false - } - Ok(v) => { - println!("have: '{}'", v.text().unwrap()); - true - } - }); -} - -fn config_exists() { - common::test_message("config exists"); - assert!(settings::PROJECT_DIRS.config_dir().exists()); - assert!(settings::CONFIG_FILE.exists()); -} - -fn data_exists() { - common::test_message("data exists"); - let data = settings::PROJECT_DIRS.data_dir(); - println!("exists: {:?}", data); - assert!(data.exists()); - - let models = data.join("models"); - println!("exists: {:?}", models); - assert!(models.exists()); - - let chat = models.join("chat"); - println!("exists: {:?}", chat); - assert!(models.exists()); - - let completions = chat.join("completions"); - println!("exists: {:?}", completions); - assert!(completions.exists()); - - let audio = models.join("audio"); - println!("exists: {:?}", audio); - assert!(audio.exists()); - - let transcriptions = audio.join("transcriptions"); - println!("exists: {:?}", transcriptions); - assert!(transcriptions.exists()); -} - fn chat_completions_status_reachable() { common::test_message("chat completions status is reachable"); assert!(match blocking::get(common::make_url(&[ @@ -237,6 +178,7 @@ fn chat_completions_status_reachable() { false } Ok(v) => { + assert!(v.status().is_success()); println!("have: '{}'", v.text().unwrap()); true } @@ -256,70 +198,13 @@ fn audio_transcriptions_status_reachable() { false } Ok(v) => { + assert!(v.status().is_success()); println!("have: '{}'", v.text().unwrap()); true } }); } -// edit the config file: set another model name and repo for the indicated endpoint. -fn set_model(ep: common::Endpoint, model_name: &str, model_repo: &str) { - common::test_message(&format!("set {} model to {}", ep, model_name,)); - - let mut config = common::get_config().unwrap(); - - match &ep { - common::Endpoint::ChatCompletions => { - config.chat_completions_model_name = model_name.to_string(); - config.chat_completions_model_repo = model_repo.to_string(); - } - common::Endpoint::AudioTranscriptions => { - config.audio_transcriptions_model_name = model_name.to_string(); - config.audio_transcriptions_model_repo = model_repo.to_string(); - } - } - common::write_config(&config).unwrap(); - - println!("pausing for 4 secs to make sure the config file has been updated"); - std::thread::sleep(std::time::Duration::from_secs(4)); - let url = match ep { - common::Endpoint::ChatCompletions => common::make_url(&[ - common::BASE_URL, - common::CHAT_URL, - common::COMPLETIONS_URL, - common::STATUS_URL, - ]), - common::Endpoint::AudioTranscriptions => common::make_url(&[ - common::BASE_URL, - common::AUDIO_URL, - common::TRANSCRIPTIONS_URL, - common::STATUS_URL, - ]), - }; - let stat: status::AIStatus = blocking::get(url).unwrap().json().unwrap(); - assert_eq!(stat.active_model, model_name); -} - -// edit the config file: set another model dir for the indicated endpoint. -fn set_model_dir(ep: common::Endpoint, model_dir: &str) { - common::test_message(&format!("set {} model directory to {}", ep, model_dir,)); - - let mut config = common::get_config().unwrap(); - - match &ep { - common::Endpoint::ChatCompletions => { - config.chat_completions_models_dir = model_dir.to_string(); - } - common::Endpoint::AudioTranscriptions => { - config.audio_transcriptions_models_dir = model_dir.to_string(); - } - } - common::write_config(&config).unwrap(); - - println!("pausing for 4 secs to make sure the config file has been updated"); - std::thread::sleep(std::time::Duration::from_secs(4)); -} - fn test_config_reset() { common::test_message("test resetting config"); common::reset_config(); diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100755 index 0000000..254968a --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,18 @@ + +from edgen import Edgen, APIConnectionError +from edgen.resources.misc import Version +import pytest +import subprocess + +client = Edgen() + +def test_models(): + try: + models = client.models.list() + except APIConnectionError: + pytest.fail("No connection. Is edgen running?") + + assert(type(models) is list) + +if __name__ == "__main__": + test_models()