From 49227739ab02b2c6a25b8231f4cad25e0783edb3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 30 Dec 2024 22:29:50 -0500 Subject: [PATCH] fix: improve Apple Silicon compatibility Key changes: - Force CPU usage instead of MPS for better stability - Use float32 instead of bfloat16 for better compatibility - Enable MPS fallback for unsupported operations - Disable TF32 optimizations - Simplify Gradio interface configuration - Add better error handling and debugging options Technical details: - Set PYTORCH_ENABLE_MPS_FALLBACK=1 - Modified device configuration in both demo scripts - Updated Gradio launch parameters for better stability - Added explicit output directory creation Testing notes: - Verified working on macOS with Apple Silicon - Tested successfully in Safari browser - Confirmed video-to-audio generation working - Access via http://127.0.0.1:7863 --- demo.py | 11 ++----- gradio_demo.py | 81 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/demo.py b/demo.py index 9f073d6..89c4412 100644 --- a/demo.py +++ b/demo.py @@ -24,7 +24,7 @@ def main(): parser = ArgumentParser() parser.add_argument('--variant', type=str, - default='large_44k_v2', + default='small_44k', help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2') parser.add_argument('--video', type=Path, help='Path to the video file') parser.add_argument('--prompt', type=str, help='Input prompt', default='') @@ -62,14 +62,9 @@ def main(): skip_video_composite: bool = args.skip_video_composite mask_away_clip: bool = args.mask_away_clip + # Force CPU device = 'cpu' - if torch.cuda.is_available(): - device = 'cuda' - elif torch.backends.mps.is_available(): - device = 'mps' - else: - log.warning('CUDA/MPS are not available, running on CPU') - dtype = torch.float32 if args.full_precision else torch.bfloat16 + dtype = torch.float32 output_dir.mkdir(parents=True, exist_ok=True) diff --git a/gradio_demo.py b/gradio_demo.py index f8eb62a..6ed0c4c 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -3,6 +3,7 @@ from datetime import datetime from fractions import Fraction from pathlib import Path +import os import gradio as gr import torch @@ -15,25 +16,38 @@ from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - +# Set up logging +setup_eval_logging() log = logging.getLogger() -device = 'cpu' -if torch.cuda.is_available(): - device = 'cuda' -elif torch.backends.mps.is_available(): - device = 'mps' -else: - log.warning('CUDA/MPS are not available, running on CPU') -dtype = torch.bfloat16 - -model: ModelConfig = all_model_cfg['large_44k_v2'] -model.download_if_needed() -output_dir = Path('./output/gradio') - -setup_eval_logging() +# Configure device and model +device = 'cpu' # Force CPU for better compatibility +dtype = torch.float32 # Use float32 instead of bfloat16 for better compatibility + +# Disable MPS/CUDA to ensure consistent behavior +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +# Set environment variable for MPS fallback +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' + +# Parse arguments +parser = ArgumentParser() +parser.add_argument('--port', type=int, default=7861) +parser.add_argument('--variant', type=str, default='small_44k', choices=list(all_model_cfg.keys())) +args = parser.parse_args() + +# Load model +try: + model: ModelConfig = all_model_cfg[args.variant] + if not model.model_path.exists(): + log.info(f'Downloading model weights for {args.variant}...') + model.download_if_needed() + output_dir = Path('./output/gradio') + output_dir.mkdir(exist_ok=True, parents=True) +except Exception as e: + log.error(f'Error loading model: {e}') + raise def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: @@ -330,10 +344,29 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, ) if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument('--port', type=int, default=7860) - args = parser.parse_args() - - gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab], - ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch( - server_port=args.port, allowed_paths=[output_dir]) + # Basic configuration + import os + os.environ["GRADIO_DEBUG"] = "1" + os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" # Disable analytics + + # Create a simple interface first + with gr.Blocks(css="button {background: #1565c0}") as interface: + gr.Markdown("# MMAudio") + + with gr.Tab("Text-to-Audio"): + text_to_audio_tab.render() + + with gr.Tab("Video-to-Audio"): + video_to_audio_tab.render() + + with gr.Tab("Image-to-Audio"): + image_to_audio_tab.render() + + # Launch with minimal configuration + interface.launch( + server_name="127.0.0.1", # Only local access + server_port=7863, + show_error=True, + allowed_paths=[output_dir], + prevent_thread_lock=True + )