Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
parser = ArgumentParser()
parser.add_argument('--variant',
type=str,
default='large_44k_v2',
default='small_44k',
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

large_44k_v2 is a better default.

help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
parser.add_argument('--video', type=Path, help='Path to the video file')
parser.add_argument('--prompt', type=str, help='Input prompt', default='')
Expand Down Expand Up @@ -62,14 +62,9 @@ def main():
skip_video_composite: bool = args.skip_video_composite
mask_away_clip: bool = args.mask_away_clip

# Force CPU
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
log.warning('CUDA/MPS are not available, running on CPU')
dtype = torch.float32 if args.full_precision else torch.bfloat16
dtype = torch.float32

output_dir.mkdir(parents=True, exist_ok=True)

Expand Down
81 changes: 57 additions & 24 deletions gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from fractions import Fraction
from pathlib import Path
import os

import gradio as gr
import torch
Expand All @@ -15,25 +16,38 @@
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Set up logging
setup_eval_logging()
log = logging.getLogger()

device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
log.warning('CUDA/MPS are not available, running on CPU')
dtype = torch.bfloat16

model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')

setup_eval_logging()
# Configure device and model
device = 'cpu' # Force CPU for better compatibility
dtype = torch.float32 # Use float32 instead of bfloat16 for better compatibility

# Disable MPS/CUDA to ensure consistent behavior
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# Set environment variable for MPS fallback
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Parse arguments
parser = ArgumentParser()
parser.add_argument('--port', type=int, default=7861)
parser.add_argument('--variant', type=str, default='small_44k', choices=list(all_model_cfg.keys()))
args = parser.parse_args()

# Load model
try:
model: ModelConfig = all_model_cfg[args.variant]
if not model.model_path.exists():
log.info(f'Downloading model weights for {args.variant}...')
model.download_if_needed()
output_dir = Path('./output/gradio')
output_dir.mkdir(exist_ok=True, parents=True)
except Exception as e:
log.error(f'Error loading model: {e}')
raise
Comment on lines +41 to +50
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

download_if_needed already checks for existing model files and additionally checks for MD5



def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
Expand Down Expand Up @@ -330,10 +344,29 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int,
)

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--port', type=int, default=7860)
args = parser.parse_args()

gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
server_port=args.port, allowed_paths=[output_dir])
# Basic configuration
import os
os.environ["GRADIO_DEBUG"] = "1"
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" # Disable analytics

# Create a simple interface first
with gr.Blocks(css="button {background: #1565c0}") as interface:
gr.Markdown("# MMAudio")

with gr.Tab("Text-to-Audio"):
text_to_audio_tab.render()

with gr.Tab("Video-to-Audio"):
video_to_audio_tab.render()

with gr.Tab("Image-to-Audio"):
image_to_audio_tab.render()

# Launch with minimal configuration
interface.launch(
server_name="127.0.0.1", # Only local access
server_port=7863,
show_error=True,
allowed_paths=[output_dir],
prevent_thread_lock=True
)