Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions stable_audio_tools/interface/gradio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import gc
import numpy as np
import gradio as gr
import json
import json
import re
import subprocess
import torch
import torchaudio
import scipy.io.wavfile as wavfile

from einops import rearrange
from safetensors.torch import load_file
Expand Down Expand Up @@ -159,11 +160,12 @@ def progress_callback(callback_info):

audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

torchaudio.save("output.wav", audio, sample_rate)
audio_np = audio.numpy().T
wavfile.write("output.wav", sample_rate, audio_np)

audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)

return ("output.wav", [audio_spectrogram, *preview_images])
return ((sample_rate, audio_np), [audio_spectrogram, *preview_images])

def generate_lm(
temperature=1.0,
Expand Down Expand Up @@ -193,11 +195,12 @@ def generate_lm(

audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

torchaudio.save("output.wav", audio, sample_rate)
audio_np = audio.numpy().T
wavfile.write("output.wav", sample_rate, audio_np)

audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)

return ("output.wav", [audio_spectrogram])
return ((sample_rate, audio_np), [audio_spectrogram])


def create_uncond_sampling_ui(model_config):
Expand Down Expand Up @@ -305,7 +308,7 @@ def autoencoder_process(audio, latent_noise, n_quantizers):

audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

torchaudio.save("output.wav", audio, sample_rate)
wavfile.write("output.wav", sample_rate, audio.numpy().T)

return "output.wav"

Expand Down
27 changes: 16 additions & 11 deletions stable_audio_tools/interface/interfaces/diffusion_cond.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import gc
import numpy as np
import gradio as gr
import json
import json
import re
import subprocess
import torch
import torchaudio
import threading
import scipy.io.wavfile as wavfile
import threading
import os, time, math

from einops import rearrange
from torchaudio import transforms as T
import torchaudio

from ..aeiou import audio_spectrogram_image
from ...inference.generation import generate_diffusion_cond, generate_diffusion_cond_inpaint
Expand Down Expand Up @@ -253,10 +254,14 @@ def progress_callback(callback_info):

# Encode the audio to WAV format
audio = rearrange(audio, "b d n -> d (b n)")
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).cpu()

# Convert to numpy for saving and returning (float32 [-1, 1])
audio_np = audio.numpy().T # [samples, channels]

# save as wav file
torchaudio.save(output_wav, audio, sample_rate)
# save as wav file (int16)
audio_int16 = (audio_np * 32767).astype(np.int16)
wavfile.write(output_wav, sample_rate, audio_int16)

# If file_format is other than wav, convert to other file format
cmd = ""
Expand All @@ -278,15 +283,16 @@ def progress_callback(callback_info):
if cmd:
cmd += " -loglevel error" # make output less verbose in the cmd window
subprocess.run(cmd, shell=True, check=True)

# Let's look at a nice spectrogram too
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
audio_spectrogram = audio_spectrogram_image(audio.mul(32767).to(torch.int16), sample_rate=sample_rate)

# Asynchronously delete the files after returning the output file, so as to prevent clutter in the directory
if file_naming in ["verbose", "prompt"]:
delete_files_async([output_wav, output_filename], 30)

return (output_filename, [audio_spectrogram, *preview_images])
# Return audio as tuple so Gradio handles encoding/serving internally
return ((sample_rate, audio_np), [audio_spectrogram, *preview_images])

# Asynchronously delete the given list of filenames after delay seconds. Sets up thread that sleeps for delay then deletes.
def delete_files_async(filenames, delay):
Expand Down Expand Up @@ -426,8 +432,7 @@ def create_sampling_ui(model_config):
]

with gr.Column():
audio_output = gr.Audio(label="Output audio", interactive=False,
waveform_options=gr.WaveformOptions(show_recording_waveform=False))
audio_output = gr.Audio(label="Output audio", interactive=False, format="wav")
audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
send_to_init_button = gr.Button("Send to init audio", scale=1)
send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
Expand Down