From ba077476e95e7b9d31b9096650a8ea594d516b8d Mon Sep 17 00:00:00 2001 From: biyuxu <50160635+biyuxu@users.noreply.github.com> Date: Sat, 30 Aug 2025 08:32:02 +0800 Subject: [PATCH 1/2] Add files via upload 1.Added the subtitle paragraph segmentation function, which automatically segments sentences based on volume level and duration to achieve subtitle segmentation. 2.Modified the previous nodes and added 2 parameters: volume level setting and duration setting. 3.Improved the multi-line display function of subtitles. Subtitles that exceed the length limit will automatically be displayed in multiple lines (line breaks) according to the video resolution. --- add_subtitles_to_frames.py | 140 ++++++++++++++++++--- apply_whisper.py | 252 +++++++++++++++++++------------------ 2 files changed, 248 insertions(+), 144 deletions(-) diff --git a/add_subtitles_to_frames.py b/add_subtitles_to_frames.py index 2bb061e..d55fff9 100644 --- a/add_subtitles_to_frames.py +++ b/add_subtitles_to_frames.py @@ -2,6 +2,7 @@ from .utils import tensor2pil, pil2tensor, tensor2Mask import math import os +import textwrap FONT_DIR = os.path.join(os.path.dirname(__file__),"fonts") @@ -47,6 +48,45 @@ def INPUT_TYPES(s): FUNCTION = "add_subtitles_to_frames" CATEGORY = "whisper" + def wrap_text(self, text, font, max_width): + """将文本分成多行,确保每行宽度不超过max_width""" + # 如果文本宽度已经小于最大宽度,直接返回单行 + bbox = font.getbbox(text) + text_width = bbox[2] - bbox[0] + if text_width <= max_width: + return [text] + + # 使用textwrap进行分行 + wrapper = textwrap.TextWrapper(width=int(len(text) * max_width / text_width)) + wrapped_lines = wrapper.wrap(text) + + # 进一步检查每行宽度,确保不超过最大宽度 + final_lines = [] + for line in wrapped_lines: + line_bbox = font.getbbox(line) + line_width = line_bbox[2] - line_bbox[0] + + if line_width <= max_width: + final_lines.append(line) + else: + # 如果仍然太宽,逐字符分割 + current_line = "" + for char in line: + test_line = current_line + char + test_bbox = font.getbbox(test_line) + test_width = test_bbox[2] - test_bbox[0] + + if test_width <= max_width: + current_line = test_line + else: + if current_line: + final_lines.append(current_line) + current_line = char + + if current_line: + final_lines.append(current_line) + + return final_lines def add_subtitles_to_frames(self, images, alignment, font_family, font_size, font_color, x_position, y_position, center_x, center_y, video_fps): pil_images = tensor2pil(images) @@ -95,31 +135,93 @@ def add_subtitles_to_frames(self, images, alignment, font_family, font_size, fon d = ImageDraw.Draw(img) - # center text - text_bbox = d.textbbox((x_position, y_position), alignment_obj["value"], font=font) - if center_x: - text_width = text_bbox[2] - text_bbox[0] - x_position = (width - text_width)/2 + # 计算最大允许宽度(视频宽度减去左右各50像素) + max_text_width = width - 100 + + # 分行文本 + text_lines = self.wrap_text(alignment_obj["value"], font, max_text_width) + + # 计算文本总高度 + total_text_height = 0 + line_heights = [] + for line in text_lines: + line_bbox = font.getbbox(line) + line_height = line_bbox[3] - line_bbox[1] + line_heights.append(line_height) + total_text_height += line_height + + # 添加行间距(假设为字体大小的20%) + line_spacing = int(font_size * 0.2) + total_text_height += line_spacing * (len(text_lines) - 1) + + # 计算起始Y位置 if center_y: - text_height = text_bbox[3] - text_bbox[1] - y_position = (height - text_height)/2 - - - # add text to video frames - d.text((x_position, y_position), alignment_obj["value"], fill=font_color,font=font) + current_y = (height - total_text_height) / 2 + else: + current_y = y_position + + # 绘制每一行文本 + text_bboxes = [] # 存储每行文本的边界框 + for line in text_lines: + line_bbox = font.getbbox(line) + line_width = line_bbox[2] - line_bbox[0] + + # 计算X位置 + if center_x: + line_x = (width - line_width) / 2 + else: + line_x = x_position + + # 绘制文本到视频帧 + d.text((line_x, current_y), line, fill=font_color, font=font) + + # 记录文本位置和大小 + text_bbox = (line_x, current_y, line_x + line_width, current_y + line_heights[text_lines.index(line)]) + text_bboxes.append(text_bbox) + + # 更新Y位置 + current_y += line_heights[text_lines.index(line)] + line_spacing + + # 计算整个文本区域的最小边界框 + if text_bboxes: + min_x = min(bbox[0] for bbox in text_bboxes) + min_y = min(bbox[1] for bbox in text_bboxes) + max_x = max(bbox[2] for bbox in text_bboxes) + max_y = max(bbox[3] for bbox in text_bboxes) + overall_bbox = (min_x, min_y, max_x, max_y) + else: + overall_bbox = (0, 0, 0, 0) + pil_images_with_text.append(img) - # create mask + # 创建mask black_img = Image.new('RGB', (width, height), 'black') - d = ImageDraw.Draw(black_img) - d.text((x_position, y_position), alignment_obj["value"], fill="white",font=font) + d_mask = ImageDraw.Draw(black_img) + + # 在mask上绘制文本 + current_y_mask = current_y = (height - total_text_height) / 2 if center_y else y_position + for line in text_lines: + line_bbox = font.getbbox(line) + line_width = line_bbox[2] - line_bbox[0] + + if center_x: + line_x = (width - line_width) / 2 + else: + line_x = x_position + + d_mask.text((line_x, current_y_mask), line, fill="white", font=font) + current_y_mask += line_heights[text_lines.index(line)] + line_spacing + pil_images_masks.append(black_img) - # crop subtitles to black frame - text_bbox = d.textbbox((x_position,y_position), alignment_obj["value"], font=font) - cropped_text_frame = black_img.crop(text_bbox) + # 裁剪字幕区域 + if text_bboxes: + cropped_text_frame = black_img.crop(overall_bbox) + else: + cropped_text_frame = Image.new('RGB', (1, 1), 'black') + cropped_pil_images_with_text.append(cropped_text_frame) - subtitle_coord.append(text_bbox) + subtitle_coord.append(overall_bbox) last_frame_no = end_frame_no @@ -151,4 +253,4 @@ def add_subtitles_to_frames(self, images, alignment, font_family, font_size, fon cropped_pil_images_with_text_normalised = pil2tensor(cropped_pil_images_with_text_normalised) tensor_masks = tensor2Mask(pil2tensor(pil_images_masks)) - return (tensor_images,tensor_masks,cropped_pil_images_with_text_normalised,subtitle_coord,) + return (tensor_images,tensor_masks,cropped_pil_images_with_text_normalised,subtitle_coord,) \ No newline at end of file diff --git a/apply_whisper.py b/apply_whisper.py index cebedf9..0a9cd11 100644 --- a/apply_whisper.py +++ b/apply_whisper.py @@ -3,78 +3,11 @@ import folder_paths import uuid import torchaudio -import torch -import logging - -import comfy.model_management as mm -import comfy.model_patcher - -WHISPER_MODEL_SUBDIR = os.path.join("stt", "whisper") - -logger = logging.getLogger(__name__) - -WHISPER_PATCHER_CACHE = {} - -class WhisperModelWrapper(torch.nn.Module): - """ - A torch.nn.Module wrapper for Whisper models. - This allows ComfyUI's model management to treat Whisper models like any other - torch module, enabling device placement and memory management. - """ - def __init__(self, model_name, download_root): - super().__init__() - self.model_name = model_name - self.download_root = download_root - self.whisper_model = None - self.model_loaded_weight_memory = 0 - - def load_model(self, device): - """Load the Whisper model from disk to the specified device""" - self.whisper_model = whisper.load_model( - self.model_name, - download_root=self.download_root, - device=device - ) - # Estimate model size for memory management - model_size = sum(p.numel() * p.element_size() for p in self.whisper_model.parameters()) - self.model_loaded_weight_memory = model_size - -class WhisperPatcher(comfy.model_patcher.ModelPatcher): - """ - Custom ModelPatcher for Whisper models that integrates with ComfyUI's - model management system for proper loading/offloading. - """ - def __init__(self, model, *args, **kwargs): - super().__init__(model, *args, **kwargs) - - def patch_model(self, device_to=None, *args, **kwargs): - """ - This method is called by ComfyUI's model manager when it's time to load - the model onto the target device (usually the GPU). Our responsibility here - is to ensure the model weights are loaded from disk if they haven't been already. - """ - target_device = self.load_device - - if self.model.whisper_model is None: - logger.info(f"Loading Whisper model '{self.model.model_name}' to {target_device}...") - self.model.load_model(target_device) - self.size = self.model.model_loaded_weight_memory - else: - logger.info(f"Whisper model '{self.model.model_name}' already in memory.") - - return super().patch_model(device_to=target_device, *args, **kwargs) - - def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): - """ - Offload the Whisper model to free up VRAM. - """ - if unpatch_weights: - logger.info(f"Offloading Whisper model '{self.model.model_name}' to {device_to}...") - self.model.whisper_model = None - self.model.model_loaded_weight_memory = 0 - mm.soft_empty_cache() - return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) - +from pydub import AudioSegment +from pydub.silence import detect_nonsilent +import tempfile +import time +import re class ApplyWhisperNode: languages_by_name = None @@ -84,14 +17,15 @@ def INPUT_TYPES(s): return { "required": { "audio": ("AUDIO",), - "model": (['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large', 'large-v3-turbo', 'turbo'],), + "model": (["base", "tiny", "small", "medium", "large"],), }, "optional": { "language": ( ["auto"] + [s.capitalize() for s in sorted(list(whisper.tokenizer.LANGUAGES.values())) ], ), - "prompt": ("STRING", {"default":""}), + "min_silence_len": ("INT", {"default": 300, "min": 100, "max": 2000, "step": 50}), + "silence_thresh": ("INT", {"default": -40, "min": -60, "max": -20, "step": 5}), } } @@ -100,68 +34,136 @@ def INPUT_TYPES(s): FUNCTION = "apply_whisper" CATEGORY = "whisper" - def apply_whisper(self, audio, model, language, prompt): + def remove_punctuation(self, text): + """移除所有标点符号并用空格替换""" + # 使用正则表达式移除所有标点符号 + text = re.sub(r'[^\w\s]', ' ', text) + # 将多个连续空格替换为单个空格 + text = re.sub(r'\s+', ' ', text) + return text.strip() + def apply_whisper(self, audio, model, language="auto", min_silence_len=300, silence_thresh=-40): # save audio bytes from VHS to file temp_dir = folder_paths.get_temp_directory() os.makedirs(temp_dir, exist_ok=True) audio_save_path = os.path.join(temp_dir, f"{uuid.uuid1()}.wav") - torchaudio.save(audio_save_path, audio['waveform'].squeeze( - 0), audio["sample_rate"]) - - cache_key = model - if cache_key not in WHISPER_PATCHER_CACHE: - load_device = mm.get_torch_device() - download_root = os.path.join(folder_paths.models_dir, WHISPER_MODEL_SUBDIR) - logger.info(f"Creating Whisper ModelPatcher for {model} on device {load_device}") - - model_wrapper = WhisperModelWrapper(model, download_root) - patcher = WhisperPatcher( - model=model_wrapper, - load_device=load_device, - offload_device=mm.unet_offload_device(), - size=0 # Will be set when model loads - ) - WHISPER_PATCHER_CACHE[cache_key] = patcher - - patcher = WHISPER_PATCHER_CACHE[cache_key] - - mm.load_model_gpu(patcher) - whisper_model = patcher.model.whisper_model - - if whisper_model is None: - logger.error("Whisper model failed to load. Please check logs for errors.") - raise RuntimeError(f"Failed to load Whisper model: {model}") - - transcribe_args = {"initial_prompt": prompt} + torchaudio.save(audio_save_path, audio['waveform'].squeeze(0), audio["sample_rate"]) + # Load audio with pydub for silence detection + audio_segment = AudioSegment.from_wav(audio_save_path) + + # Detect non-silent segments + non_silent_segments = detect_nonsilent( + audio_segment, + min_silence_len=min_silence_len, + silence_thresh=silence_thresh, + seek_step=10 + ) + + # Load whisper model + model = whisper.load_model(model) + transcribe_args = {} if language != "auto": if ApplyWhisperNode.languages_by_name is None: ApplyWhisperNode.languages_by_name = {v.lower(): k for k, v in whisper.tokenizer.LANGUAGES.items()} transcribe_args['language'] = ApplyWhisperNode.languages_by_name[language.lower()] - result = whisper_model.transcribe(audio_save_path, word_timestamps=True, **transcribe_args) - - segments = result['segments'] segments_alignment = [] words_alignment = [] - - for segment in segments: - # create segment alignments - segment_dict = { - 'value': segment['text'].strip(), - 'start': segment['start'], - 'end': segment['end'] - } - segments_alignment.append(segment_dict) - - # create word alignments - for word in segment["words"]: - word_dict = { - 'value': word["word"].strip(), - 'start': word["start"], - 'end': word['end'] + full_text = "" + + # Process each non-silent segment + for i, (start_ms, end_ms) in enumerate(non_silent_segments): + # Convert to seconds + start_time = start_ms / 1000.0 + end_time = end_ms / 1000.0 + + # Extract segment audio + segment_audio = audio_segment[start_ms:end_ms] + + # Create a temporary file path + temp_audio_path = os.path.join(temp_dir, f"temp_segment_{uuid.uuid4()}.wav") + + # Export segment to file + segment_audio.export(temp_audio_path, format="wav") + + try: + # Transcribe segment + result = model.transcribe(temp_audio_path, word_timestamps=True, **transcribe_args) + + # Add segment to alignment list + if result['segments']: + segment = result['segments'][0] + # 移除标点符号并用空格替换 + segment_text = self.remove_punctuation(segment['text']) + segment_dict = { + 'value': segment_text, + 'start': segment['start'] + start_time, # Adjust time to absolute position + 'end': segment['end'] + start_time + } + segments_alignment.append(segment_dict) + full_text += segment_text + " " + + # Add words to alignment list + for word in segment.get("words", []): + # 移除标点符号并用空格替换 + word_text = self.remove_punctuation(word["word"]) + word_dict = { + 'value': word_text, + 'start': word["start"] + start_time, # Adjust time to absolute position + 'end': word['end'] + start_time + } + words_alignment.append(word_dict) + except Exception as e: + print(f"Error transcribing segment {i}: {e}") + finally: + # Try to delete the temporary file with retries + self._safe_delete_file(temp_audio_path) + + # If no segments were found, fall back to original method + if not segments_alignment: + result = model.transcribe(audio_save_path, word_timestamps=True, **transcribe_args) + segments = result['segments'] + + for segment in segments: + # 移除标点符号并用空格替换 + segment_text = self.remove_punctuation(segment['text']) + # create segment alignments + segment_dict = { + 'value': segment_text, + 'start': segment['start'], + 'end': segment['end'] } - words_alignment.append(word_dict) - - return (result["text"].strip(), segments_alignment, words_alignment) + segments_alignment.append(segment_dict) + full_text += segment_text + " " + + # create word alignments + for word in segment["words"]: + # 移除标点符号并用空格替换 + word_text = self.remove_punctuation(word["word"]) + word_dict = { + 'value': word_text, + 'start': word["start"], + 'end': word['end'] + } + words_alignment.append(word_dict) + + # Clean up audio file + self._safe_delete_file(audio_save_path) + + return (full_text.strip(), segments_alignment, words_alignment) + + def _safe_delete_file(self, file_path, max_retries=5, delay=0.1): + """Safely delete a file with retries to handle file locking issues""" + for i in range(max_retries): + try: + if os.path.exists(file_path): + os.remove(file_path) + return True + except (PermissionError, OSError) as e: + if i < max_retries - 1: + time.sleep(delay) + else: + print(f"Warning: Could not delete file {file_path}: {e}") + return False + return False \ No newline at end of file From 002514625f8c68ee8fef82a5611adf1a51a093f6 Mon Sep 17 00:00:00 2001 From: biyuxu <50160635+biyuxu@users.noreply.github.com> Date: Sun, 31 Aug 2025 14:27:24 +0800 Subject: [PATCH 2/2] Add files via upload --- apply_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/apply_whisper.py b/apply_whisper.py index 0a9cd11..80a2999 100644 --- a/apply_whisper.py +++ b/apply_whisper.py @@ -24,7 +24,7 @@ def INPUT_TYPES(s): ["auto"] + [s.capitalize() for s in sorted(list(whisper.tokenizer.LANGUAGES.values())) ], ), - "min_silence_len": ("INT", {"default": 300, "min": 100, "max": 2000, "step": 50}), + "min_silence_len": ("INT", {"default": 300, "min": 50, "max": 1000, "step": 50}), "silence_thresh": ("INT", {"default": -40, "min": -60, "max": -20, "step": 5}), } } @@ -35,9 +35,9 @@ def INPUT_TYPES(s): CATEGORY = "whisper" def remove_punctuation(self, text): - """移除所有标点符号并用空格替换""" - # 使用正则表达式移除所有标点符号 - text = re.sub(r'[^\w\s]', ' ', text) + """移除所有标点符号,但保留小数点和百分号,并用空格替换其他符号""" + # 使用正则表达式移除非保留字符 + text = re.sub(r'[^\w\s.%]', ' ', text) # 将多个连续空格替换为单个空格 text = re.sub(r'\s+', ' ', text) return text.strip()