Skip to content

Commit

Permalink
Introduce Persistance (#4)
Browse files Browse the repository at this point in the history
* feat: add UI

* Do not open URL in tests

* Introduce persistance for chats

* Slightly improve the code
  • Loading branch information
gabotechs authored May 17, 2024
1 parent c01db57 commit 21d7db8
Show file tree
Hide file tree
Showing 39 changed files with 1,956 additions and 589 deletions.
67 changes: 65 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "musicgpt"
license = "MIT"
version = "0.2.0"
version = "0.1.34"
edition = "2021"
description = "Generate music samples from natural language prompt locally with your own computer"
keywords = ["llm", "music", "audio", "ai"]
Expand Down Expand Up @@ -40,6 +40,8 @@ specta = { version = "1.0.5", features = ["uuid", "serde", "typescript", "export
axum = { version = "0.7.5", features = ["ws"] }
tower-http = { version = "0.5.2" , features = ["fs"]}
open = "5.1.2"
chrono = "0.4.38"
scopeguard = "1.2.0"

[features]
default = []
Expand Down
49 changes: 35 additions & 14 deletions src/audio_manager.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::collections::VecDeque;
use std::path::PathBuf;
use std::time::Duration;

use anyhow::anyhow;
use cpal::{ChannelCount, SampleFormat, SampleRate, Stream, SupportedBufferSize, SupportedStreamConfig};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{
ChannelCount, SampleFormat, SampleRate, Stream, SupportedBufferSize, SupportedStreamConfig,
};

const DEFAULT_SAMPLING_RATE: u32 = 32000;

Expand All @@ -30,14 +31,12 @@ impl Default for AudioManager {

pub struct AudioStream {
pub stream: Stream,
pub duration: Duration
pub duration: Duration,
}


unsafe impl Send for AudioStream {}
unsafe impl Sync for AudioStream {}


impl AudioManager {
pub fn play_from_queue(&self, mut v: VecDeque<f32>) -> anyhow::Result<AudioStream> {
let time = 1000 * v.len() / self.sampling_rate as usize;
Expand Down Expand Up @@ -70,15 +69,11 @@ impl AudioManager {
stream.play()?;
Ok(AudioStream {
stream,
duration: Duration::from_millis(time as u64)
duration: Duration::from_millis(time as u64),
})
}

pub fn store_as_wav(
&self,
v: VecDeque<f32>,
out_path: impl Into<PathBuf>,
) -> hound::Result<()> {
pub fn to_wav(&self, v: VecDeque<f32>) -> hound::Result<Vec<u8>> {
let spec = hound::WavSpec {
channels: self.n_channels,
sample_rate: self.sampling_rate,
Expand Down Expand Up @@ -110,11 +105,37 @@ impl AudioManager {
},
};

let mut writer = hound::WavWriter::create(out_path.into(), spec)?;
for sample in v {
writer.write_sample(sample)?;
let mut buffer = vec![];
let cursor = std::io::Cursor::new(&mut buffer);
let in_memory_file = std::io::BufWriter::new(cursor);
{
let mut writer = hound::WavWriter::new(in_memory_file, spec)?;
for sample in v {
writer.write_sample(sample)?;
}
// <- we need writer to be dropped here.
}

Ok(buffer)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn saves_to_wav() -> anyhow::Result<()> {
let wav_path = concat!(env!("CARGO_MANIFEST_DIR"), "/assets/test.wav");
let audio_manager = AudioManager::default();
let reader = hound::WavReader::open(wav_path)?;
let mut data = VecDeque::new();
for sample in reader.into_samples::<f32>() {
data.push_back(sample?)
}
let buff = audio_manager.to_wav(data)?;
let wav_path_content = std::fs::read(wav_path)?;
assert_eq!(wav_path_content, buff);
Ok(())
}
}
116 changes: 116 additions & 0 deletions src/fetch_remove_data_file.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::error;
use std::path::PathBuf;

use axum::http::StatusCode;
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;

use crate::storage::{AppFs, Storage};

/// Loads a remote from the local data directory, downloading it from
/// the remote endpoint if necessary
///
/// # Arguments
///
/// * `url`: The URL of the remote file
/// * `file_name`: The filename in the local data directory
/// * `force`: Force the download even if the file exists
/// * `cbk`: A callback for tracking progress of the download (elapsed, total)
///
/// returns: Result<PathBuf, Error>
impl AppFs {
pub async fn fetch_remote_data_file<Cb: Fn(usize, usize)>(
&self,
url: &str,
local_file: &str,
force: bool,
cbk: Cb,
) -> std::io::Result<PathBuf> {
// At this point, the file might already exist on disk, so nothing else to do.
if self.exists(local_file).await? && !force {
return Ok(self.path_buf(local_file));
}

// If the file was not in disk, we need to download it.
let resp = reqwest::get(url).await.map_err(io_err)?;
let status_code = resp.status();
if status_code != StatusCode::OK {
return Err(io_err(format!("Invalid status code {status_code}")));
}
let total_bytes = resp.content_length().unwrap_or_default() as usize;

// The file will be first downloaded to a temporary file, to avoid corruptions.
let temp_file = format!("{local_file}.temp");
let mut file = self.create(&temp_file).await?;

// Stream the HTTP response to the file stream.
let mut stream = resp.bytes_stream();
let mut downloaded_bytes = 0;
while let Some(item) = stream.next().await {
match item {
Ok(chunk) => {
downloaded_bytes += chunk.len();
cbk(downloaded_bytes, total_bytes);
file.write_all(&chunk).await?
}
Err(err) => return Err(io_err(err)),
}
}

// If everything succeeded, we are fine to promote the newly stored temporary
// file to the actual destination.
self.mv(&temp_file, local_file).await?;

Ok(self.path_buf(local_file))
}
}

fn io_err<E>(e: E) -> std::io::Error
where
E: Into<Box<dyn error::Error + Send + Sync>>,
{
std::io::Error::new(std::io::ErrorKind::Other, e)
}

#[cfg(test)]
mod tests {
use std::path::Path;
use std::time::SystemTime;

use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};

use crate::storage::AppFs;

fn rand_string() -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect()
}

#[tokio::test]
async fn downloads_remote_file() -> std::io::Result<()> {
let remote_file = "https://raw.githubusercontent.com/seanmonstar/reqwest/master/README.md";
let file_name = format!("foo/{}.txt", rand_string());

let app_fs = AppFs::new(Path::new("/tmp/downloads_remote_file_test"));

let time = SystemTime::now();
app_fs
.fetch_remote_data_file(remote_file, &file_name, false, |_, _| {})
.await?;
let download_elapsed = SystemTime::now().duration_since(time).unwrap().as_micros();

let time = SystemTime::now();
app_fs
.fetch_remote_data_file(remote_file, &file_name, false, |_, _| {})
.await?;
let cached_elapsed = SystemTime::now().duration_since(time).unwrap().as_micros();

assert!(download_elapsed / cached_elapsed > 10);

Ok(())
}
}
Loading

0 comments on commit 21d7db8

Please sign in to comment.