Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jan 8, 2024
1 parent 5b7a08b commit dc5677d
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 151 deletions.
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"[rust]": {
"editor.defaultFormatter": "rust-lang.rust-analyzer",
"editor.formatOnSave": true
}
}
3 changes: 3 additions & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
indent_style = "Block"
reorder_imports = true
max_width = 130
89 changes: 44 additions & 45 deletions src/audio/encoder.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
extern crate ffmpeg_next as ffmpeg;
use std::path::{Path, PathBuf};

use anyhow::{Context, Result};
use ffmpeg::{codec, filter, format, frame, media};
use ffmpeg::{rescale, Rescale};
use anyhow::Result;
use log::debug;
use std::path::{Path, PathBuf};

fn filter(
spec: &str,
decoder: &codec::decoder::Audio,
encoder: &codec::encoder::Audio,
) -> Result<filter::Graph, ffmpeg::Error> {
fn filter(spec: &str, decoder: &codec::decoder::Audio, encoder: &codec::encoder::Audio) -> Result<filter::Graph> {
let mut filter = filter::Graph::new();


let channel_layout = if !decoder.channel_layout().is_empty() {
decoder.channel_layout()
} else {
ffmpeg_next::channel_layout::ChannelLayout::MONO
};

let args = format!(
"time_base={}:sample_rate={}:sample_fmt={}:channel_layout=0x{:x}",
decoder.time_base(),
Expand All @@ -29,11 +23,11 @@ fn filter(
);
debug!("args are {args}");

filter.add(&filter::find("abuffer").unwrap(), "in", &args)?;
filter.add(&filter::find("abuffersink").unwrap(), "out", "")?;
filter.add(&filter::find("abuffer").context("cant find abuffer")?, "in", &args)?;
filter.add(&filter::find("abuffersink").context("cant find abuffersink")?, "out", "")?;

{
let mut out = filter.get("out").unwrap();
let mut out = filter.get("out").context("out is none")?;

out.set_sample_format(encoder.format());
out.set_channel_layout(ffmpeg_next::channel_layout::ChannelLayout::MONO); // TODO change
Expand All @@ -52,7 +46,7 @@ fn filter(
{
filter
.get("out")
.unwrap()
.context("out is none")?
.sink()
.set_frame_size(encoder.frame_size());
}
Expand All @@ -75,7 +69,7 @@ fn transcoder<P: AsRef<Path>>(
octx: &mut format::context::Output,
path: &P,
filter_spec: &str,
) -> Result<Transcoder, ffmpeg::Error> {
) -> Result<Transcoder> {
let input = ictx
.streams()
.best(media::Type::Audio)
Expand All @@ -85,10 +79,7 @@ fn transcoder<P: AsRef<Path>>(
let codec = ffmpeg::encoder::find(octx.format().codec(path, media::Type::Audio))
.expect("failed to find encoder")
.audio()?;
let global = octx
.format()
.flags()
.contains(ffmpeg::format::flag::Flags::GLOBAL_HEADER);
let global = octx.format().flags().contains(ffmpeg::format::flag::Flags::GLOBAL_HEADER);

decoder.set_parameters(input.parameters())?;

Expand All @@ -111,9 +102,9 @@ fn transcoder<P: AsRef<Path>>(
encoder.set_format(
codec
.formats()
.expect("unknown supported formats")
.context("unknown supported formats")?
.next()
.unwrap(),
.context("codec not found")?,
);
encoder.set_bit_rate(decoder.bit_rate());
encoder.set_max_bit_rate(decoder.max_bit_rate());
Expand All @@ -140,52 +131,60 @@ fn transcoder<P: AsRef<Path>>(
}

impl Transcoder {
fn send_frame_to_encoder(&mut self, frame: &ffmpeg::Frame) {
self.encoder.send_frame(frame).unwrap();
fn send_frame_to_encoder(&mut self, frame: &ffmpeg::Frame) -> Result<()> {
self.encoder.send_frame(frame)?;
Ok(())
}

fn send_eof_to_encoder(&mut self) {
self.encoder.send_eof().unwrap();
fn send_eof_to_encoder(&mut self) -> Result<()> {
self.encoder.send_eof()?;
Ok(())
}

fn receive_and_process_encoded_packets(&mut self, octx: &mut format::context::Output) {
fn receive_and_process_encoded_packets(&mut self, octx: &mut format::context::Output) -> Result<()> {
let mut encoded = ffmpeg::Packet::empty();
while self.encoder.receive_packet(&mut encoded).is_ok() {
encoded.set_stream(0);
encoded.rescale_ts(self.in_time_base, self.out_time_base);
encoded.write_interleaved(octx).unwrap();
encoded.write_interleaved(octx)?;
}
Ok(())
}

fn add_frame_to_filter(&mut self, frame: &ffmpeg::Frame) {
self.filter.get("in").unwrap().source().add(frame).unwrap();
fn add_frame_to_filter(&mut self, frame: &ffmpeg::Frame) -> Result<()> {
self.filter.get("in").context("in is none")?.source().add(frame)?;
Ok(())
}

fn flush_filter(&mut self) {
self.filter.get("in").unwrap().source().flush().unwrap();
fn flush_filter(&mut self) -> Result<()> {
self.filter.get("in").context("in is none")?.source().flush()?;
Ok(())
}

fn get_and_process_filtered_frames(&mut self, octx: &mut format::context::Output) {
fn get_and_process_filtered_frames(&mut self, octx: &mut format::context::Output) -> Result<()> {
let mut filtered = frame::Audio::empty();
while self
.filter
.get("out")
.unwrap()
.context("out is none")?
.sink()
.frame(&mut filtered)
.is_ok()
{
self.send_frame_to_encoder(&filtered);
self.receive_and_process_encoded_packets(octx);
}
Ok(())
}

fn send_packet_to_decoder(&mut self, packet: &ffmpeg::Packet) {
self.decoder.send_packet(packet).unwrap();
fn send_packet_to_decoder(&mut self, packet: &ffmpeg::Packet) -> Result<()> {
self.decoder.send_packet(packet)?;
Ok(())
}

fn send_eof_to_decoder(&mut self) {
self.decoder.send_eof().unwrap();
fn send_eof_to_decoder(&mut self) -> Result<()> {
self.decoder.send_eof()?;
Ok(())
}

fn receive_and_process_decoded_frames(&mut self, octx: &mut format::context::Output) {
Expand All @@ -211,24 +210,24 @@ impl Transcoder {
// Example 3: Seek to a specified position (in seconds)
// transcode-audio in.mp3 out.mp3 anull 30
pub fn convert_to_16khz(input: PathBuf, output: PathBuf, filter: Option<String>, seek: Option<String>) -> Result<()> {
ffmpeg::init().unwrap();
ffmpeg::init()?;

let filter = filter.unwrap_or_else(|| "anull".to_owned());
let seek = seek.and_then(|s| s.parse::<i64>().ok());

let mut ictx = format::input(&input).unwrap();
let mut octx = format::output(&output).unwrap();
let mut transcoder = transcoder(&mut ictx, &mut octx, &output, &filter).unwrap();
let mut ictx = format::input(&input)?;
let mut octx = format::output(&output)?;
let mut transcoder = transcoder(&mut ictx, &mut octx, &output, &filter)?;

if let Some(position) = seek {
// If the position was given in seconds, rescale it to ffmpegs base timebase.
let position = position.rescale((1, 1), rescale::TIME_BASE);
// If this seek was embedded in the transcoding loop, a call of `flush()`
// for every opened buffer after the successful seek would be advisable.
ictx.seek(position, ..position).unwrap();
ictx.seek(position, ..position)?;
}
octx.set_metadata(ictx.metadata().to_owned());
octx.write_header().unwrap();
octx.write_header()?;

for (stream, mut packet) in ictx.packets() {
if stream.index() == transcoder.stream {
Expand All @@ -247,6 +246,6 @@ pub fn convert_to_16khz(input: PathBuf, output: PathBuf, filter: Option<String>,
transcoder.send_eof_to_encoder();
transcoder.receive_and_process_encoded_packets(&mut octx);

octx.write_trailer().unwrap();
octx.write_trailer()?;
Ok(())
}
}
16 changes: 7 additions & 9 deletions src/audio/mod.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
use std::path::PathBuf;
use anyhow::Result;
use ffmpeg_next as ffmpeg;
use std::path::PathBuf;
mod encoder;
pub struct Audio {}


impl Audio {
pub fn try_create() -> Result<Self> {

Ok(Audio {})
}

pub fn convert(&self, input: PathBuf, output: PathBuf) -> Result<()> {
encoder::convert_to_16khz(input, output, None, None)?;
Ok(())
}

}

#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
use anyhow::Result;
use log::debug;
use std::fs;
use tempfile::tempdir;

fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}

fn wait_for_enter() {
fn wait_for_enter() -> Result<()> {
println!("PRESS ENTER");
let mut buffer = String::new();
std::io::stdin().read_line(&mut buffer).unwrap();
std::io::stdin().read_line(&mut buffer)?;
Ok(())
}

#[test]
Expand All @@ -56,4 +54,4 @@ mod tests {

Ok(())
}
}
}
9 changes: 4 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use std::path::PathBuf;
use anyhow::Result;
use app_dirs2::{AppInfo, AppDataType, app_root};
use app_dirs2;
use app_dirs2::{app_root, AppDataType, AppInfo};
use std::path::PathBuf;

pub const APP_NAME: &str = "ruscribe";
pub const URL: &str = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin?download=true";
pub const FILENAME: &str = "ggml-medium.bin";
pub const HASH: &str = ""; // TODO
pub const APP_INFO: AppInfo = AppInfo {
name: "ruscribe",
author: "github.com/thewh1teagle",
author: "github.com.thewh1teagle",
};


pub fn get_model_path() -> Result<PathBuf> {
let app_config = app_root(AppDataType::UserData, &APP_INFO)?;
let filepath = app_config.join(FILENAME);
let filepath = app_config.join(FILENAME);
Ok(filepath)
}
38 changes: 18 additions & 20 deletions src/downloader.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,34 @@
use std::path::PathBuf;
use crate::config;
use anyhow::{bail, Context, Ok, Result};
use env_logger;
use futures_util::StreamExt;
use indicatif;
use reqwest::{self, Request};
use anyhow::{Result, Context, Ok, bail};
use indicatif::{ProgressBar, ProgressStyle};
use futures_util::StreamExt;
use std::io::Write;
use std::clone::Clone;
use env_logger;
use log::debug;
use reqwest::{self, Request};
use sha256::{digest, try_digest};
use std::clone::Clone;
use std::io::Write;
use std::path::PathBuf;
// https://huggingface.co/ggerganov/whisper.cpp/tree/main


struct Downloader {
client: reqwest::Client
client: reqwest::Client,
}

impl Downloader {
fn new() -> Self {
let client = reqwest::Client::new();

Downloader { client }
}


async fn download(&mut self, url: &str, path: PathBuf, hash: Option<&str>) -> Result<()> {
if path.exists() {
debug!("file {} exists!", path.display());
return Ok(());
}
let res = self.client.get(url)
.send()
.await?;
let res = self.client.get(url).send().await?;
let total_size = res
.content_length()
.context(format!("Failed to get content length from '{}'", url))?;
Expand All @@ -56,8 +52,9 @@ impl Downloader {
Ok(())
}

async fn _verify(path: &PathBuf, hash: String) -> Result<()> { // TODO
let val = try_digest(path).unwrap();
async fn _verify(path: &PathBuf, hash: String) -> Result<()> {
// TODO
let val = try_digest(path)?;
if val != hash {
bail!("Invalid file hash!");
}
Expand All @@ -67,9 +64,9 @@ impl Downloader {

#[cfg(test)]
mod tests {
use crate::{config, downloader};
use anyhow::{Context, Result};
use app_dirs2::*;
use crate::{downloader, config};

fn init() {
let _ = env_logger::builder().is_test(true).try_init();
Expand All @@ -81,8 +78,9 @@ mod tests {
let mut d = downloader::Downloader::new();
let app_config = app_root(AppDataType::UserData, &config::APP_INFO)?;
let filepath = config::get_model_path()?;
d.download(config::URL, filepath, Some(config::HASH)).await.context("Cant download")?;
d.download(config::URL, filepath, Some(config::HASH))
.await
.context("Cant download")?;
Ok(())
}

}
}
Loading

0 comments on commit dc5677d

Please sign in to comment.