diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 49dacb6fb1..4a3334dd6b 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -513,7 +513,56 @@ process: mem_required: '9GB' - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. + +# When use HumanVBench mapper, keep_stats_in_res_ds should be set true + + - video_human_tracks_extraction_mapper: # Get the body and face trajectory bounding box of people in one shot of the video. To ensure correctness, it should be applied after video_split_by_scene_mapper + face_track_bbox_path: /home/daoyuan_mm/data-juicer/tmptreciept/tpt # The storage location of the bounding box tracks of the characters in the video + YOLOv8_human_model_path: ./thirdparty/humanvbench_models/YOLOv8_human/weights/best.pt + mem_required: '10GB' + + - video_humantrack_face_demographic_mapper: # Get the facial demographics of each person based on the results of video_human_tracks_extraction_mapper + original_data_save_path: /home/daoyuan_mm/data-juicer/tmptreciept/tpt2 # The location where the specific results of each frame's detection are stored + detect_interval: 5 + + - video_audio_attribute_mapper: # If the audio is speech, classify the gender and age of the speech + hf_audio_mapper: '/mnt/daoyuan_open_research/zt_data/pt_model/wav2vec2-large-robust-24-ft-age-gender' # Huggingface model name for speech age and gender classification + mem_required: '7GB' + + - video_captioning_from_human_tracks_mapper: # Based on the results of video_human_tracks_extraction_mapper, focus on the single person in the video for captioning + video_describe_model_path: /mnt/daoyuan_open_research/zt_data/pt_model/videollm/VideoLLaMA3-7B # model path to sharegpt4video-8b + trust_remote_code: true + tempt_video_path: /home/daoyuan_mm/data-juicer/tmptreciept/tpt2 # Used to store temporary videos that will be removed finally. + mem_required: '40GB' + + - video_captioning_face_attribute_emotion_mapper: # Based on the results of video_human_tracks_extraction_mapper, focus on judging the gender, age, and race of a single person in the video + face_track_query: Please only describe the appearance and facial emotions of the person in the video in detail. Don't mention the background. Less than 80 words. + trust_remote_code: true + cropping_face_video_tempt_path: /home/daoyuan_mm/data-juicer/tmptreciept/tpt2 # Used to store temporary videos + video_describe_model_path: /mnt/daoyuan_open_research/zt_data/pt_model/videollm/VideoLLaMA3-7B # Huggingface model DAMO-NLP-SG/VideoLLaMA2-7B-16F + mem_required: '40GB' + + - video_active_speaker_mapper: # Based on the results of video_human_tracks_extraction_mapper, determine whether each person is an active speaker + tempt_save_path: /home/daoyuan_mm/data-juicer/tmptreciept/tpt2 # Used to store temporary videos + Light_ASD_model_path: /home/daoyuan_mm/data-juicer/thirdparty/humanvbench_models/Light-ASD/weight/finetuning_TalkSet.model + acitve_threshold: 15 + mem_required: '10GB' + + + - video_audio_speech_ASR_mapper: # Automatic speech recognition from video speech + model_dir_ASR: '/mnt/daoyuan_open_research/zt_data/pt_model/SenseVoiceSmall' # Huggingface model FunAudioLLM/SenseVoiceSmall + mem_required: '20GB' + + - video_audio_speech_emotion_mapper: # Speech emotion recognition from video speech + model_dir_emo: '/mnt/daoyuan_open_research/zt_data/pt_model/SenseVoiceSmall' # # Huggingface model FunAudioLLM/SenseVoiceSmall + mem_required: '20GB' + # Filter ops + - video_face_ratio_filter: # Filter to retain human-centric videos + threshold: 0.65 # The lower limit of the ratio of frames with faces to the total number of video frames + detect_interval: 4 + any_or_all: any + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. tokenization: false # whether to count the ratio of alphanumeric to the total number of tokens. min_ratio: 0.0 # the min ratio of filter range diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 8cb986b2b3..259284c747 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -43,6 +43,7 @@ from .video_watermark_filter import VideoWatermarkFilter from .word_repetition_filter import WordRepetitionFilter from .words_num_filter import WordsNumFilter +from .video_face_ratio_filter import VideoFaceRatioFilter __all__ = [ 'AlphanumericFilter', 'AudioDurationFilter', 'AudioNMFSNRFilter', @@ -61,7 +62,7 @@ 'VideoMotionScoreFilter', 'VideoMotionScoreRaftFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter', 'VideoResolutionFilter', 'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter', - 'WordRepetitionFilter', 'WordsNumFilter' + 'WordRepetitionFilter', 'WordsNumFilter', 'VideoFaceRatioFilter' ] NON_STATS_FILTERS = [ diff --git a/data_juicer/ops/filter/video_face_ratio_filter.py b/data_juicer/ops/filter/video_face_ratio_filter.py new file mode 100644 index 0000000000..ba54ab0f50 --- /dev/null +++ b/data_juicer/ops/filter/video_face_ratio_filter.py @@ -0,0 +1,139 @@ +import av +import numpy as np +from jsonargparse.typing import ClosedUnitInterval +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import (load_data_with_context, load_video, + pil_to_opencv, pil_to_opencv, process_each_frame) +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +import psutil +import gc,os + +OP_NAME = 'video_face_ratio_filter' + +import cv2,dlib +from PIL import ImageFilter + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoFaceRatioFilter(Filter): + """Keep data samples whose videos' durations are within a specified range. + """ + + def __init__(self, + threshold: ClosedUnitInterval = 0.8, + detect_interval: int = 1, + any_or_all: str = 'all', + *args, + **kwargs): + """ + Initialization method. + + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.threshold = threshold + + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + # Initialize face detector + self.detector = dlib.get_frontal_face_detector() + + + self.detect_interval = detect_interval + + + def compute_stats_single(self, sample, rank=None, context=False): + # check if it's computed already + if StatsKeys.video_face_exist in sample[Fields.stats]: + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + video_faces_ratio = {} + + # face_detect_S3FD = get_model(self.detector_key, rank=rank) + + process = psutil.Process(os.getpid()) + # memory_before = process.memory_info().rss / 1024 ** 2 # MB + + + for video_key in loaded_video_keys: + try: + with av.open(video_key) as container: + # getting video stream + video_stream = next(s for s in container.streams if s.type == 'video') + # iterate over the video frame and detect faces + frame_counter = 0 + total_frames = 0 + frames_with_face = 0 + detect_num = 0 + for packet in container.demux(video_stream): + try: + for frame in packet.decode(): + total_frames += 1 + frame_counter += 1 + + if frame_counter % self.detect_interval == 0: + detect_num = detect_num + 1 + img = frame.to_image() + image = pil_to_opencv(img) + # imageNumpy = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # faces = face_detect_S3FD.detect_faces(imageNumpy, conf_th=0.9, scales=[0.25]) + faces = self.detector(image) + if len(faces) > 0: + frames_with_face += 1 + except Exception as e: + print(f"Frame decoding error in video {video_key}: {e}") + frames_with_face = 0 + detect_num = 0 + + # calculate the proportion of the number of face frames + if detect_num > 0: + face_ratio = frames_with_face / detect_num + else: + face_ratio = 0.0 + video_faces_ratio[video_key] = face_ratio + except av.AVError as e: + print(f"Error opening video {video_key}: {e}") + video_faces_ratio[video_key] = 0.0 + finally: + container.close() + + video_faces_ratio[video_key] = face_ratio + + # get video faces ratio + sample[Fields.stats][StatsKeys.video_face_exist] = [ + video_faces_ratio[video_key] for video_key in sample[self.video_key] + ] + + memory_after = process.memory_info().rss / 1024 ** 2 # MB + print(f"Memory Usage: {memory_after:.2f} MB") + + gc.collect() + + return sample + + def process_single(self, sample): + video_faces_ratio = sample[Fields.stats][StatsKeys.video_face_exist] + keep_bools = np.array([ + duration >= self.threshold + for duration in video_faces_ratio + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 8ffe7cc8e8..d7d2e9a6cb 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -73,6 +73,15 @@ from .video_tagging_from_audio_mapper import VideoTaggingFromAudioMapper from .video_tagging_from_frames_mapper import VideoTaggingFromFramesMapper from .whitespace_normalization_mapper import WhitespaceNormalizationMapper +from .video_active_speaker_mapper import VideoActiveSpeakerMapper +from .video_audio_attribute_mapper import VideoAudioAttributeMapper +from .video_audio_speech_ASR_mapper import VideoAudioSpeechASRMapper +from .video_audio_speech_emotion_mapper import VideoAudioSpeechEmotionMapper +from .video_captioning_face_attribute_emotion_mapper import VideoCaptioningFaceAttributeEmotionMapper +from .video_captioning_from_human_tracks_mapper import VideoCaptioningFromHumanTracksMapper +from .video_human_tracks_extraction_mapper import VideoHumanTracksExtractionMapper +from .video_captioning_face_attribute_emotion_mapper import VideoCaptioningFaceAttributeEmotionMapper +from .video_humantrack_face_demographic_mapper import VideoHumantrackFaceDemographicMapper __all__ = [ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', @@ -105,5 +114,9 @@ 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', - 'WhitespaceNormalizationMapper' + 'WhitespaceNormalizationMapper','VideoActiveSpeakerMapper', + 'VideoAudioAttributeMapper', 'VideoAudioSpeechASRMapper', + 'VideoCaptioningFaceAttributeEmotionMapper','VideoCaptioningFromHumanTracksMapper', + 'VideoHumanTracksExtractionMapper', 'VideoCaptioningFaceAttributeEmotionMapper', + 'VideoHumantrackFaceDemographicMapper', 'VideoAudioSpeechEmotionMapper' ] diff --git a/data_juicer/ops/mapper/video_active_speaker_mapper.py b/data_juicer/ops/mapper/video_active_speaker_mapper.py new file mode 100644 index 0000000000..efc709de78 --- /dev/null +++ b/data_juicer/ops/mapper/video_active_speaker_mapper.py @@ -0,0 +1,208 @@ +from data_juicer.utils.ASD_mapper_utils import get_video_array_cv2,evaluate_network, \ + crop_video_with_facetrack, longest_continuous_actives + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS +from data_juicer.utils.model_utils import get_model, prepare_model +import gc,os + +OP_NAME = 'video_active_speaker_mapper' + +import torch +import sys +sys.path.append('./thirdparty/humanvbench_models/Light-ASD') +from data_juicer.utils.constant import Fields, MetaKeys +import tempfile +import shutil, pickle +from shutil import rmtree +import os, subprocess +import tqdm, glob +# from model.faceDetector.s3fd import S3FD + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoActiveSpeakerMapper(Mapper): + _accelerator = 'cuda' + _batched_op = True + + """ + """ + + _default_kwargs = {'upsample_num_times': 0} + + def __init__(self, + tempt_save_path: str = './HumanVBenchRecipe/dj_ASD_tempt', + Light_ASD_model_path: str = './thirdparty/humanvbench_models/Light-ASD/weight/finetuning_TalkSet.model', + acitve_threshold: int = 15, + active_speaker_flag: str = MetaKeys.active_speaker_flag, + *args, + **kwargs): + """ + Initialization method. + + :param blur_type: + """ + kwargs.setdefault('mem_required', '10GB') + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + self.acitve_threshold = acitve_threshold + + self.tempt_save_path = tempt_save_path + + # Initialize ASD model + self.ASD_model_key = prepare_model(model_type='Light_ASD', + pretrained_model_name_or_path=Light_ASD_model_path) + + self.active_speaker_flag = active_speaker_flag + + def active_speaker_detection_revise(self, active_score,is_child_descrip,speech_audio,face_gender): + speech_child = speech_audio['child'][0] + speech_male = speech_audio['male'][0] + speech_female = speech_audio['female'][0] + if speech_male > speech_female: + speech_gender = 'Man' + speech_gender_confidence = speech_male + else: + speech_gender = 'Woman' + speech_gender_confidence = speech_female + + if 'No' in is_child_descrip or 'no' in is_child_descrip: + is_child_apperance = False + else: + is_child_apperance = True + + if speech_child < 0.1: + is_child_voice = False + elif speech_audio['Age'][0]<=12: + is_child_voice = True + else: + is_child_voice = 'Not Sure' + + # Consistency detection: only perform false positive detection on positive samples + if active_score>self.acitve_threshold: + speak_active = True + # age consistency test: + if not is_child_voice == 'Not Sure': + if is_child_apperance == is_child_voice: + # gender consistency test + if speech_gender_confidence > 0.85 and float(face_gender[1]) > 0.85: + if not speech_gender == face_gender[0]: + speak_active = False + else: + speak_active = False + return speak_active + else: + return False + + + def process_single(self, sample, rank=None): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_tagging_from_audio_mapper.") + + if not MetaKeys.human_track_data_path in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_human_tracks_extraction_mapper.") + + if not MetaKeys.audio_speech_attribute in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_audio_attribute_mapper.") + + if not MetaKeys.video_facetrack_attribute_demographic in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_humantrack_face_demographic_mapper.") + + if not MetaKeys.video_track_is_child in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_captioning_from_human_tracks_mapper.") + + loaded_video_keys = sample[self.video_key] + audio_speech_attribute = sample[Fields.meta][MetaKeys.audio_speech_attribute] + face_demographic = sample[Fields.meta][MetaKeys.video_facetrack_attribute_demographic][0] + child_flag = sample[Fields.meta][MetaKeys.video_track_is_child][0] + + Total_result = [] + + temp_dir = tempfile.mkdtemp(dir=self.tempt_save_path) + pyaviPath = os.path.join(temp_dir, 'pyavi') + pyframesPath = os.path.join(temp_dir, 'pyframes') + pyworkPath = os.path.join(temp_dir, 'pywork') + pycropPath = os.path.join(temp_dir, 'pycrop') + if os.path.exists(temp_dir): + rmtree(temp_dir) + + audio_tag = sample[Fields.meta][MetaKeys.video_audio_tags] + asd_detection_model = get_model(self.ASD_model_key, rank=rank) + + for id_out,video_key in enumerate(loaded_video_keys): + os.makedirs(pyaviPath, exist_ok = False) # The path for the input video, input audio, output video + os.makedirs(pyframesPath, exist_ok = False) # Save all the video frames + os.makedirs(pyworkPath, exist_ok = False) # Save the results in this process by the pckl method + os.makedirs(pycropPath, exist_ok = False) # Save the detected face clips (audio+video) in this process + + # Extract audio + audio_is_empty = False + audioFilePath = os.path.join(pyaviPath, 'audio.wav') + command = ("ffmpeg -y -i '%s' -qscale:a 0 -ac 1 -vn -threads %d -ar 16000 %s -loglevel panic" % \ + (video_key, 10, audioFilePath)) + if audio_tag[id_out] == "EMPTY": + audio_is_empty = True + else: + subprocess.call(command, shell=True, stdout=None) + + + video_array = get_video_array_cv2(video_key) + + def load_pkl(file_path): + with open(file_path, 'rb') as file: + return pickle.load(file) + # get allTracks + allTracks = [load_pkl(item['bbox_path']) for item in sample[Fields.meta][MetaKeys.human_track_data_path][id_out]] + + # Face clips cropping + for ii, track in tqdm.tqdm(enumerate(allTracks), total = len(allTracks)): + result = crop_video_with_facetrack(video_array, track, os.path.join(pycropPath, '%05d' % ii), audioFilePath, audio_is_empty) + if not result: + raise ValueError("something wrong with crop_video_with_facetrack.") + + # Active Speaker Detection + if audio_tag[id_out] == 'Speech': + files = glob.glob("%s/*.avi"%pycropPath) + files.sort() + try: + scores = evaluate_network(files, asd_detection_model, pycropPath) + except: + scores = [[-10000]]* len(allTracks) + + else: + scores = [[-10000]]* len(allTracks) + + for id in range(len(scores)): + allTracks[id]['active_scores'] = scores[id] + + update_track = allTracks + # for validation + # visualization(vidTracks, scores, video_array, pyaviPath) + + shutil.rmtree(temp_dir) + + speak_flag_for_tracks_in_a_video = [] + for track_idx,track_i in enumerate(update_track): + active_count = longest_continuous_actives(track_i['active_scores']) + audio_attri = audio_speech_attribute[id_out][0] + is_child_descrip = child_flag[id_out][track_idx][0] + face_gender = face_demographic[id_out][track_idx]['gender'] + flag = self.active_speaker_detection_revise(active_count, is_child_descrip, audio_attri, face_gender) + speak_flag_for_tracks_in_a_video.append(flag) + + + Total_result.append(speak_flag_for_tracks_in_a_video) + torch.cuda.empty_cache() + + sample[Fields.meta][self.active_speaker_flag] = Total_result + + gc.collect() + torch.cuda.empty_cache() + + return sample diff --git a/data_juicer/ops/mapper/video_audio_attribute_mapper.py b/data_juicer/ops/mapper/video_audio_attribute_mapper.py new file mode 100644 index 0000000000..03955b0442 --- /dev/null +++ b/data_juicer/ops/mapper/video_audio_attribute_mapper.py @@ -0,0 +1,98 @@ +import librosa +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.mm_utils import extract_audio_from_video +from thirdparty.humanvbench_models.audio_code.wav2vec_age_gender import process_func,AgeGenderModel +from ..base_op import OPERATORS, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +NAME = 'video_audio_attribute_mapper' +CHECK_PKGS = [ + 'transformers', 'transformers_stream_generator', 'einops', 'accelerate', + 'tiktoken' +] + +from data_juicer.utils.model_utils import get_model, prepare_model + + + +@OPERATORS.register_module(NAME) +class VideoAudioAttributeMapper(Mapper): + """Mapper to caption a video according to its audio streams based on + Qwen-Audio model. + """ + _accelerator = 'cuda' + _batched_op = True + + def __init__(self, + hf_audio_mapper: str = None, + tag_field_name: str = MetaKeys.audio_speech_attribute, + *args, **kwargs): + """ + Initialization method. + + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only captioned sample in the + final datasets and the original sample will be removed. It's True + in default. + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault('mem_required', '7GB') + super().__init__(*args, **kwargs) + self._model_sampling_rate = 16000 + + self._hf_summarizer = hf_audio_mapper if hf_audio_mapper else 'audeering/wav2vec2-large-robust-24-ft-age-gender' # noqa: E501 + self.model_key = prepare_model( + model_type='wav2vec2_age_gender', + pretrained_model_name_or_path=self._hf_summarizer, + ) + self.tag_field_name = tag_field_name + + def process_single(self, sample, rank=None): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return [] + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_audio_attribute_mapper must be operated after video_tagging_from_audio_mapper.") + + # get paths of all video(s) + loaded_video_keys = sample[self.video_key] + audio_tag = sample[Fields.meta][MetaKeys.video_audio_tags] + + Total_result = [] + # get models + model, processor = get_model(self.model_key, rank, self.use_cuda()) + + for i,video in enumerate(loaded_video_keys): + audio_tag_this = audio_tag[i] + if not audio_tag_this == 'Speech': + Total_result.append([]) + else: + ys, srs, valid_indexes = extract_audio_from_video( + video, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + Total_result.append([]) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + + Age_female_male_child = process_func(y, sr, processor, model, device=model.device)[0] + Age_female_male_child_dict = {} + Age_female_male_child_dict['Age'] = [int(Age_female_male_child[0]*100)] + Age_female_male_child_dict['female'] = [Age_female_male_child[1]] + Age_female_male_child_dict['male'] = [Age_female_male_child[2]] + Age_female_male_child_dict['child'] = [Age_female_male_child[3]] + Total_result.append([Age_female_male_child_dict]) + + sample[Fields.meta][self.tag_field_name] = Total_result + return sample diff --git a/data_juicer/ops/mapper/video_audio_speech_ASR_mapper.py b/data_juicer/ops/mapper/video_audio_speech_ASR_mapper.py new file mode 100644 index 0000000000..3e08ffff53 --- /dev/null +++ b/data_juicer/ops/mapper/video_audio_speech_ASR_mapper.py @@ -0,0 +1,105 @@ +import librosa +from data_juicer.utils.mm_utils import extract_audio_from_video +from data_juicer.utils.model_utils import get_model, prepare_model +from ..base_op import OPERATORS, Mapper +import gc +from data_juicer.utils.constant import Fields, MetaKeys + +OP_NAME = 'video_audio_speech_ASR_mapper' + +import torch +torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class VideoAudioSpeechASRMapper(Mapper): + """Mapper to generate video tags from audio streams extracted by video + using the Audio Spectrogram Transformer. + """ + _accelerator = 'cuda' + _batched_op = True + + def __init__(self, + model_dir_ASR = 'FunAudioLLM/SenseVoiceSmall', + speech_ASR: str = MetaKeys.speech_ASR, + *args, + **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault('mem_required', '20GB') + super().__init__(*args, **kwargs) + self._batched_op = True + self._model_sampling_rate = 16000 + self.model_dir_ASR = model_dir_ASR + + self.model_key = prepare_model( + model_type='SenseVoiceSmall', + pretrained_model_name_or_path=model_dir_ASR, + ) + + self.speech_ASR = speech_ASR + + def process_single(self, sample, rank=None): + # check if it's generated already + if MetaKeys.speech_emotion in sample[Fields.meta]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_tagging_from_audio_mapper.") + + + # load video paths + loaded_video_keys = sample[self.video_key] + audio_tags = sample[Fields.meta][MetaKeys.video_audio_tags] + + ASR_model, kwargs1= get_model(self.model_key, rank=rank) + + # model, feature_extractor = get_model(self.model_key, rank=rank) + video_audio_tags = [] + + for id,video_path in enumerate(loaded_video_keys): + if audio_tags[id] == 'Speech': + # only extract audio data and sr for index 0 for now + ys, srs, valid_indexes = extract_audio_from_video( + video_path, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + video_audio_tags.append(self._no_audio_label) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + + inputs = torch.tensor(y).to(next(ASR_model.parameters()).device) + with torch.no_grad(): + output_ASR_emo = ASR_model.inference( + data_in=inputs, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + **kwargs1, + ) + + video_audio_tags.append({'language':output_ASR_emo[0][0]['text'].split('<|',1)[-1].split('|>')[0], 'asr': output_ASR_emo[0][0]['text'].split('|>',4)[-1]}) + else: + video_audio_tags.append('') + + sample[Fields.meta][self.speech_ASR] = video_audio_tags + gc.collect() + torch.cuda.empty_cache() + return sample diff --git a/data_juicer/ops/mapper/video_audio_speech_emotion_mapper.py b/data_juicer/ops/mapper/video_audio_speech_emotion_mapper.py new file mode 100644 index 0000000000..3b9f27670b --- /dev/null +++ b/data_juicer/ops/mapper/video_audio_speech_emotion_mapper.py @@ -0,0 +1,104 @@ +import librosa +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.mm_utils import extract_audio_from_video +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper +import gc + +OP_NAME = 'video_audio_speech_emotion_mapper' + +import torch +torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class VideoAudioSpeechEmotionMapper(Mapper): + """Mapper to generate video tags from audio streams extracted by video + using the Audio Spectrogram Transformer. + """ + _accelerator = 'cuda' + _batched_op = True + + def __init__(self, + model_dir_emo='FunAudioLLM/SenseVoiceSmall', + speech_Emo: str = MetaKeys.speech_emotion, + *args, + **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault('mem_required', '20GB') + super().__init__(*args, **kwargs) + self._batched_op = True + self._model_sampling_rate = 16000 + self.model_dir_emo = model_dir_emo + + self.model_key = prepare_model( + model_type='SenseVoiceSmall', + pretrained_model_name_or_path=self.model_dir_emo, + ) + + self.speech_Emo = speech_Emo + + def process_single(self, sample, rank=None): + # check if it's generated already + if MetaKeys.speech_emotion in sample[Fields.meta]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_tagging_from_audio_mapper.") + + + # load video paths + loaded_video_keys = sample[self.video_key] + audio_tags = sample[Fields.meta][MetaKeys.video_audio_tags] + + Emo_model, kwargs1= get_model(self.model_key, rank=rank) + + video_audio_tags = [] + for id,video_path in enumerate(loaded_video_keys): + if audio_tags[id] == 'Speech': + # only extract audio data and sr for index 0 for now + ys, srs, valid_indexes = extract_audio_from_video( + video_path, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + video_audio_tags.append(self._no_audio_label) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + + inputs = torch.tensor(y).to(next(Emo_model.parameters()).device) + with torch.no_grad(): + output_emo = Emo_model.inference( + data_in=inputs, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + **kwargs1, + ) + + video_audio_tags.append(output_emo[0][0]['text'].split('<|',2)[-1].split('|>')[0]) + else: + video_audio_tags.append('') + + sample[Fields.meta][self.speech_Emo] = video_audio_tags + gc.collect() + torch.cuda.empty_cache() + return sample diff --git a/data_juicer/ops/mapper/video_captioning_face_attribute_emotion_mapper.py b/data_juicer/ops/mapper/video_captioning_face_attribute_emotion_mapper.py new file mode 100644 index 0000000000..5fcc683b8b --- /dev/null +++ b/data_juicer/ops/mapper/video_captioning_face_attribute_emotion_mapper.py @@ -0,0 +1,150 @@ +import numpy as np +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS +from data_juicer.utils.ASD_mapper_utils import get_video_array_cv2 +import gc + +OP_NAME = 'video_captioning_face_attribute_emotion_mapper' + +import torch, os, tempfile, shutil +from shutil import rmtree +import pickle, copy, cv2 +import transformers # noqa: F401 + +# avoid hanging when calling clip in multiprocessing +torch.set_num_threads(1) +import sys + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoCaptioningFaceAttributeEmotionMapper(Mapper): + _accelerator = 'cuda' + _batched_op = True + + def __init__( + self, + face_track_query: str = "Please describe the person's facial expression, tell me the person's emotion through the video, like Happiness, Excitement, Love, Gratitude, Relief, Pride, Anger, Sadness, Fear, Guilt, Shame, Disgust, Surprise, Confusion, Curiosity, Boredom ...", + trust_remote_code: bool = False, + cropping_face_video_tempt_path = './tempt_video/tmp_video_remove', + video_describe_model_path: str = 'DAMO-NLP-SG/VideoLLaMA3-7B', + video_facetrack_attribute_emotion: str = MetaKeys.video_facetrack_attribute_emotion, + *args, + **kwargs + ): + """ + Initialization method. + + :param hf_video_blip: video-blip model name on huggingface + to generate caption + """ + kwargs.setdefault('mem_required', '40GB') + super().__init__(*args, **kwargs) + + self._batched_op = True + self._accelerator = 'cuda' + self.context_param = 0.8 + + # self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. " + self.query = face_track_query + self.cropping_face_video_tempt_path = cropping_face_video_tempt_path + + self.video_describe_model_path = video_describe_model_path if video_describe_model_path else 'DAMO-NLP-SG/VideoLLaMA3-7B' + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=video_describe_model_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" + ) + + self.video_facetrack_attribute_emotion = video_facetrack_attribute_emotion + + + + def process_single(self, samples, rank=None): + + if not MetaKeys.human_track_data_path in samples[Fields.meta]: + raise ValueError("video_captioning_from_human_tracks_mapper must be operated after video_human_tracks_extraction_mapper.") + + + Total_information = [] + video_samples = samples[Fields.meta][MetaKeys.human_track_data_path] + loaded_video_keys = samples[self.video_key] + + cropping_face_video_tempt_path = tempfile.mkdtemp(dir=self.cropping_face_video_tempt_path) + if os.path.exists(cropping_face_video_tempt_path): + rmtree(cropping_face_video_tempt_path) + + os.makedirs(cropping_face_video_tempt_path, exist_ok = False) + model, processor = get_model(self.model_key, rank, self.use_cuda()) + for vedio_id,ASD_attribute_all_tracks_for_one_video in enumerate(video_samples): + if len(ASD_attribute_all_tracks_for_one_video) == 0: + Total_information.append([]) + continue + + description_for_each_track = [] + video_array = get_video_array_cv2(loaded_video_keys[vedio_id]) + for track_id,tracks_now in enumerate(ASD_attribute_all_tracks_for_one_video): + cs = self.context_param + + with open(tracks_now['bbox_path'], 'rb') as f: + bbox_data = pickle.load(f) + xys_bbox = bbox_data['xys_bbox'] + track_frame = bbox_data['frame'] + + face_video_out_path = os.path.join(cropping_face_video_tempt_path, loaded_video_keys[vedio_id].split('/')[-1][:-4] + '__' + str(track_id) + '.mp4') + vOut = cv2.VideoWriter(face_video_out_path, cv2.VideoWriter_fourcc(*'XVID'), 25, (224,224))# Write video + + start_frame_id_in = 0 + start_frame_id_out = track_frame[start_frame_id_in] # tag + while start_frame_id_in + 1 = minTrack: # Discard the shot frames less than minTrack frames + allTracks.extend(track_shot(faces[shot[0].frame_num:shot[1].frame_num])) # 'frames' to present this tracks' timestep, 'bbox' presents the location of the faces + + # Get face and human tracks + for ii, track in tqdm.tqdm(enumerate(allTracks), total = len(allTracks)): + result = get_face_and_human_tracks(video_array, track, human_detection_model) + if result: + vidTracks.append(result) + # merge + people_num_atleast, update_track = post_merge(vidTracks,video_array) + + for i in range(len(update_track)): + save_bbox_name = os.path.join(self.face_track_bbox_path, video_key.split("/")[-1][:-4] +'_'+str(i)+'.pkl') + xy_bbox = update_track[i]['track']['bbox'] + xys_bbox = update_track[i]['proc_track'] + xy_human_bbox = update_track[i]['human_bbox'] + frames = update_track[i]['track']['frame'] + bbox_dict = {'frame':frames, 'xy_bbox':xy_bbox, 'xys_bbox':xys_bbox, 'xy_human_bbox':xy_human_bbox} + f_save = open(save_bbox_name, 'wb') + pickle.dump(bbox_dict, f_save) + f_save.close() + del update_track[i]['human_bbox'] + del update_track[i]['proc_track'] + del update_track[i]['track'] + update_track[i]['bbox_path'] = save_bbox_name + + + Total_result.append(update_track) + min_people_in_video.append(people_num_atleast) + torch.cuda.empty_cache() + + sample[Fields.meta][self.tag_field_name_human_track_path] = Total_result + sample[Fields.meta][self.tag_field_name_people_num] = min_people_in_video + + gc.collect() + torch.cuda.empty_cache() + + return sample diff --git a/data_juicer/ops/mapper/video_humantrack_face_demographic_mapper.py b/data_juicer/ops/mapper/video_humantrack_face_demographic_mapper.py new file mode 100644 index 0000000000..d211134c59 --- /dev/null +++ b/data_juicer/ops/mapper/video_humantrack_face_demographic_mapper.py @@ -0,0 +1,195 @@ +import numpy as np +from data_juicer.utils.constant import Fields, MetaKeys +from deepface import DeepFace +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS +from data_juicer.utils.ASD_mapper_utils import get_video_array_cv2 +import gc + +OP_NAME = 'video_humantrack_face_demographic_mapper' + +import torch, os +import pickle + +# avoid hanging when calling clip in multiprocessing +torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoHumantrackFaceDemographicMapper(Mapper): + """Mapper to generate samples whose captions are generated based on + a video-to-text model and sampled video frame.""" + + def __init__( + self, + original_data_save_path = '', + detect_interval: int = 5, + tag_field_name: str = MetaKeys.video_facetrack_attribute_demographic, + *args, + **kwargs + ): + """ + Initialization method. + + :param hf_video_blip: video-blip model name on huggingface + to generate caption + """ + super().__init__(*args, **kwargs) + + self.interval = detect_interval + self.original_data_save_path = original_data_save_path + self.tag_field_name = tag_field_name + + def process_single(self, samples, rank=None, context=False): + if not MetaKeys.human_track_data_path in samples[Fields.meta]: + raise ValueError("video_humantrack_face_demographic_mapper must be operated after video_human_tracks_extraction_mapper.") + + Total_information = [] + video_samples = samples[Fields.meta][MetaKeys.human_track_data_path] + loaded_video_keys = samples[self.video_key] + + for vedio_id,ASD_attribute_all_tracks_for_one_video in enumerate(video_samples): + if len(ASD_attribute_all_tracks_for_one_video) == 0: + Total_information.append([]) + continue + description_for_each_track = [] + video_array = get_video_array_cv2(loaded_video_keys[vedio_id]) + for track_id,tracks_now in enumerate(ASD_attribute_all_tracks_for_one_video): + face_attribute_dict_with_framestamp = {} + + bbox_path = tracks_now['bbox_path'] + with open(bbox_path, 'rb') as f: + bbox_data = pickle.load(f) + xys_bbox = bbox_data['xys_bbox'] + track_frame = bbox_data['frame'] + + + total_len = len(track_frame) + if total_len > 75: + interval = int(total_len/15) + else: + interval = self.interval + + + start_frame_id_in = 0 + start_frame_id_out = track_frame[start_frame_id_in] # tag + cs = 0.5 + while start_frame_id_in + interval iouThres and iou > max_iou: + best_match = face + max_iou = iou + else: + break + + if best_match is not None: + track.append(best_match) + frameFaces.remove(best_match) + + if track == []: + break + elif len(track) > minTrack: + frameNum = np.array([ f['frame'] for f in track ]) + bboxes = np.array([np.array(f['bbox']) for f in track]) + frameI = np.arange(frameNum[0],frameNum[-1]+1) + bboxesI = [] + for ij in range(0,4): + interpfn = interp1d(frameNum, bboxes[:,ij]) + bboxesI.append(interpfn(frameI)) + bboxesI = np.stack(bboxesI, axis=1) + if max(np.mean(bboxesI[:,2]-bboxesI[:,0]), np.mean(bboxesI[:,3]-bboxesI[:,1])) > 1: + tracks.append({'frame':frameI,'bbox':bboxesI}) + return tracks + + +def find_human_bounding_box(face_bbox, human_bboxes): + head_x1, head_y1, head_x2, head_y2 = face_bbox + head_center_x = (head_x1 + head_x2)/2 + + candidate_bboxes = [] + + for human_bbox in human_bboxes: + human_x1, human_y1, human_x2, human_y2 = human_bbox + + if (human_x1 <= head_x1 and head_x2 <= human_x2) and (human_y1 <= head_y1 and head_y2 <= human_y2): + candidate_bboxes.append(human_bbox) + + if not candidate_bboxes: + return () + + # Select the human body bounding box with the smallest distance between (x1 + x2) / 2 and (x1 + x2) / 2 of face_bbox + closest_bbox = min(candidate_bboxes, key=lambda bbox: (((bbox[0] + bbox[2]) / 2) - head_center_x)**2 + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) + + return closest_bbox + +def update_negative_ones(values): + n = len(values) + i = 0 + + while i < n: + if values[i] == -1: + # Find the nearest number on the left + left_index = i - 1 + while left_index >= 0 and values[left_index] == -1: + left_index -= 1 + + # Find the nearest number on the right + right_index = i + 1 + while right_index < n and values[right_index] == -1: + right_index += 1 + + # Update the value of -1 + if left_index >= 0 and right_index < n: + left_value = values[left_index] + right_value = values[right_index] + values[i] = (left_value + right_value) / 2 + elif left_index >= 0: + values[i] = values[left_index] + elif right_index < n: + values[i] = values[right_index] + else: + raise ValueError("Unable to find valid values ​​on both the left and right to update -1 at index {i}") + i += 1 + + return values + + +def detect_and_mark_anomalies(data, window_size=7, std_multiplier=2): + data = np.array(data) + result = data.copy() + + for i in range(len(data)): + if data[i] > 0: + start = max(0, i - window_size) + end = min(len(data), i + window_size + 1) + neighbors = data[start:end] + + neighbors = np.delete(neighbors, np.where(neighbors == data[i])) + + positive_neighbors = neighbors[neighbors > 0] + + if len(positive_neighbors) < 2: + continue + + mean = np.mean(positive_neighbors) + std = np.std(positive_neighbors) + + if abs(data[i] - mean) > std * std_multiplier: + result[i] = -1 + + return result + + +def get_face_and_human_tracks(video_array, track, human_detection_pipeline): + dets = {'x':[], 'y':[], 's':[]} + for det in track['bbox']: # Read the tracks + dets['s'].append(max((det[3]-det[1]), (det[2]-det[0]))/2) + dets['y'].append((det[1]+det[3])/2) # crop center x + dets['x'].append((det[0]+det[2])/2) # crop center y + + # human_bounding_box + human_bbox = {'x1':[], 'y1':[], 'x2':[], 'y2':[]} + for in_id,out_track_id in enumerate(track['frame']): # Read the tracks + frame_ = video_array[out_track_id] + head_x1, head_y1, head_x2, head_y2 = track['bbox'][in_id] + human_bbox_list = demo(frame_, human_detection_pipeline) + result = find_human_bounding_box((head_x1, head_y1, head_x2, head_y2), human_bbox_list) + if result == (): + human_bbox['x1'].append(-1) + human_bbox['y1'].append(-1) + human_bbox['x2'].append(-1) + human_bbox['y2'].append(-1) + else: + human_bbox['x1'].append(result[0]) + human_bbox['y1'].append(result[1]) + human_bbox['x2'].append(result[2]) + human_bbox['y2'].append(result[3]) + if (np.array(human_bbox['x1'])<0).sum() > 0: + if all(element < 0 for element in human_bbox['x1']): + return False + human_bbox['x1'] = detect_and_mark_anomalies(human_bbox['x1'], window_size=30, std_multiplier=10) + human_bbox['x1'] = update_negative_ones(human_bbox['x1']) + if (np.array(human_bbox['y1'])<0).sum() > 0: + human_bbox['y1'] = detect_and_mark_anomalies(human_bbox['y1'], window_size=30, std_multiplier=10) + human_bbox['y1'] = update_negative_ones(human_bbox['y1']) + if (np.array(human_bbox['x2'])<0).sum() > 0: + human_bbox['x2'] = detect_and_mark_anomalies(human_bbox['x2'], window_size=30, std_multiplier=10) + human_bbox['x2'] = update_negative_ones(human_bbox['x2']) + if (np.array(human_bbox['y2'])<0).sum() > 0: + human_bbox['y2'] = detect_and_mark_anomalies(human_bbox['y2'], window_size=30, std_multiplier=10) + human_bbox['y2'] = update_negative_ones(human_bbox['y2']) + human_bbox['x1'] = signal.medfilt(human_bbox['x1'], kernel_size=5).tolist() + human_bbox['y1'] = signal.medfilt(human_bbox['y1'], kernel_size=5).tolist() + human_bbox['x2'] = signal.medfilt(human_bbox['x2'], kernel_size=5).tolist() + human_bbox['y2'] = signal.medfilt(human_bbox['y2'], kernel_size=5).tolist() + + return {'track':track, 'proc_track':dets, 'human_bbox':human_bbox} + +def crop_video_with_facetrack(video_array, track, cropFile, audioFilePath,is_empty=False): + if is_empty: + return True + + dets = track['xys_bbox'] + # CPU: crop the face clips + vOut = cv2.VideoWriter(cropFile + 't.avi', cv2.VideoWriter_fourcc(*'XVID'), 25, (224,224))# Write video + + for fidx, frame in enumerate(track['frame']): + cs = 0.4 + bs = dets['s'][fidx] # Detection box size + bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount + image = video_array[frame] + frame = numpy.pad(image, ((bsi,bsi), (bsi,bsi), (0, 0)), 'constant', constant_values=(110, 110)) + my = dets['y'][fidx] + bsi # BBox center Y + mx = dets['x'][fidx] + bsi # BBox center X + face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))] + vOut.write(cv2.resize(face, (224, 224))) + audioTmp = cropFile + '.wav' + audioStart = (track['frame'][0]) / 25 + audioEnd = (track['frame'][-1]+1) / 25 + vOut.release() + command = ("ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 -threads %d -ss %.3f -to %.3f %s -loglevel panic" % \ + (audioFilePath, 10, audioStart, audioEnd, audioTmp)) + output = subprocess.call(command, shell=True, stdout=None) # Crop audio file + _, audio = wavfile.read(audioTmp) + command = ("ffmpeg -y -i %st.avi -i %s -threads %d -c:v copy -c:a copy %s.avi -loglevel panic" % \ + (cropFile, audioTmp, 10, cropFile)) # Combine audio and video file + output = subprocess.call(command, shell=True, stdout=None) + os.remove(cropFile + 't.avi') + return True + + + +def crop_video(video_array, track, cropFile, audioFilePath, human_detection_pipeline,is_empty=False): + dets = {'x':[], 'y':[], 's':[]} + for det in track['bbox']: # Read the tracks + dets['s'].append(max((det[3]-det[1]), (det[2]-det[0]))/2) + dets['y'].append((det[1]+det[3])/2) # crop center x + dets['x'].append((det[0]+det[2])/2) # crop center y + + # human_bounding_box + human_bbox = {'x1':[], 'y1':[], 'x2':[], 'y2':[]} + for in_id,out_track_id in enumerate(track['frame']): # Read the tracks + frame_ = video_array[out_track_id] + head_x1, head_y1, head_x2, head_y2 = track['bbox'][in_id] + human_bbox_list = demo(frame_, human_detection_pipeline) + result = find_human_bounding_box((head_x1, head_y1, head_x2, head_y2), human_bbox_list) + if result == (): + human_bbox['x1'].append(-1) + human_bbox['y1'].append(-1) + human_bbox['x2'].append(-1) + human_bbox['y2'].append(-1) + else: + human_bbox['x1'].append(result[0]) + human_bbox['y1'].append(result[1]) + human_bbox['x2'].append(result[2]) + human_bbox['y2'].append(result[3]) + if (np.array(human_bbox['x1'])<0).sum() > 0: + if all(element < 0 for element in human_bbox['x1']): + return False + human_bbox['x1'] = update_negative_ones(human_bbox['x1']) + if (np.array(human_bbox['y1'])<0).sum() > 0: + human_bbox['y1'] = update_negative_ones(human_bbox['y1']) + if (np.array(human_bbox['x2'])<0).sum() > 0: + human_bbox['x2'] = update_negative_ones(human_bbox['x2']) + if (np.array(human_bbox['y2'])<0).sum() > 0: + human_bbox['y2'] = update_negative_ones(human_bbox['y2']) + human_bbox['x1'] = signal.medfilt(human_bbox['x1'], kernel_size=5).tolist() + human_bbox['y1'] = signal.medfilt(human_bbox['y1'], kernel_size=5).tolist() + human_bbox['x2'] = signal.medfilt(human_bbox['x2'], kernel_size=5).tolist() + human_bbox['y2'] = signal.medfilt(human_bbox['y2'], kernel_size=5).tolist() + + if is_empty: + return {'track':track, 'proc_track':dets, 'human_bbox':human_bbox} + + # CPU: crop the face clips + vOut = cv2.VideoWriter(cropFile + 't.avi', cv2.VideoWriter_fourcc(*'XVID'), 25, (224,224))# Write video + + for fidx, frame in enumerate(track['frame']): + cs = 0.4 + bs = dets['s'][fidx] # Detection box size + bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount + image = video_array[frame] + frame = numpy.pad(image, ((bsi,bsi), (bsi,bsi), (0, 0)), 'constant', constant_values=(110, 110)) + my = dets['y'][fidx] + bsi # BBox center Y + mx = dets['x'][fidx] + bsi # BBox center X + face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))] + vOut.write(cv2.resize(face, (224, 224))) + audioTmp = cropFile + '.wav' + audioStart = (track['frame'][0]) / 25 + audioEnd = (track['frame'][-1]+1) / 25 + vOut.release() + command = ("ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 -threads %d -ss %.3f -to %.3f %s -loglevel panic" % \ + (audioFilePath, 10, audioStart, audioEnd, audioTmp)) + output = subprocess.call(command, shell=True, stdout=None) # Crop audio file + _, audio = wavfile.read(audioTmp) + command = ("ffmpeg -y -i %st.avi -i %s -threads %d -c:v copy -c:a copy %s.avi -loglevel panic" % \ + (cropFile, audioTmp, 10, cropFile)) # Combine audio and video file + output = subprocess.call(command, shell=True, stdout=None) + os.remove(cropFile + 't.avi') + return {'track':track, 'proc_track':dets, 'human_bbox':human_bbox} + + +def evaluate_network(files, s, pycropPath): + # GPU: active speaker detection by pretrained model + allScores = [] + # durationSet = {1,2,4,6} # To make the result more reliable + durationSet = {1,1,1,2,2,2,3,3,4,5,6} # Use this line can get more reliable result + for file in tqdm.tqdm(files, total = len(files)): + fileName = os.path.splitext(file.split('/')[-1])[0] # Load audio and video + _, audio = wavfile.read(os.path.join(pycropPath, fileName + '.wav')) + if len(audio) == 0: + scores = numpy.array([-5]) + allScores.append(allScore) + continue + + audioFeature = python_speech_features.mfcc(audio, 16000, numcep = 13, winlen = 0.025, winstep = 0.010) + + video = cv2.VideoCapture(os.path.join(pycropPath, fileName + '.avi')) + videoFeature = [] + while video.isOpened(): + ret, frames = video.read() + if ret == True: + face = cv2.cvtColor(frames, cv2.COLOR_BGR2GRAY) + face = cv2.resize(face, (224,224)) + face = face[int(112-(112/2)):int(112+(112/2)), int(112-(112/2)):int(112+(112/2))] + videoFeature.append(face) + else: + break + video.release() + videoFeature = np.array(videoFeature) + length = min((audioFeature.shape[0] - audioFeature.shape[0] % 4) / 100, videoFeature.shape[0]) + audioFeature = audioFeature[:int(round(length * 100)),:] + videoFeature = videoFeature[:int(round(length * 25)),:,:] + allScore = [] # Evaluation use model + for duration in durationSet: + batchSize = int(math.ceil(length / duration)) + scores = [] + with torch.no_grad(): + for i in range(batchSize): + inputA = torch.FloatTensor(audioFeature[i * duration * 100:(i+1) * duration * 100,:]).unsqueeze(0).to(next(s.parameters()).device) + inputV = torch.FloatTensor(videoFeature[i * duration * 25: (i+1) * duration * 25,:,:]).unsqueeze(0).to(next(s.parameters()).device) + embedA = s.model.forward_audio_frontend(inputA) + embedV = s.model.forward_visual_frontend(inputV) + out = s.model.forward_audio_visual_backend(embedA, embedV) + score = s.lossAV.forward(out, labels = None) + scores.extend(score) + del inputA + del inputV + del embedA + del embedV + allScore.append(scores) + allScore = numpy.round((numpy.mean(numpy.array(allScore), axis = 0)), 1).astype(float) + allScores.append(allScore) + return allScores + + +def visualization(tracks, scores, video_array, pyaviPath): + # CPU: visulize the result for video format + + faces = [[] for i in range(video_array.shape[0])] + for tidx, track in enumerate(tracks): + score = scores[tidx] + for fidx, frame in enumerate(track['track']['frame'].tolist()): + s = score[max(fidx - 2, 0): min(fidx + 3, len(score) - 1)] # average smoothing + s = numpy.mean(s) + faces[frame].append({'track':tidx, 'score':float(s),'s':track['proc_track']['s'][fidx], 'x':track['proc_track']['x'][fidx], 'y':track['proc_track']['y'][fidx]}) + firstImage = video_array[0] + fw = firstImage.shape[1] + fh = firstImage.shape[0] + vOut = cv2.VideoWriter(os.path.join(pyaviPath, 'video_only.avi'), cv2.VideoWriter_fourcc(*'XVID'), 25, (fw,fh)) + colorDict = {0: 0, 1: 255} + for fidx in tqdm.tqdm(range(video_array.shape[0])): + image = video_array[fidx] + for face in faces[fidx]: + clr = colorDict[int((face['score'] >= 0))] + txt = round(face['score'], 1) + cv2.rectangle(image, (int(face['x']-face['s']), int(face['y']-face['s'])), (int(face['x']+face['s']), int(face['y']+face['s'])),(0,clr,255-clr),10) + cv2.putText(image,'%s'%(txt), (int(face['x']-face['s']), int(face['y']-face['s'])), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,clr,255-clr),5) + vOut.write(image) + vOut.release() + command = ("ffmpeg -y -i %s -i %s -threads %d -c:v copy -c:a copy %s -loglevel panic" % \ + (os.path.join(pyaviPath, 'video_only.avi'), os.path.join(pyaviPath, 'audio.wav'), \ + 10, os.path.join(pyaviPath,'video_out.avi'))) + output = subprocess.call(command, shell=True, stdout=None) + +def calculate_good_matches(matches, ratio=0.75): + good_matches = [] + for m, n in matches: + if m.distance < ratio * n.distance: + good_matches.append(m) + return len(good_matches) + +def find_max_intersection_and_remaining_dicts(dicts): + if not dicts: + return [], [] + + track_frames = [d['track']['frame'] for d in dicts] + + all_elements = set() + for frame in track_frames: + all_elements.update(frame) + + max_combination_indices = [] + max_intersection = set() + + for elem in all_elements: + current_combination_indices = [] + current_intersection = set([elem]) + + for i, frame in enumerate(track_frames): + if elem in frame: + current_combination_indices.append(i) + current_intersection.intersection_update(frame) + + if len(current_combination_indices) > len(max_combination_indices): + max_combination_indices = current_combination_indices + max_intersection = current_intersection + + max_combination = [dicts[i] for i in max_combination_indices] + remaining_dicts = [d for i, d in enumerate(dicts) if i not in max_combination_indices] + + return max_combination, remaining_dicts + +def get_faces_array(frame,s,x,y): + cs = 0.4 + bs = s # Detection box size + bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount + image = frame + frame = np.pad(image, ((bsi,bsi), (bsi,bsi), (0, 0)), 'constant', constant_values=(110, 110)) + my = y + bsi # BBox center Y + mx = x + bsi # BBox center X + face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))] + return face + + +def order_track_distance(track1,track2,video_array): + # Get the last face frame of track1 and the first face frame of track2 + track1_end_frame = video_array[track1['track']['frame'][-1]] + track1_s = track1['proc_track']['s'][-1] + track1_x = track1['proc_track']['x'][-1] + track1_y = track1['proc_track']['y'][-1] + track1_end_face_array = get_faces_array(track1_end_frame,track1_s,track1_x,track1_y) + + track2_start_frame = video_array[track2['track']['frame'][0]] + track2_s = track2['proc_track']['s'][0] + track2_x = track2['proc_track']['x'][0] + track2_y = track2['proc_track']['y'][0] + track2_strat_face_array = get_faces_array(track2_start_frame,track2_s,track2_x,track2_y) + + # Calculate the area overlap ratio + track1_bbox = track1['track']['bbox'][-1] + track2_bbox = track2['track']['bbox'][0] + iou = bb_intersection_over_union(track1_bbox, track2_bbox) + if iou <= 0.2: + distance_iou = 10000 + else: + distance_iou = math.exp(-5*iou) + + normalized_distance = 0 + + # face_id distance (with facenet) + result = DeepFace.verify(track1_end_face_array, track2_strat_face_array, model_name='Facenet', detector_backend = 'skip') + facenet_distance = result['distance'] + if facenet_distance > 0.85: + facenet_distance = facenet_distance + 10000 + + distance = 2*distance_iou + normalized_distance + facenet_distance + + return distance + +def update_remain(remaining_dicts, pop_item): + updated_dicts = [item for item in remaining_dicts if item['track']['bbox'].shape != pop_item['track']['bbox'].shape or (item['track']['bbox'] != pop_item['track']['bbox']).any()] + return updated_dicts + +def order_merge_tracks(track1,track2): + new_track = {} + new_track['proc_track'] = {} + new_track['proc_track']['x'] = track1['proc_track']['x'] + track2['proc_track']['x'] + new_track['proc_track']['y'] = track1['proc_track']['y'] + track2['proc_track']['y'] + new_track['proc_track']['s'] = track1['proc_track']['s'] + track2['proc_track']['s'] + new_track['human_bbox'] = {} + new_track['human_bbox']['x1'] = track1['human_bbox']['x1'] + track2['human_bbox']['x1'] + new_track['human_bbox']['y1'] = track1['human_bbox']['y1'] + track2['human_bbox']['y1'] + new_track['human_bbox']['x2'] = track1['human_bbox']['x2'] + track2['human_bbox']['x2'] + new_track['human_bbox']['y2'] = track1['human_bbox']['y2'] + track2['human_bbox']['y2'] + + new_track['track'] = {} + for key in list(track1['track'].keys()): + object1 = track1['track'][key] + object2 = track2['track'][key] + if isinstance(object1, np.ndarray): + new_track['track'][key] = np.concatenate((object1, object2)) + elif isinstance(object1, list): + new_track['track'][key] = object1 + object2 + else: + raise('new data type') + + return new_track + +def post_merge(vidTracks,video_array): + # Find the maximum overlapping tracks as the initial anchor + anchor_combination, remaining_dicts = find_max_intersection_and_remaining_dicts(vidTracks) + end_frame = video_array.shape[0] + continue_flag = np.ones((len(anchor_combination),2)) + max_iteration = 10 + iteration_count = 0 + while iteration_count0: + for track_ind in range(len(anchor_combination)): + track = anchor_combination[track_ind] + # Try to extend forward + if continue_flag[track_ind][0]: + if track['track']['frame'][0] == 0: + continue_flag[track_ind][0] = 0 + else: + # Find the candidate that is connected to it and is in the front row + possible_prior_tracks = [] + for checktrack in remaining_dicts: + if checktrack['track']['frame'][-1]+1 == track['track']['frame'][0] or checktrack['track']['frame'][-1]+2 == track['track']['frame'][0]: + possible_prior_tracks.append(checktrack) + # If it is not zero, then check the calculated distance + if len(possible_prior_tracks)>0: + distance_score_list = [] + for possible_prior_track in possible_prior_tracks: + distance_score_list.append(order_track_distance(possible_prior_track, track, video_array)) + distance_score_array = np.array(distance_score_list) + if min(distance_score_array) < 10000: + min_index = np.argmin(distance_score_array) + new_anchor = order_merge_tracks(possible_prior_tracks[min_index], track) + # update_anchor() + anchor_combination[track_ind] = new_anchor + track = new_anchor + remaining_dicts = update_remain(remaining_dicts, possible_prior_tracks[min_index]) + else: + continue_flag[track_ind][0] = 0 + else: + continue_flag[track_ind][0] = 0 + # Try to extend backwards + if continue_flag[track_ind][1]: + if track['track']['frame'][-1] == end_frame: + continue_flag[track_ind][0] = 0 + else: + # Find the candidate that is connected to it and in front of it + possible_after_tracks = [] + for checktrack in remaining_dicts: + if checktrack['track']['frame'][0]-1 == track['track']['frame'][-1] or checktrack['track']['frame'][0]-2 == track['track']['frame'][-1]: + possible_after_tracks.append(checktrack) + # If it is not zero, then check the calculated distance + if len(possible_after_tracks)>0: + distance_score_list = [] + for possible_after_track in possible_after_tracks: + distance_score_list.append(order_track_distance(track, possible_after_track, video_array)) + distance_score_array = np.array(distance_score_list) + if min(distance_score_array) < 10000: + min_index = np.argmin(distance_score_array) + new_anchor = order_merge_tracks(track, possible_after_tracks[min_index]) + # update_anchor() + anchor_combination[track_ind] = new_anchor + remaining_dicts = update_remain(remaining_dicts, possible_after_tracks[min_index]) + else: + continue_flag[track_ind][1] = 0 + else: + continue_flag[track_ind][1] = 0 + + final_tracks = anchor_combination + remaining_dicts + if len(final_tracks) > 5: + sorted_tracks = sorted(final_tracks, key=lambda x: len(x['track']['frame']), reverse=True) + top_tracks = sorted_tracks[:5] + else: + top_tracks = final_tracks + # return len(anchor_combination), top_5_tracks + returntracks = [] + for item in top_tracks: + if len(item['track']['frame'])>15: + returntracks.append(item) + return len(anchor_combination), returntracks + + +def longest_continuous_actives(arr): + max_length = 0 + current_length = 0 + + for num in arr: + if num > 0: + current_length += 1 + if current_length > max_length: + max_length = current_length + else: + current_length = 0 + + return max_length + +import pickle +import moviepy.editor as mp + +def annotate_video_with_bounding_boxes_with_audio(video_path, q_human_video_track_bbox, output_path): + bbox_path = q_human_video_track_bbox['bbox_path'] + frame_indices = q_human_video_track_bbox['track']['frame'] + video_array = get_video_array_cv2(video_path) + + with open(bbox_path, 'rb') as f: + bbox_data = pickle.load(f) + xy_bbox = bbox_data['xy_bbox'] + + # Get video dimensions and frame rate + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) # Get original video frame rate + num_frames, height, width, channels = video_array.shape + assert channels == 3, "Input video must have 3 channels (BGR)." + + # Initialize video writer + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for mp4 + temp_video_path = output_path.split('.')[0] + 'temp.mp4' + out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) # Use original FPS + + # Annotate video frames with bounding boxes + for i in range(num_frames): + frame = video_array[i] + if i in frame_indices: + idx = frame_indices.index(i) + x1, y1, x2, y2 = xy_bbox[idx] + # Draw bounding box + thickness = max(int((x2 - x1) / 40), 2) + cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), thickness) + + # Write frame to temporary video + out.write(frame) + + out.release() + cap.release() # Release the video capture object + + # Load original video and audio + original_video = mp.VideoFileClip(video_path) + annotated_video = mp.VideoFileClip(temp_video_path) + + # Combine annotated video with original audio, ensuring alignment + final_video = annotated_video.set_audio(original_video.audio) + + # Write the final output video with audio + final_video.write_videofile(output_path, codec='libx264', audio_codec='aac', fps=fps) + + # Clean up temporary video file + annotated_video.close() + original_video.close() + + # Optionally, remove the temporary video file + import os + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + + return output_path + +def annotate_video_with_bounding_boxes_withText_with_audio(video_path, q_human_video_track_bbox, output_path, numbers): + bbox_path = q_human_video_track_bbox['bbox_path'] + frame_indices = q_human_video_track_bbox['track']['frame'] + video_array = get_video_array_cv2(video_path) + + with open(bbox_path, 'rb') as f: + bbox_data = pickle.load(f) + xy_bbox = bbox_data['xy_bbox'] + + # Get video dimensions and frame rate + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) # Get original video frame rate + num_frames, height, width, channels = video_array.shape + assert channels == 3, "Input video must have 3 channels (BGR)." + + # Initialize video writer + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for mp4 + temp_video_path = output_path.split('.')[0] + 'temp.mp4' + out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) # Use original FPS + + # Annotate video frames with bounding boxes + for i in range(num_frames): + frame = video_array[i] + if i in frame_indices: + idx = frame_indices.index(i) + x1, y1, x2, y2 = xy_bbox[idx] + # Draw bounding box + thickness = max(int((x2 - x1) / 40), 2) + cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), thickness) + # Put the number in the top-left corner of the bounding box + cv2.putText(frame, numbers, (int(x1) + 10, int(y1) + 35), cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3) + + # Write frame to temporary video + out.write(frame) + + out.release() + cap.release() # Release the video capture object + + # Load original video and audio + original_video = mp.VideoFileClip(video_path) + annotated_video = mp.VideoFileClip(temp_video_path) + + # Combine annotated video with original audio, ensuring alignment + final_video = annotated_video.set_audio(original_video.audio) + + # Write the final output video with audio + final_video.write_videofile(output_path, codec='libx264', audio_codec='aac', fps=fps) + + # Clean up temporary video file + annotated_video.close() + original_video.close() + + # Optionally, remove the temporary video file + import os + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + + return output_path + + +def annotate_video_with_bounding_boxes(video_array, frame_indices, bounding_boxes, output_path): + """ + Annotates specified frames in the video with bounding boxes and saves the result to a new video file. + + :param video_array: Input video as a numpy array with shape (num_frames, height, width, channels). + :param frame_indices: List of frame indices to annotate. + :param bounding_boxes: Array of bounding box coordinates with shape (num_frames_to_annotate, 4), where each bounding box is (x, y, w, h). + :param output_path: Path to save the output video. + """ + # Get video dimensions + num_frames, height, width, channels = video_array.shape + assert channels == 3, "Input video must have 3 channels (BGR)." + + # Initialize video writer + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for mp4 + out = cv2.VideoWriter(output_path, fourcc, 30.0, (width, height)) + + # option 1: keep all video + for i in range(num_frames): + frame = video_array[i] + if i in frame_indices: + idx = frame_indices.index(i) + x1, y1, x2, y2 = bounding_boxes[idx] + # Draw bounding box + thinkness = max(int((x2-x1)/40),2) + cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), thinkness) + + # Write frame to output video + out.write(frame) + + # option 2:crap + # for in_id, out_id in enumerate(frame_indices): + # frame = video_array[out_id] + # x1, y1, x2, y2 = bounding_boxes[in_id] + # # Draw bounding box + # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 5) + # # Write frame to output video + # out.write(frame) + + out.release() + return output_path + + +def crop_from_array(frame_before_crop, coords): + x1, y1, x2, y2 = coords + cropped_frame = frame_before_crop[y1:y2, x1:x2] + return cropped_frame \ No newline at end of file diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 242ad3afe5..aff4a3a862 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -33,6 +33,20 @@ class BatchMetaKeys(object): class MetaKeys(object): + # === humanvbench related tags === + + active_speaker_flag = 'active_speaker_flag' + audio_speech_attribute = 'audio_speech_attribute' + speech_ASR = 'speech_ASR' + speech_emotion = 'speech_emotion' + video_facetrack_attribute_demographic = 'video_facetrack_attribute_demographic' + video_facetrack_attribute_emotion = 'video_facetrack_attribute_emotion' + track_video_caption = 'track_video_caption' + video_track_is_child = 'video_track_is_child' + human_track_data_path = 'human_track_data_path' + number_people_in_video = 'number_people_in_video' + + # === text related tags === # # sentiment @@ -240,6 +254,7 @@ class StatsKeysConstant(object): video_motion_score = 'video_motion_score' video_nsfw_score = 'video_nsfw_score' video_watermark_prob = 'video_watermark_prob' + # === multimodal === # image-text @@ -250,6 +265,10 @@ class StatsKeysConstant(object): # video-text video_frames_text_similarity = 'video_frames_text_similarity' + # video-face-ratio + video_face_exist = 'video_face_exist' + + class StatsKeys(object, metaclass=StatsKeysMeta): _constants_class = StatsKeysConstant diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index dd99032e36..25d83a2900 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -751,6 +751,123 @@ def prepare_vllm_model(pretrained_model_name_or_path, **model_params): return (model, tokenizer) +def prepare_SenseVoiceSmall_model(pretrained_model_name_or_path, **model_params): + """ + Prepare and load light sharegpt4video. + + :param model_name: input model name. + """ + from thirdparty.humanvbench_models.SenseVoice.model import SenseVoiceSmall + + logger.info('Loading ASR_Emo_model model...') + ASR_Emo_model, kwargs1 = SenseVoiceSmall.from_pretrained(model=pretrained_model_name_or_path) + + ASR_Emo_model.eval() + return ASR_Emo_model, kwargs1 + +def prepare_light_asd_model( + pretrained_model_name_or_path='weight/finetuning_TalkSet.model', **model_params): + """ + Prepare and load light asd model. + + :param model_name: input model name. + """ + logger.info('Loading light_asd model...') + from ASD import ASD + model = ASD() + model.loadParameters(pretrained_model_name_or_path) + model.eval() + return model + +def prepare_YOLOv8_human_model( + pretrained_model_name_or_path='./thirdparty/humanvbench_models/YOLOv8_human/weights/best.pt', **model_params): + """ + Prepare and load light YOLOv8_human. + + :param model_name: input model name. + """ + logger.info('Loading YOLOv8_human model...') + human_detection_model = torch.load(pretrained_model_name_or_path)['model'].float() + human_detection_model.half() + human_detection_model.eval() + return human_detection_model + +import sys +sys.path.append("./thirdparty/humanvbench_models/Light-ASD") +def prepare_face_detect_S3FD_model(model_path=None, **model_params): + """ + Prepare and load light asd model. + + :param model_name: input model name. + """ + logger.info('Loading face_detect_S3FD_model model...') + from model.faceDetector.s3fd import S3FD + model = S3FD() + return model + + +import torch +import torch.nn as nn +from transformers import Wav2Vec2Processor +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, +) +def prepare_wav2vec2_age_gender_model(pretrained_model_name_or_path = 'audeering/wav2vec2-large-robust-24-ft-age-gender', **model_params): + + class ModelHead(nn.Module): + r"""Classification head.""" + + def __init__(self, config, num_labels): + + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.final_dropout) + self.out_proj = nn.Linear(config.hidden_size, num_labels) + + def forward(self, features, **kwargs): + + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + + return x + + + class AgeGenderModel(Wav2Vec2PreTrainedModel): + r"""Speech emotion classifier.""" + + def __init__(self, config): + + super().__init__(config) + + self.config = config + self.wav2vec2 = Wav2Vec2Model(config) + self.age = ModelHead(config, 1) + self.gender = ModelHead(config, 3) + self.init_weights() + + def forward( + self, + input_values, + ): + + outputs = self.wav2vec2(input_values) + hidden_states = outputs[0] + hidden_states = torch.mean(hidden_states, dim=1) + logits_age = self.age(hidden_states) + logits_gender = torch.softmax(self.gender(hidden_states), dim=1) + + return hidden_states, logits_age, logits_gender + + processor = Wav2Vec2Processor.from_pretrained(pretrained_model_name_or_path) + model = AgeGenderModel.from_pretrained(pretrained_model_name_or_path) + return model, processor + def update_sampling_params(sampling_params, pretrained_model_name_or_path, @@ -817,6 +934,11 @@ def update_sampling_params(sampling_params, 'spacy': prepare_spacy_model, 'video_blip': prepare_video_blip_model, 'vllm': prepare_vllm_model, + 'Light_ASD': prepare_light_asd_model, + 'SenseVoiceSmall': prepare_SenseVoiceSmall_model, + 'YOLOv8_human': prepare_YOLOv8_human_model, + 'face_detect_S3FD': prepare_face_detect_S3FD_model, + 'wav2vec2_age_gender': prepare_wav2vec2_age_gender_model } _MODELS_WITHOUT_FILE_LOCK = { diff --git a/thirdparty/humanvbench_models/.gitmodules b/thirdparty/humanvbench_models/.gitmodules new file mode 100644 index 0000000000..b3785fb6aa --- /dev/null +++ b/thirdparty/humanvbench_models/.gitmodules @@ -0,0 +1,10 @@ +[submodule "YOLOv8_human"] + path = YOLOv8_human + url = https://github.com/jahongir7174/YOLOv8-human.git + commit_id = 8f8a65e +[submodule "Light-ASD"] + path = Light-ASD + url = https://github.com/Junhua-Liao/Light-ASD.git +[submodule "SenseVoice"] + path = SenseVoice + url = https://github.com/FunAudioLLM/SenseVoice.git diff --git a/thirdparty/humanvbench_models/Light-ASD_changes.diff b/thirdparty/humanvbench_models/Light-ASD_changes.diff new file mode 100644 index 0000000000..19451570b2 --- /dev/null +++ b/thirdparty/humanvbench_models/Light-ASD_changes.diff @@ -0,0 +1,66 @@ +diff --git a/__pycache__/ASD.cpython-39.pyc b/__pycache__/ASD.cpython-39.pyc +new file mode 100644 +index 0000000..2b6d14b +Binary files /dev/null and b/__pycache__/ASD.cpython-39.pyc differ +diff --git a/__pycache__/loss.cpython-39.pyc b/__pycache__/loss.cpython-39.pyc +new file mode 100644 +index 0000000..38f0c69 +Binary files /dev/null and b/__pycache__/loss.cpython-39.pyc differ +diff --git a/model/__pycache__/Classifier.cpython-39.pyc b/model/__pycache__/Classifier.cpython-39.pyc +new file mode 100644 +index 0000000..91653ac +Binary files /dev/null and b/model/__pycache__/Classifier.cpython-39.pyc differ +diff --git a/model/__pycache__/Encoder.cpython-39.pyc b/model/__pycache__/Encoder.cpython-39.pyc +new file mode 100644 +index 0000000..5a935d0 +Binary files /dev/null and b/model/__pycache__/Encoder.cpython-39.pyc differ +diff --git a/model/__pycache__/Model.cpython-39.pyc b/model/__pycache__/Model.cpython-39.pyc +new file mode 100644 +index 0000000..38b7681 +Binary files /dev/null and b/model/__pycache__/Model.cpython-39.pyc differ +diff --git a/model/faceDetector/__pycache__/__init__.cpython-39.pyc b/model/faceDetector/__pycache__/__init__.cpython-39.pyc +new file mode 100644 +index 0000000..4fcd28a +Binary files /dev/null and b/model/faceDetector/__pycache__/__init__.cpython-39.pyc differ +diff --git a/model/faceDetector/s3fd/__init__.py b/model/faceDetector/s3fd/__init__.py +index 943292a..a029f3d 100644 +--- a/model/faceDetector/s3fd/__init__.py ++++ b/model/faceDetector/s3fd/__init__.py +@@ -6,7 +6,7 @@ from torchvision import transforms + from .nets import S3FDNet + from .box_utils import nms_ + +-PATH_WEIGHT = 'model/faceDetector/s3fd/sfd_face.pth' ++PATH_WEIGHT = './thirdparty/humanvbench_models/Light-ASD/model/faceDetector/s3fd/sfd_face.pth' + if os.path.isfile(PATH_WEIGHT) == False: + Link = "1KafnHz7ccT-3IyddBsL5yi2xGtxAKypt" + cmd = "gdown --id %s -O %s"%(Link, PATH_WEIGHT) +diff --git a/model/faceDetector/s3fd/__pycache__/__init__.cpython-39.pyc b/model/faceDetector/s3fd/__pycache__/__init__.cpython-39.pyc +new file mode 100644 +index 0000000..1859ab0 +Binary files /dev/null and b/model/faceDetector/s3fd/__pycache__/__init__.cpython-39.pyc differ +diff --git a/model/faceDetector/s3fd/__pycache__/box_utils.cpython-39.pyc b/model/faceDetector/s3fd/__pycache__/box_utils.cpython-39.pyc +new file mode 100644 +index 0000000..c41063b +Binary files /dev/null and b/model/faceDetector/s3fd/__pycache__/box_utils.cpython-39.pyc differ +diff --git a/model/faceDetector/s3fd/__pycache__/nets.cpython-39.pyc b/model/faceDetector/s3fd/__pycache__/nets.cpython-39.pyc +new file mode 100644 +index 0000000..116e8ef +Binary files /dev/null and b/model/faceDetector/s3fd/__pycache__/nets.cpython-39.pyc differ +diff --git a/model/faceDetector/s3fd/box_utils.py b/model/faceDetector/s3fd/box_utils.py +index 0779bcd..1bf4be2 100644 +--- a/model/faceDetector/s3fd/box_utils.py ++++ b/model/faceDetector/s3fd/box_utils.py +@@ -35,7 +35,7 @@ def nms_(dets, thresh): + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + +- return np.array(keep).astype(np.int) ++ return np.array(keep).astype(int) + + + def decode(loc, priors, variances): +diff --git a/model/faceDetector/s3fd/sfd_face.pth b/model/faceDetector/s3fd/sfd_face.pth +new file mode 100644 +index 0000000..2bdf053 +Binary files /dev/null and b/model/faceDetector/s3fd/sfd_face.pth differ diff --git a/thirdparty/humanvbench_models/README.md b/thirdparty/humanvbench_models/README.md new file mode 100644 index 0000000000..030f8cc190 --- /dev/null +++ b/thirdparty/humanvbench_models/README.md @@ -0,0 +1,6 @@ +git clone +https://github.com/jahongir7174/YOLOv8-human.git +https://github.com/Junhua-Liao/Light-ASD.git +https://github.com/FunAudioLLM/SenseVoice.git + +其中./thirdparty/humanvbench_models/Light-ASD/model/faceDetector/s3fd目录下有一个sfd_face.pth是需要单独下载放进去的:https://huggingface.co/lithiumice/syncnet/tree/main \ No newline at end of file diff --git a/thirdparty/humanvbench_models/SenseVoice_changes.diff b/thirdparty/humanvbench_models/SenseVoice_changes.diff new file mode 100644 index 0000000000..bf121b6311 --- /dev/null +++ b/thirdparty/humanvbench_models/SenseVoice_changes.diff @@ -0,0 +1,15 @@ +diff --git a/model.py b/model.py +index a89defd..11b1285 100644 +--- a/model.py ++++ b/model.py +@@ -13,7 +13,9 @@ from funasr.train_utils.device_funcs import force_gatherable + from funasr.losses.label_smoothing_loss import LabelSmoothingLoss + from funasr.metrics.compute_acc import compute_accuracy, th_accuracy + from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +-from utils.ctc_alignment import ctc_forced_align ++import sys ++sys.path.append('/home/daoyuan_mm/data-juicer/thirdparty/humanvbench_models') ++from SenseVoice.utils.ctc_alignment import ctc_forced_align + + class SinusoidalPositionEncoder(torch.nn.Module): + """ """ diff --git a/thirdparty/humanvbench_models/YOLOv8_human_changes.diff b/thirdparty/humanvbench_models/YOLOv8_human_changes.diff new file mode 100644 index 0000000000..1c1e72a744 --- /dev/null +++ b/thirdparty/humanvbench_models/YOLOv8_human_changes.diff @@ -0,0 +1,145 @@ +diff --git a/__pycache__/dj.cpython-39.pyc b/__pycache__/dj.cpython-39.pyc +new file mode 100644 +index 0000000..7bc6a03 +Binary files /dev/null and b/__pycache__/dj.cpython-39.pyc differ +diff --git a/dj.py b/dj.py +new file mode 100644 +index 0000000..ce25877 +--- /dev/null ++++ b/dj.py +@@ -0,0 +1,111 @@ ++import sys ++import warnings ++from argparse import ArgumentParser ++ ++import numpy ++import torch ++sys.path.append('./thirdparty/humanvbench_models/YOLOv8_human') ++from nets import nn ++from util import non_max_suppression ++ ++warnings.filterwarnings("ignore") ++ ++ ++@torch.no_grad() ++def demo(img_array, model): ++ import cv2 ++ ++ frame = img_array ++ image = frame.copy() ++ shape = image.shape[:2] ++ ++ r = 640 / max(shape[0], shape[1]) ++ if r != 1: ++ resample = cv2.INTER_LINEAR if r > 1 else cv2.INTER_AREA ++ image = cv2.resize(image, dsize=(int(shape[1] * r), int(shape[0] * r)), interpolation=resample) ++ height, width = image.shape[:2] ++ ++ # Scale ratio (new / old) ++ r = min(1.0, 640 / height, 640 / width) ++ ++ # Compute padding ++ pad = int(round(width * r)), int(round(height * r)) ++ w = numpy.mod((640 - pad[0]), 32) / 2 ++ h = numpy.mod((640 - pad[1]), 32) / 2 ++ ++ if (width, height) != pad: # resize ++ image = cv2.resize(image, pad, interpolation=cv2.INTER_LINEAR) ++ top, bottom = int(round(h - 0.1)), int(round(h + 0.1)) ++ left, right = int(round(w - 0.1)), int(round(w + 0.1)) ++ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT) # add border ++ ++ # Convert HWC to CHW, BGR to RGB ++ x = image.transpose((2, 0, 1))[::-1] ++ x = numpy.ascontiguousarray(x) ++ x = torch.from_numpy(x) ++ x = x.unsqueeze(dim=0) ++ x = x.to(next(model.parameters()).device) ++ x = x.half() ++ x = x / 255 ++ # Inference ++ outputs = model(x) ++ # NMS ++ outputs = non_max_suppression(outputs, 0.25, 0.7) ++ final_output_box_list = [] ++ for output in outputs: ++ output[:, [0, 2]] -= w # x padding ++ output[:, [1, 3]] -= h # y padding ++ output[:, :4] /= min(height / shape[0], width / shape[1]) ++ ++ output[:, 0].clamp_(0, shape[1]) # x1 ++ output[:, 1].clamp_(0, shape[0]) # y1 ++ output[:, 2].clamp_(0, shape[1]) # x2 ++ output[:, 3].clamp_(0, shape[0]) # y2 ++ ++ for box in output: ++ box = box.cpu().numpy() ++ x1, y1, x2, y2, score, index = box ++ final_output_box_list.append((x1, y1, x2, y2)) ++ # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) ++ del x ++ return final_output_box_list ++ ++ ++ ++def profile(args, params): ++ model = nn.yolo_v8_n(len(params['names'])) ++ shape = (1, 3, args.input_size, args.input_size) ++ ++ model.eval() ++ model(torch.zeros(shape)) ++ params = sum(p.numel() for p in model.parameters()) ++ if args.local_rank == 0: ++ print(f'Number of parameters: {int(params)}') ++ ++ ++def human_detect(img_array): ++ parser = ArgumentParser() ++ parser.add_argument('--input-size', default=640, type=int) ++ parser.add_argument('--local_rank', default=0, type=int) ++ ++ args = parser.parse_args() ++ ++ args.local_rank = int(os.getenv('LOCAL_RANK', 0)) ++ args.world_size = int(os.getenv('WORLD_SIZE', 1)) ++ args.distributed = int(os.getenv('WORLD_SIZE', 1)) > 1 ++ ++ if args.distributed: ++ torch.cuda.set_device(device=args.local_rank) ++ torch.distributed.init_process_group(backend='nccl', init_method='env://') ++ ++ if args.local_rank == 0: ++ if not os.path.exists('weights'): ++ os.makedirs('weights') ++ ++ profile(args, img_array) ++ ++ demo(args,img_array) ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/nets/__pycache__/nn.cpython-39.pyc b/nets/__pycache__/nn.cpython-39.pyc +new file mode 100644 +index 0000000..6c6d13f +Binary files /dev/null and b/nets/__pycache__/nn.cpython-39.pyc differ +diff --git a/nets/nn.py b/nets/nn.py +index 66aec47..0dd5ee4 100644 +--- a/nets/nn.py ++++ b/nets/nn.py +@@ -1,8 +1,8 @@ + import math +- ++import sys + import torch +- +-from utils.util import make_anchors ++sys.path.append('./thirdparty/humanvbench_models/YOLOv8_human/utils') ++from util import make_anchors + + + def fuse_conv(conv, norm): +diff --git a/utils/__pycache__/util.cpython-39.pyc b/utils/__pycache__/util.cpython-39.pyc +new file mode 100644 +index 0000000..284455f +Binary files /dev/null and b/utils/__pycache__/util.cpython-39.pyc differ diff --git a/thirdparty/humanvbench_models/audio_code/wav2vec_age_gender.py b/thirdparty/humanvbench_models/audio_code/wav2vec_age_gender.py new file mode 100644 index 0000000000..3d4fee61a3 --- /dev/null +++ b/thirdparty/humanvbench_models/audio_code/wav2vec_age_gender.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import Wav2Vec2Processor +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, +) + + +class ModelHead(nn.Module): + r"""Classification head.""" + + def __init__(self, config, num_labels): + + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.final_dropout) + self.out_proj = nn.Linear(config.hidden_size, num_labels) + + def forward(self, features, **kwargs): + + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + + return x + + +class AgeGenderModel(Wav2Vec2PreTrainedModel): + r"""Speech emotion classifier.""" + + def __init__(self, config): + + super().__init__(config) + + self.config = config + self.wav2vec2 = Wav2Vec2Model(config) + self.age = ModelHead(config, 1) + self.gender = ModelHead(config, 3) + self.init_weights() + + def forward( + self, + input_values, + ): + + outputs = self.wav2vec2(input_values) + hidden_states = outputs[0] + hidden_states = torch.mean(hidden_states, dim=1) + logits_age = self.age(hidden_states) + logits_gender = torch.softmax(self.gender(hidden_states), dim=1) + + return hidden_states, logits_age, logits_gender + + + +# load model from hub +# device = 'cpu' +# model_name = '/mnt1/daoyuan_mm/wav2vec2-large-robust-24-ft-age-gender' +# processor = Wav2Vec2Processor.from_pretrained(model_name) +# model = AgeGenderModel.from_pretrained(model_name) + +# dummy signal +# sampling_rate = 16000 +# signal = np.zeros((1, sampling_rate), dtype=np.float32) + + +def process_func( + x: np.ndarray, + sampling_rate: int, + processor, + model, + device, + embeddings: bool = False, +) -> np.ndarray: + r"""Predict age and gender or extract embeddings from raw audio signal.""" + + # run through processor to normalize signal + # always returns a batch, so we just get the first entry + # then we put it on the device + y = processor(x, sampling_rate=sampling_rate) + y = y['input_values'][0] + y = y.reshape(1, -1) + y = torch.from_numpy(y).to(device) + + # run through model + with torch.no_grad(): + y = model(y) + if embeddings: + y = y[0] + else: + y = torch.hstack([y[1], y[2]]) + + # convert to numpy + y = y.detach().cpu().numpy() + + return y + + +# print(process_func(signal, sampling_rate)) +# # Age female male child +# # [[ 0.33793038 0.2715511 0.2275236 0.5009253 ]] + +# print(process_func(signal, sampling_rate, embeddings=True)) +# Pooled hidden states of last transformer layer +# [[ 0.024444 0.0508722 0.04930823 ... 0.07247854 -0.0697901 +# -0.0170537 ]] diff --git a/tools/process_data.py b/tools/process_data.py index a97ef9a408..49d6726917 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -1,5 +1,4 @@ from loguru import logger - from data_juicer.config import init_configs from data_juicer.core import Executor