Skip to content

Commit

Permalink
recovered cohere fn -> #18
Browse files Browse the repository at this point in the history
  • Loading branch information
bm777 committed Apr 7, 2024
1 parent f78b0db commit 27feac8
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ env_logger = "0.11.3"
log = "0.4"
diesel_migrations = {version = "2.1.0" }
chrono = "0.4.37"
reqwest = { version = "0.12.2", features = ["json", "rustls-tls"] }
tokio = { version = "1.37.0", features = ["full"] }
26 changes: 26 additions & 0 deletions src/cohere/embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, Error};
use serde_json::{Value, json};
use std::collections::HashMap;


pub async fn generate_embed(key: &String, texts: &Vec<String>) -> Result<Value, Error>{
let client = Client::new();

let mut params: HashMap<&str, Value> = HashMap::new();

params.insert("model", "embed-english-light-v3.0".into());
params.insert("texts", json!(texts));
params.insert("input_type", "search_document".into());

// let json_params = json!(params);

let response = client.post("https://api.cohere.ai/v1/embed")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", key))
.json(&params)
.send()
.await?;

Ok(response.json().await?)
}
3 changes: 3 additions & 0 deletions src/cohere/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod embed;
pub mod rerank;
pub mod sum;
26 changes: 26 additions & 0 deletions src/cohere/rerank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, Error};
use serde_json::{Value, json};
use std::collections::HashMap;


pub async fn rerank(key: &String, query: &String, page: &Vec<String>) -> Result<Value, Error>{
let client = Client::new();

let mut params: HashMap<&str, Value> = HashMap::new();

params.insert("model", "rerank-english-v2.0".into());
// convert query to value
params.insert("query", json!(query));
params.insert("top_n", 3.into()); // default to 5 in production
params.insert("documents", json!(page));

let response = client.post("https://api.cohere.ai/v1/rerank")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", key))
.json(&params)
.send()
.await?;

Ok(response.json().await?)
}
23 changes: 23 additions & 0 deletions src/cohere/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, Error};
use serde_json::{Value, json};
use std::collections::HashMap;


pub async fn summarize(key: &String, page: &String) -> Result<Value, Error>{
let client = Client::new();

let mut params: HashMap<&str, Value> = HashMap::new();

params.insert("preamble", "Summarize the webpage so that it can be retrieved later.".into());
params.insert("message", json!(page));

let response = client.post("https://api.cohere.ai/v1/chat")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", key))
.json(&params)
.send()
.await?;

Ok(response.json().await?)
}
53 changes: 53 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ use shellexpand;
mod schema;
mod models;
mod utils;
mod cohere;

use crate::cohere::{
embed::generate_embed,
rerank::rerank,
sum::summarize,
};

// Database connection pool -> thoughout teh APP
type DbPool = Pool<ConnectionManager<SqliteConnection>>;
Expand Down Expand Up @@ -104,6 +111,52 @@ async fn main() -> std::io::Result<()> {

log::info!("Server running.... -> 1777");

let api = "C6z92rBNOj1oZ28eZmMiLJdn2SKYOh8QNh4aiyP0";
// let response = generate_embed(&api.to_string(), &vec!["I ping you!".to_string(), "I try another string".to_string()])
// .await
// .unwrap();
// println!("Response: {}", response);

let _query = "how to write python code?";
let _docs = vec![
"it is easy to write python code".to_string(),
"I try another letter here, and i love paris".to_string(),
"I am a software engineer".to_string(),
"I am a mechanical engineer".to_string(),
"I am a civil engineer".to_string(),
];
// let response = rerank(
// &api.to_string(),
// &query.to_string(),
// &docs)
// .await
// .unwrap();
// let results = response.get("results").unwrap();
// let arr = results.as_array().unwrap();
// let top3pages = vec![
// docs[arr[0].get("index").unwrap().as_i64().unwrap() as usize].clone(),
// docs[arr[1].get("index").unwrap().as_i64().unwrap() as usize].clone(),
// docs[arr[2].get("index").unwrap().as_i64().unwrap() as usize].clone(),
// ];
// println!("Response: {}", results);
// println!("Top 3 pages: {:?}", top3pages);

let _page = "Fiat money is a type of currency that is not backed by a precious metal, such as gold or silver.
It is typically designated by the issuing government to be legal tender, and is authorized by government
regulation. Since the end of the Bretton Woods system in 1971, the major currencies in the world are fiat money.
Fiat money generally does not have intrinsic value and does not have use value. It has value only because the individuals
who use it as a unit of account - or, in the case of currency, a medium of exchange - agree on its value.[1]
They trust that it will be accepted by merchants and other people as a means of payment for liabilities.";

// let response = summarize(&api.to_string(), &page.to_string())
// .await
// .unwrap();

// let summary = response.get("text").unwrap().to_string();

// println!("Response: {}", response);
// println!("Summary: {:?}", summary);

HttpServer::new(move || {
App::new()
.app_data(web::Data::new(pool.clone()))
Expand Down

0 comments on commit 27feac8

Please sign in to comment.