Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor folders #30

Merged
merged 8 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup
- run: cargo test -vv --no-default-features --features onnxruntime-from-source
- run: cargo test --no-default-features --features onnxruntime-from-source

smoke-test:
strategy:
Expand Down
1 change: 0 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ mod build {
Ok(())
}


#[derive(Deserialize, Serialize, Clone)]
pub struct BuildInfo {
/// absolute path to all the compiled dynamic library files.
Expand Down
5 changes: 2 additions & 3 deletions src/audio_manager.rs → src/audio/audio_manager.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::collections::VecDeque;
use std::time::Duration;

use anyhow::anyhow;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{
ChannelCount, SampleFormat, SampleRate, Stream, SupportedBufferSize, SupportedStreamConfig,
};
use std::collections::VecDeque;
use std::time::Duration;

const DEFAULT_SAMPLING_RATE: u32 = 32000;

Expand Down
3 changes: 3 additions & 0 deletions src/audio/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod audio_manager;

pub use audio_manager::{AudioManager, AudioStream};
50 changes: 2 additions & 48 deletions src/backend/audio_generation_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ use std::time::Duration;

use tokio_util::sync::CancellationToken;

use crate::music_gen_audio_encodec::MusicGenAudioEncodec;
use crate::music_gen_decoder::MusicGenDecoder;
use crate::music_gen_text_encoder::MusicGenTextEncoder;

const INPUT_IDS_BATCH_PER_SECOND: usize = 50;

#[derive(Clone, Debug)]
pub struct AudioGenerationRequest {
pub id: String,
Expand Down Expand Up @@ -58,47 +52,6 @@ pub trait JobProcessor: Send + Sync {
) -> ort::Result<VecDeque<f32>>;
}

pub struct MusicGenJobProcessor {
pub name: String,
pub device: String,
pub text_encoder: MusicGenTextEncoder,
pub decoder: Box<dyn MusicGenDecoder>,
pub audio_encodec: MusicGenAudioEncodec,
}

impl JobProcessor for MusicGenJobProcessor {
fn name(&self) -> String {
self.name.clone()
}

fn device(&self) -> String {
self.device.clone()
}

fn process(
&self,
prompt: &str,
secs: usize,
on_progress: Box<dyn Fn(f32) -> bool + Sync + Send + 'static>,
) -> ort::Result<VecDeque<f32>> {
let max_len = secs * INPUT_IDS_BATCH_PER_SECOND;

let (lhs, am) = self.text_encoder.encode(prompt)?;
let token_stream = self.decoder.generate_tokens(lhs, am, max_len)?;

let mut data = VecDeque::new();
while let Ok(tokens) = token_stream.recv() {
data.push_back(tokens?);
let should_exit = on_progress(data.len() as f32 / max_len as f32);
if should_exit {
return Err(ort::Error::new("Aborted"));
}
}

self.audio_encodec.encode(data)
}
}

#[derive(Clone)]
pub struct AudioGenerationBackend {
processor: Arc<dyn JobProcessor>,
Expand Down Expand Up @@ -249,7 +202,8 @@ mod tests {
// TODO: for some reason this test fails in CI with a timeout.
#[cfg(not(target_os = "macos"))]
async fn handles_job_cancellation() -> anyhow::Result<()> {
let backend = AudioGenerationBackend::new(DummyJobProcessor::new(Duration::from_millis(200)));
let backend =
AudioGenerationBackend::new(DummyJobProcessor::new(Duration::from_millis(200)));

let (tx, rx) = backend.run();

Expand Down
2 changes: 1 addition & 1 deletion src/backend/audio_generation_fanout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use specta::Type;
use tracing::info;
use uuid::Uuid;

use crate::audio_manager::AudioManager;
use crate::audio::AudioManager;
use crate::backend::audio_generation_backend::BackendOutboundMsg;
use crate::backend::music_gpt_chat::ChatEntry;
use crate::backend::music_gpt_ws_handler::IdPair;
Expand Down
18 changes: 9 additions & 9 deletions src/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
pub use audio_generation_backend::MusicGenJobProcessor;
pub use audio_generation_backend::JobProcessor;
pub use server::*;

mod audio_generation_backend;
mod server;
#[cfg(test)]
mod _test_utils;
mod music_gpt_chat;
mod audio_generation_backend;
mod audio_generation_fanout;
mod ws_handler;
mod music_gpt_chat;
mod music_gpt_ws_handler;
mod server;
mod ws_handler;

#[cfg(test)]
mod tests {
use specta::ts::{BigIntExportBehavior, ExportConfiguration};
use std::path::{Path, PathBuf};
use std::time::Duration;
use specta::ts::{BigIntExportBehavior, ExportConfiguration};

use crate::storage::AppFs;
use crate::backend::_test_utils::DummyJobProcessor;
use crate::backend::RunOptions;
use crate::backend::_test_utils::DummyJobProcessor;
use crate::backend::server::run;
use crate::storage::AppFs;

#[ignore]
#[tokio::test]
Expand All @@ -46,4 +46,4 @@ mod tests {
)?;
Ok(())
}
}
}
5 changes: 1 addition & 4 deletions src/backend/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ mod tests {
use async_trait::async_trait;
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::Duration;

use futures_util::{SinkExt, StreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
Expand All @@ -90,9 +89,7 @@ mod tests {

use crate::backend::_test_utils::DummyJobProcessor;
use crate::backend::music_gpt_chat::{AiChatEntry, ChatEntry, UserChatEntry};
use crate::backend::music_gpt_ws_handler::{
AbortGenerationRequest, ChatRequest, GenerateAudioRequest, InboundMsg, OutboundMsg,
};
use crate::backend::music_gpt_ws_handler::{AbortGenerationRequest, ChatRequest, GenerateAudioRequest, InboundMsg, OutboundMsg};

use super::*;

Expand Down
51 changes: 51 additions & 0 deletions src/cli/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use log::info;
use std::collections::VecDeque;
use std::fmt::Display;
use std::path::PathBuf;

use crate::cli::loading_bar::LoadingBarFactory;
use crate::cli::storage_ext::StorageExt;
use crate::cli::PROJECT_FS;
use crate::storage::Storage;

pub async fn download_many<T: Display>(
remote_file_spec: Vec<(T, T)>,
force_download: bool,
on_download_msg: &str,
on_finished_msg: &str,
) -> anyhow::Result<VecDeque<PathBuf>> {
let mut has_to_download = force_download;
for (_, local_filename) in remote_file_spec.iter() {
has_to_download = has_to_download || !PROJECT_FS.exists(&local_filename.to_string()).await?
}

if has_to_download {
info!("{on_download_msg}");
}
let m = LoadingBarFactory::multi();
let mut tasks = vec![];
for (remote_file, local_filename) in remote_file_spec {
let remote_file = remote_file.to_string();
let local_filename = local_filename.to_string();
let bar = m.add(LoadingBarFactory::download_bar(&local_filename));
tasks.push(tokio::spawn(async move {
PROJECT_FS
.fetch_remote_data_file(
&remote_file,
&local_filename,
force_download,
bar.into_update_callback(),
)
.await
}));
}
let mut results = VecDeque::new();
for task in tasks {
results.push_back(task.await??);
}
m.clear()?;
if has_to_download {
info!("{on_finished_msg}");
}
Ok(results)
}
46 changes: 46 additions & 0 deletions src/cli/gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use anyhow::anyhow;
use log::{error, info};
use ort::execution_providers::{
CUDAExecutionProvider, CoreMLExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
TensorRTExecutionProvider,
};
use ort::session::Session;

pub fn init_gpu() -> anyhow::Result<(&'static str, ExecutionProviderDispatch)> {
let mut dummy_builder = Session::builder()?;

if cfg!(feature = "tensorrt") {
let provider = TensorRTExecutionProvider::default();
match provider.register(&mut dummy_builder) {
Ok(_) => {
info!("{} detected", provider.as_str());
return Ok(("TensorRT", provider.build()));
}
Err(err) => error!("Could not load {}: {}", provider.as_str(), err),
}
}
if cfg!(feature = "cuda") {
let provider = CUDAExecutionProvider::default();
match provider.register(&mut dummy_builder) {
Ok(_) => {
info!("{} detected", provider.as_str());
return Ok(("Cuda", provider.build()));
}
Err(err) => error!("Could not load {}: {}", provider.as_str(), err),
}
}
if cfg!(feature = "coreml") {
let provider = CoreMLExecutionProvider::default().with_ane_only();
match provider.register(&mut dummy_builder) {
Ok(_) => {
info!("{} detected", provider.as_str());
return Ok(("CoreML", provider.build()));
}
Err(err) => error!("Could not load {}: {}", provider.as_str(), err),
}
}

Err(anyhow!(
"No hardware accelerator was detected, try running the program without the --gpu flag",
))
}
13 changes: 6 additions & 7 deletions src/loading_bar_factory.rs → src/cli/loading_bar.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};
use std::fmt::Write;
use std::ops::{Deref, DerefMut};
use std::time::Duration;

use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};

pub struct LoadingBarFactor;
pub struct LoadingBarFactory;

pub struct Bar(ProgressBar);

Expand All @@ -13,9 +12,9 @@ impl Bar {
self.0.set_length(total as u64);
self.0.set_position(elapsed as u64);
}
pub fn into_update_callback(self) -> Box<dyn Fn(usize, usize) + Send + Sync + 'static> {
Box::new(move |el, t| self.update_elapsed_total(el, t))

pub fn into_update_callback(self) -> Box<dyn Fn(usize, usize) + Send + Sync + 'static> {
Box::new(move |el, t| self.update_elapsed_total(el, t))
}
}

Expand Down Expand Up @@ -55,7 +54,7 @@ impl DerefMut for MultiBar {
}
}

impl LoadingBarFactor {
impl LoadingBarFactory {
pub fn multi() -> MultiBar {
MultiBar(MultiProgress::new())
}
Expand Down
Loading
Loading