Skip to content

Commit

Permalink
Refactor folders (#30)
Browse files Browse the repository at this point in the history
* Refactor folders

* Remove unused import

* move app_fs_ext to storage_ext

* Fix storage_ext import

* Fix pathbuf

* Fix import

* No need for verbose output in CI

* Tweak imports
  • Loading branch information
gabotechs authored Jan 11, 2025
1 parent 9115817 commit 2d9bddd
Show file tree
Hide file tree
Showing 29 changed files with 773 additions and 731 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup
- run: cargo test -vv --no-default-features --features onnxruntime-from-source
- run: cargo test --no-default-features --features onnxruntime-from-source

smoke-test:
strategy:
Expand Down
1 change: 0 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ mod build {
Ok(())
}


#[derive(Deserialize, Serialize, Clone)]
pub struct BuildInfo {
/// absolute path to all the compiled dynamic library files.
Expand Down
5 changes: 2 additions & 3 deletions src/audio_manager.rs → src/audio/audio_manager.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::collections::VecDeque;
use std::time::Duration;

use anyhow::anyhow;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{
ChannelCount, SampleFormat, SampleRate, Stream, SupportedBufferSize, SupportedStreamConfig,
};
use std::collections::VecDeque;
use std::time::Duration;

const DEFAULT_SAMPLING_RATE: u32 = 32000;

Expand Down
3 changes: 3 additions & 0 deletions src/audio/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod audio_manager;

pub use audio_manager::{AudioManager, AudioStream};
50 changes: 2 additions & 48 deletions src/backend/audio_generation_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ use std::time::Duration;

use tokio_util::sync::CancellationToken;

use crate::music_gen_audio_encodec::MusicGenAudioEncodec;
use crate::music_gen_decoder::MusicGenDecoder;
use crate::music_gen_text_encoder::MusicGenTextEncoder;

const INPUT_IDS_BATCH_PER_SECOND: usize = 50;

#[derive(Clone, Debug)]
pub struct AudioGenerationRequest {
pub id: String,
Expand Down Expand Up @@ -58,47 +52,6 @@ pub trait JobProcessor: Send + Sync {
) -> ort::Result<VecDeque<f32>>;
}

pub struct MusicGenJobProcessor {
pub name: String,
pub device: String,
pub text_encoder: MusicGenTextEncoder,
pub decoder: Box<dyn MusicGenDecoder>,
pub audio_encodec: MusicGenAudioEncodec,
}

impl JobProcessor for MusicGenJobProcessor {
fn name(&self) -> String {
self.name.clone()
}

fn device(&self) -> String {
self.device.clone()
}

fn process(
&self,
prompt: &str,
secs: usize,
on_progress: Box<dyn Fn(f32) -> bool + Sync + Send + 'static>,
) -> ort::Result<VecDeque<f32>> {
let max_len = secs * INPUT_IDS_BATCH_PER_SECOND;

let (lhs, am) = self.text_encoder.encode(prompt)?;
let token_stream = self.decoder.generate_tokens(lhs, am, max_len)?;

let mut data = VecDeque::new();
while let Ok(tokens) = token_stream.recv() {
data.push_back(tokens?);
let should_exit = on_progress(data.len() as f32 / max_len as f32);
if should_exit {
return Err(ort::Error::new("Aborted"));
}
}

self.audio_encodec.encode(data)
}
}

#[derive(Clone)]
pub struct AudioGenerationBackend {
processor: Arc<dyn JobProcessor>,
Expand Down Expand Up @@ -249,7 +202,8 @@ mod tests {
// TODO: for some reason this test fails in CI with a timeout.
#[cfg(not(target_os = "macos"))]
async fn handles_job_cancellation() -> anyhow::Result<()> {
let backend = AudioGenerationBackend::new(DummyJobProcessor::new(Duration::from_millis(200)));
let backend =
AudioGenerationBackend::new(DummyJobProcessor::new(Duration::from_millis(200)));

let (tx, rx) = backend.run();

Expand Down
2 changes: 1 addition & 1 deletion src/backend/audio_generation_fanout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use specta::Type;
use tracing::info;
use uuid::Uuid;

use crate::audio_manager::AudioManager;
use crate::audio::AudioManager;
use crate::backend::audio_generation_backend::BackendOutboundMsg;
use crate::backend::music_gpt_chat::ChatEntry;
use crate::backend::music_gpt_ws_handler::IdPair;
Expand Down
18 changes: 9 additions & 9 deletions src/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
pub use audio_generation_backend::MusicGenJobProcessor;
pub use audio_generation_backend::JobProcessor;
pub use server::*;

mod audio_generation_backend;
mod server;
#[cfg(test)]
mod _test_utils;
mod music_gpt_chat;
mod audio_generation_backend;
mod audio_generation_fanout;
mod ws_handler;
mod music_gpt_chat;
mod music_gpt_ws_handler;
mod server;
mod ws_handler;

#[cfg(test)]
mod tests {
use specta::ts::{BigIntExportBehavior, ExportConfiguration};
use std::path::{Path, PathBuf};
use std::time::Duration;
use specta::ts::{BigIntExportBehavior, ExportConfiguration};

use crate::storage::AppFs;
use crate::backend::_test_utils::DummyJobProcessor;
use crate::backend::RunOptions;
use crate::backend::_test_utils::DummyJobProcessor;
use crate::backend::server::run;
use crate::storage::AppFs;

#[ignore]
#[tokio::test]
Expand All @@ -46,4 +46,4 @@ mod tests {
)?;
Ok(())
}
}
}
5 changes: 1 addition & 4 deletions src/backend/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ mod tests {
use async_trait::async_trait;
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::Duration;

use futures_util::{SinkExt, StreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
Expand All @@ -90,9 +89,7 @@ mod tests {

use crate::backend::_test_utils::DummyJobProcessor;
use crate::backend::music_gpt_chat::{AiChatEntry, ChatEntry, UserChatEntry};
use crate::backend::music_gpt_ws_handler::{
AbortGenerationRequest, ChatRequest, GenerateAudioRequest, InboundMsg, OutboundMsg,
};
use crate::backend::music_gpt_ws_handler::{AbortGenerationRequest, ChatRequest, GenerateAudioRequest, InboundMsg, OutboundMsg};

use super::*;

Expand Down
51 changes: 51 additions & 0 deletions src/cli/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use log::info;
use std::collections::VecDeque;
use std::fmt::Display;
use std::path::PathBuf;

use crate::cli::loading_bar::LoadingBarFactory;
use crate::cli::storage_ext::StorageExt;
use crate::cli::PROJECT_FS;
use crate::storage::Storage;

pub async fn download_many<T: Display>(
remote_file_spec: Vec<(T, T)>,
force_download: bool,
on_download_msg: &str,
on_finished_msg: &str,
) -> anyhow::Result<VecDeque<PathBuf>> {
let mut has_to_download = force_download;
for (_, local_filename) in remote_file_spec.iter() {
has_to_download = has_to_download || !PROJECT_FS.exists(&local_filename.to_string()).await?
}

if has_to_download {
info!("{on_download_msg}");
}
let m = LoadingBarFactory::multi();
let mut tasks = vec![];
for (remote_file, local_filename) in remote_file_spec {
let remote_file = remote_file.to_string();
let local_filename = local_filename.to_string();
let bar = m.add(LoadingBarFactory::download_bar(&local_filename));
tasks.push(tokio::spawn(async move {
PROJECT_FS
.fetch_remote_data_file(
&remote_file,
&local_filename,
force_download,
bar.into_update_callback(),
)
.await
}));
}
let mut results = VecDeque::new();
for task in tasks {
results.push_back(task.await??);
}
m.clear()?;
if has_to_download {
info!("{on_finished_msg}");
}
Ok(results)
}
46 changes: 46 additions & 0 deletions src/cli/gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use anyhow::anyhow;
use log::{error, info};
use ort::execution_providers::{
CUDAExecutionProvider, CoreMLExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
TensorRTExecutionProvider,
};
use ort::session::Session;

pub fn init_gpu() -> anyhow::Result<(&'static str, ExecutionProviderDispatch)> {
let mut dummy_builder = Session::builder()?;

if cfg!(feature = "tensorrt") {
let provider = TensorRTExecutionProvider::default();
match provider.register(&mut dummy_builder) {
Ok(_) => {
info!("{} detected", provider.as_str());
return Ok(("TensorRT", provider.build()));
}
Err(err) => error!("Could not load {}: {}", provider.as_str(), err),
}
}
if cfg!(feature = "cuda") {
let provider = CUDAExecutionProvider::default();
match provider.register(&mut dummy_builder) {
Ok(_) => {
info!("{} detected", provider.as_str());
return Ok(("Cuda", provider.build()));
}
Err(err) => error!("Could not load {}: {}", provider.as_str(), err),
}
}
if cfg!(feature = "coreml") {
let provider = CoreMLExecutionProvider::default().with_ane_only();
match provider.register(&mut dummy_builder) {
Ok(_) => {
info!("{} detected", provider.as_str());
return Ok(("CoreML", provider.build()));
}
Err(err) => error!("Could not load {}: {}", provider.as_str(), err),
}
}

Err(anyhow!(
"No hardware accelerator was detected, try running the program without the --gpu flag",
))
}
13 changes: 6 additions & 7 deletions src/loading_bar_factory.rs → src/cli/loading_bar.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};
use std::fmt::Write;
use std::ops::{Deref, DerefMut};
use std::time::Duration;

use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};

pub struct LoadingBarFactor;
pub struct LoadingBarFactory;

pub struct Bar(ProgressBar);

Expand All @@ -13,9 +12,9 @@ impl Bar {
self.0.set_length(total as u64);
self.0.set_position(elapsed as u64);
}
pub fn into_update_callback(self) -> Box<dyn Fn(usize, usize) + Send + Sync + 'static> {
Box::new(move |el, t| self.update_elapsed_total(el, t))

pub fn into_update_callback(self) -> Box<dyn Fn(usize, usize) + Send + Sync + 'static> {
Box::new(move |el, t| self.update_elapsed_total(el, t))
}
}

Expand Down Expand Up @@ -55,7 +54,7 @@ impl DerefMut for MultiBar {
}
}

impl LoadingBarFactor {
impl LoadingBarFactory {
pub fn multi() -> MultiBar {
MultiBar(MultiProgress::new())
}
Expand Down
Loading

0 comments on commit 2d9bddd

Please sign in to comment.