diff --git a/app.py b/app.py new file mode 100644 index 0000000000..c7a5ef5596 --- /dev/null +++ b/app.py @@ -0,0 +1,315 @@ +""" +FastAPI application for skeleton-based action recognition. +""" +from typing import Optional, Dict, Any +from pathlib import Path +import os +import uuid +import shutil +import datetime + +import uvicorn +from fastapi import FastAPI, File, UploadFile, HTTPException, Query +from fastapi.responses import FileResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +import torch + +# Import the processing functions from demo_skeleton_refactored +from demo.demo_skeleton_refactored import ( + process_video_windows, + visualize_with_labels, + frame_extract, + detection_inference, + pose_inference, + init_recognizer +) + + +class ProcessingConfig: + """Configuration for video processing.""" + + def __init__(self): + # Create timestamped root directory for better organization + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + self.root_dir = Path(f"skeleton_recognition_{timestamp}") + self.upload_dir = self.root_dir / "uploads" + self.output_dir = self.root_dir / "processed_videos" + self.logs_dir = self.root_dir / "logs" + + # Model configurations based on infer_skl.sh + self.model_configs = { + "config": "work_dirs/posec3d_ntu60_2d_adam/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py", + "checkpoint": "work_dirs/posec3d_ntu60_2d_adam/best_acc_top1_epoch_24.pth", + "det_config": "demo/demo_configs/faster-rcnn_r50_fpn_2x_coco_infer.py", + "det_checkpoint": "http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth", + "pose_config": "demo/demo_configs/td-hm_ViTPose-small_8xb64-210e_coco-256x192.py", + "pose_checkpoint": "https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-small_8xb64-210e_coco-256x192-62d7a712_20230314.pth", + "label_map": "tools/data/skeleton/label_map_ntu60.txt" + } + + # Processing parameters + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.short_side = 480 + self.window_size = 32 + self.window_stride = 16 + self.det_score_thr = 0.9 + self.det_cat_id = 0 + + # Create necessary directories + self._create_directories() + + def _create_directories(self) -> None: + """Create all necessary directories for processing.""" + self.root_dir.mkdir(exist_ok=True) + self.upload_dir.mkdir(exist_ok=True) + self.output_dir.mkdir(exist_ok=True) + self.logs_dir.mkdir(exist_ok=True) + + +class VideoResponse(BaseModel): + """Response model for video processing.""" + video_id: str + processed_video_path: str + full_url: str + message: str + processing_info: Dict[str, Any] + + +app = FastAPI( + title="Skeleton-based Action Recognition API", + description="API for processing videos using skeleton-based action recognition", + version="1.0.0" +) + +config = ProcessingConfig() + +# Mount the processed videos directory for direct access +app.mount("/videos", StaticFiles(directory=str(config.output_dir)), name="videos") + + +@app.post("/process_video/", response_model=VideoResponse) +async def process_video( + video: UploadFile = File(...), + det_score_thr: float = Query(0.9, description="Detection score threshold") +) -> VideoResponse: + """ + Process a video file using skeleton-based action recognition. + + Args: + video (UploadFile): The input video file to process + det_score_thr (float): Detection score threshold + + Returns: + VideoResponse: Object containing the processed video information + + Raises: + HTTPException: If video processing fails + """ + try: + # Generate unique filename for the upload + video_id = str(uuid.uuid4()) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # Create paths for input and output files + video_filename = Path(video.filename).stem + video_extension = Path(video.filename).suffix + input_path = config.upload_dir / \ + f"{video_id}_{video_filename}{video_extension}" + output_filename = f"{video_id}_{video_filename}_processed.mp4" + output_path = config.output_dir / output_filename + + # Log processing request + log_path = config.logs_dir / f"{video_id}_processing.log" + with open(log_path, "w") as log_file: + log_file.write(f"Processing started at: {timestamp}\n") + log_file.write(f"Input video: {video.filename}\n") + log_file.write(f"Detection score threshold: {det_score_thr}\n") + + # Save uploaded video + with open(input_path, "wb") as buffer: + shutil.copyfileobj(video.file, buffer) + + # Update config with request-specific parameters + config.det_score_thr = det_score_thr + + # Extract frames + frame_paths, frames = frame_extract( + str(input_path), + config.short_side, + str(config.upload_dir) + ) + h, w, _ = frames[0].shape + + # Perform detection + det_results, _ = detection_inference( + config.model_configs["det_config"], + config.model_configs["det_checkpoint"], + frame_paths, + config.det_score_thr, + config.det_cat_id, + config.device + ) + torch.cuda.empty_cache() + + # Perform pose estimation + pose_results, pose_data_samples = pose_inference( + config.model_configs["pose_config"], + config.model_configs["pose_checkpoint"], + frame_paths, + det_results, + config.device + ) + torch.cuda.empty_cache() + + # Initialize model + model = init_recognizer( + config.model_configs["config"], + config.model_configs["checkpoint"], + config.device + ) + + # Load labels + with open(config.model_configs["label_map"], 'r') as f: + label_map = [x.strip() for x in f.readlines()] + + # Process video windows + frame_labels, frame_confidences = process_video_windows( + model, + pose_results, + (h, w), + label_map, + config.window_size, + config.window_stride, + len(pose_results) + ) + + # Create visualization + args = type('Args', (), { + 'out_filename': str(output_path), + 'video': str(input_path), + 'det_score_thr': config.det_score_thr, + 'pose_config': config.model_configs["pose_config"], + 'pose_checkpoint': config.model_configs["pose_checkpoint"], + 'config': config.model_configs["config"], + 'checkpoint': config.model_configs["checkpoint"], + 'label_map': config.model_configs["label_map"], + 'device': config.device, + 'short_side': config.short_side, + 'window_size': config.window_size, + 'window_stride': config.window_stride, + 'det_cat_id': config.det_cat_id, + 'cfg_options': {} + }) + + visualize_with_labels( + args, + frames, + pose_data_samples, + frame_labels, + frame_confidences + ) + + # Generate full URL for accessing the video + # Change this to your actual host URL in production + host_url = "http://localhost:8000" + video_url = f"{host_url}/videos/{output_filename}" + + # Prepare processing info + processing_info = { + "timestamp": timestamp, + "input_video": str(input_path), + "output_video": str(output_path), + "detection_score_threshold": config.det_score_thr, + "num_frames_processed": len(frames), + "device_used": config.device + } + + # Log completion + with open(log_path, "a") as log_file: + log_file.write( + f"Processing completed at: {datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + log_file.write(f"Output video: {str(output_path)}\n") + + return VideoResponse( + video_id=video_id, + processed_video_path=str(output_path), + full_url=video_url, + message="Video processed successfully", + processing_info=processing_info + ) + + except Exception as e: + # Log error + error_log_path = config.logs_dir / \ + f"{video_id}_error.log" if 'video_id' in locals( + ) else config.logs_dir / f"error_{timestamp}.log" + with open(error_log_path, "w") as error_file: + error_file.write( + f"Error occurred at: {datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + error_file.write(f"Error details: {str(e)}\n") + + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/video/{video_id}") +async def get_video(video_id: str): + """ + Retrieve a processed video file. + + Args: + video_id (str): The ID of the processed video + + Returns: + FileResponse: The video file + + Raises: + HTTPException: If video file is not found + """ + # Find the video with the given ID (with any filename) + for file in config.output_dir.glob(f"{video_id}_*"): + if file.is_file() and file.suffix in (".mp4", ".avi", ".mov"): + return FileResponse( + str(file), + media_type="video/mp4", + filename=file.name + ) + + raise HTTPException(status_code=404, detail="Video not found") + + +@app.get("/videos/") +async def list_videos(): + """ + List all processed videos. + + Returns: + JSONResponse: List of available videos with their IDs and paths + """ + videos = [] + for file in config.output_dir.glob("*"): + if file.is_file() and file.suffix in (".mp4", ".avi", ".mov"): + video_id = file.stem.split("_")[0] + videos.append({ + "video_id": video_id, + "filename": file.name, + "path": str(file), + "url": f"/videos/{file.name}" + }) + + return JSONResponse(content={"videos": videos}) + + +@app.on_event("startup") +async def startup_event(): + """Initialize directories and models on startup.""" + config._create_directories() + print(f"Server initialized with the following configuration:") + print(f"- Root directory: {config.root_dir}") + print(f"- Upload directory: {config.upload_dir}") + print(f"- Output directory: {config.output_dir}") + print(f"- Logs directory: {config.logs_dir}") + print(f"- Using device: {config.device}") + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/configs/skeleton/posec3d/distill_posec3d_student.py b/configs/skeleton/posec3d/distill_posec3d_student.py new file mode 100644 index 0000000000..d5c904b54b --- /dev/null +++ b/configs/skeleton/posec3d/distill_posec3d_student.py @@ -0,0 +1,139 @@ +# configs/skeleton/posec3d/distill_posec3d_student.py + +_base_ = ['../../_base_/default_runtime.py'] + +# model settings +model = dict( + type='DistillPoseC3D', + backbone=dict( + type='ResNet3dSlowOnly', # Student backbone (smaller) + in_channels=17, # Number of keypoints + base_channels=32, # Reduced from original (64) + num_stages=3, + out_indices=(2, ), + stage_blocks=(4, 6, 3), + conv1_stride_s=1, + pool1_stride_s=1, + inflate=(0, 1, 1), + spatial_strides=(2, 2, 2), + temporal_strides=(1, 1, 2), + dilations=(1, 1, 1)), + teacher_backbone=dict( + type='ResNet3dSlowOnly', # Teacher backbone (larger) + in_channels=17, + base_channels=64, # Original size + num_stages=4, + out_indices=(3, ), + stage_blocks=(3, 4, 6, 3), + conv1_stride_s=1, + pool1_stride_s=1, + inflate=(0, 1, 1, 1), + spatial_strides=(2, 2, 2, 2), + temporal_strides=(1, 1, 2, 2), + dilations=(1, 1, 1, 1)), + teacher_checkpoint='checkpoints/posec3d_k400.pth', # Path to teacher model + cls_head=dict( + type='I3DHead', + in_channels=512, # Match student backbone output + num_classes=60, # NTU60 has 60 classes + spatial_type='avg', + dropout_ratio=0.5), + teacher_cls_head=dict( + type='I3DHead', + in_channels=512, # Match teacher backbone output + num_classes=60, + spatial_type='avg', + dropout_ratio=0.5), + train_cfg=dict( + distill_loss=dict( + type='KLDivLoss', + temperature=4.0, + alpha=0.5 # Balance between distillation and CE loss + ) + ), + test_cfg=dict(average_clips='prob') +) + +# dataset settings +dataset_type = 'PoseDataset' +ann_file_train = '/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/data/skeleton/ntu60_2d/ntu60_2d_train.pkl' +ann_file_val = '/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/data/skeleton/ntu60_2d/ntu60_2d_val.pkl' +train_pipeline = [ + dict(type='UniformSampleFrames', clip_len=48), + dict(type='PoseDecode'), + dict(type='PoseCompact', hw_ratio=1., allow_imgpad=True), + dict(type='Resize', scale=(-1, 64)), + dict(type='RandomResizedCrop', area_range=(0.56, 1.0)), + dict(type='Resize', scale=(56, 56), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='GeneratePoseTarget', + sigma=0.6, + use_score=True, + with_kp=True, + with_limb=False), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='UniformSampleFrames', clip_len=48, num_clips=1), + dict(type='PoseDecode'), + dict(type='PoseCompact', hw_ratio=1., allow_imgpad=True), + dict(type='Resize', scale=(56, 56), keep_ratio=False), + dict(type='GeneratePoseTarget', + sigma=0.6, + use_score=True, + with_kp=True, + with_limb=False), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = val_pipeline +data = dict( + videos_per_gpu=16, + workers_per_gpu=2, + test_dataloader=dict(videos_per_gpu=1), + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix='', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix='', + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix='', + pipeline=test_pipeline)) + +# optimizer +optimizer = dict( + type='SGD', + lr=0.01, # Lower learning rate for distillation + momentum=0.9, + weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2)) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + min_lr=0, + warmup='linear', + warmup_by_epoch=True, + warmup_iters=5) +total_epochs = 10 + +# runtime settings +checkpoint_config = dict(interval=5) +evaluation = dict(interval=5, metrics=[ + 'top_k_accuracy', 'mean_class_accuracy']) +log_config = dict(interval=20, hooks=[dict(type='TextLoggerHook')]) + +# Make sure the teacher model checkpoint exists +# If you don't have it, you can download it or use a different checkpoint +# You can also set this to None and the model will use random weights (not recommended) +find_unused_parameters = True # Important for distillation training diff --git a/demo/1732629973744.mp4 b/demo/1732629973744.mp4 new file mode 100644 index 0000000000..38f9152d09 Binary files /dev/null and b/demo/1732629973744.mp4 differ diff --git a/demo/1732698562982.mp4 b/demo/1732698562982.mp4 new file mode 100644 index 0000000000..208777d502 Binary files /dev/null and b/demo/1732698562982.mp4 differ diff --git a/demo/1736557631421.mp4 b/demo/1736557631421.mp4 new file mode 100644 index 0000000000..8e11bb5aa9 Binary files /dev/null and b/demo/1736557631421.mp4 differ diff --git a/demo/convert_video.py b/demo/convert_video.py new file mode 100644 index 0000000000..4b9f543cc5 --- /dev/null +++ b/demo/convert_video.py @@ -0,0 +1,11 @@ +import ffmpeg + +input_file = "/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/1736557631421.mp4" +output_file = "/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/1736557631421_converted.mp4" + +( + ffmpeg + .input(input_file) + .output(output_file, vcodec='libx264', preset='medium', crf=23) + .run() +) \ No newline at end of file diff --git a/demo/demo_configs/td-hm_ViTPose-base-simple_8xb64-210e_coco-256x192.py b/demo/demo_configs/td-hm_ViTPose-base-simple_8xb64-210e_coco-256x192.py new file mode 100644 index 0000000000..cb03d41527 --- /dev/null +++ b/demo/demo_configs/td-hm_ViTPose-base-simple_8xb64-210e_coco-256x192.py @@ -0,0 +1,160 @@ +# _base_ = ['../../../_base_/default_runtime.py'] + +# runtime +from mmengine.registry import MODELS +import mmpretrain +train_cfg = dict(max_epochs=210, val_interval=10) + +# optimizer +custom_imports = dict( + imports=['mmpose.engine.optim_wrappers.layer_decay_optim_wrapper'], + allow_failed_imports=False) + +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1), + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + 'bias': dict(decay_multi=0.0), + 'pos_embed': dict(decay_mult=0.0), + 'relative_position_bias_table': dict(decay_mult=0.0), + 'norm': dict(decay_mult=0.0), + }, + ), + constructor='LayerDecayOptimWrapperConstructor', + clip_grad=dict(max_norm=1., norm_type=2), +) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict( + checkpoint=dict(save_best='coco/AP', rule='greater', max_keep_ckpts=1)) + +# codec settings +codec = dict( + type='UDPHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='mmpretrain.VisionTransformer', + arch='base', + img_size=(256, 192), + patch_size=16, + qkv_bias=True, + drop_path_rate=0.3, + with_cls_token=False, + out_type='featmap', + patch_cfg=dict(padding=2), + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmpose/' + 'v1/pretrained_models/mae_pretrain_vit_base_20230913.pth'), + ), + neck=dict(type='FeatureMapProcessor', scale_factor=4.0, apply_relu=True), + head=dict( + type='HeatmapHead', + in_channels=768, + out_channels=17, + deconv_out_channels=[], + deconv_kernel_sizes=[], + final_layer=dict(kernel_size=3, padding=1), + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec, + ), + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=False, + )) + +# base dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' +data_mode = 'topdown' + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size'], use_udp=True), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size'], use_udp=True), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=32, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/person_keypoints_val2017.json') +test_evaluator = val_evaluator + + + +# Add this near the top of your script +# Make sure model registry is properly imported diff --git a/demo/demo_configs/td-hm_ViTPose-small_8xb64-210e_coco-256x192.py b/demo/demo_configs/td-hm_ViTPose-small_8xb64-210e_coco-256x192.py new file mode 100644 index 0000000000..2036ed8301 --- /dev/null +++ b/demo/demo_configs/td-hm_ViTPose-small_8xb64-210e_coco-256x192.py @@ -0,0 +1,164 @@ +# _base_ = ['../../../_base_/default_runtime.py'] +from mmengine.registry import MODELS +import mmpretrain + +# runtime +train_cfg = dict(max_epochs=210, val_interval=10) + +# optimizer +custom_imports = dict( + imports=['mmpose.engine.optim_wrappers.layer_decay_optim_wrapper'], + allow_failed_imports=False) + +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1), + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.8, + custom_keys={ + 'bias': dict(decay_multi=0.0), + 'pos_embed': dict(decay_mult=0.0), + 'relative_position_bias_table': dict(decay_mult=0.0), + 'norm': dict(decay_mult=0.0), + }, + ), + constructor='LayerDecayOptimWrapperConstructor', + clip_grad=dict(max_norm=1., norm_type=2), +) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict( + checkpoint=dict(save_best='coco/AP', rule='greater', max_keep_ckpts=1)) + +# codec settings +codec = dict( + type='UDPHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='mmpretrain.VisionTransformer', + arch={ + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 384 * 4 + }, + img_size=(256, 192), + patch_size=16, + qkv_bias=True, + drop_path_rate=0.1, + with_cls_token=False, + out_type='featmap', + patch_cfg=dict(padding=2), + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmpose/' + 'v1/pretrained_models/mae_pretrain_vit_small_20230913.pth'), + ), + head=dict( + type='HeatmapHead', + in_channels=384, + out_channels=17, + deconv_out_channels=(256, 256), + deconv_kernel_sizes=(4, 4), + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec), + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=False, + )) + +# base dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' +data_mode = 'topdown' + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size'], use_udp=True), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size'], use_udp=True), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=32, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/person_keypoints_val2017.json') +test_evaluator = val_evaluator + +# visualizer +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='mmpose.PoseLocalVisualizer', + vis_backends=vis_backends, + name='visualizer') diff --git a/demo/demo_enhanced_pose.py b/demo/demo_enhanced_pose.py new file mode 100644 index 0000000000..c754c2041b --- /dev/null +++ b/demo/demo_enhanced_pose.py @@ -0,0 +1,68 @@ +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +import numpy as np +import mmcv +from mmaction.apis.enhanced_inference import enhanced_pose_inference +from mmdet.apis import inference_detector, init_detector + + +def main(): + # Detector settings + det_config = '/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/demo_configs/faster-rcnn_r50_fpn_2x_coco_infer.py' + det_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth' + + # Pose estimation settings + pose_config = '/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/configs/skeleton/posec3d/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py' + pose_checkpoint = '/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint/best_acc_top1_epoch_24.pth' + + # Initialize detector + det_model = init_detector(det_config, det_checkpoint, device='cuda:0') + + # Input video + video_path = '/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/data/skeleton/Le2i/Lecture_room/video_1.avi' + + # Extract frames + video = mmcv.VideoReader(video_path) + frames = [video[i] for i in range(len(video))] + frame_paths = [f'tmp_frame_{i}.jpg' for i in range(len(frames))] + + # Save frames temporarily + for frame, path in zip(frames, frame_paths): + mmcv.imwrite(frame, path) + + # Run human detection on each frame + det_results = [] + for frame_path in frame_paths: + result = inference_detector(det_model, frame_path) + # Keep only person class (usually class 0) + det_results.append( + result.pred_instances.bboxes[result.pred_instances.labels == 0].cpu().numpy()) + + # Enhanced pose estimation with batching and temporal smoothing + pose_results, _ = enhanced_pose_inference( + pose_config=pose_config, + pose_checkpoint=pose_checkpoint, + frame_paths=frame_paths, + det_results=det_results, + device='cuda:0', + batch_size=4, # Process 4 frames at once + use_temporal_smoothing=True, # Apply smoothing for video + smoothing_window_size=7, + smoothing_sigma=1.5, + keypoint_threshold=0.3 + ) + + # Clean up temporary files + for path in frame_paths: + if os.path.exists(path): + os.remove(path) + + # Now you can use pose_results for further processing, + # such as skeleton-based action recognition + print(f"Successfully processed {len(pose_results)} frames") + + +if __name__ == '__main__': + main() diff --git a/demo/demo_skeleton_refactored.py b/demo/demo_skeleton_refactored.py new file mode 100644 index 0000000000..52da0bb3d5 --- /dev/null +++ b/demo/demo_skeleton_refactored.py @@ -0,0 +1,361 @@ +""" +Copyright (c) OpenMMLab. All rights reserved. +Refactored skeleton-based action recognition demo with improved organization. +""" +import argparse +import tempfile +from typing import List, Tuple, Dict + +import cv2 +import mmcv +import mmengine +import torch +from mmengine import DictAction +from mmengine.utils import track_iter_progress + +try: + import moviepy.editor as mpy +except ImportError: + raise ImportError('Please install moviepy to enable output file') + +from mmaction.apis import ( + detection_inference, + inference_skeleton, + init_recognizer, + pose_inference +) +from mmaction.utils import frame_extract +from mmaction.registry import VISUALIZERS + +# Visualization settings +VISUALIZATION_SETTINGS = { + 'font_face': cv2.FONT_HERSHEY_DUPLEX, + 'font_scale': 0.75, + 'font_color': (255, 0, 0), # BGR, white + 'thickness': 1, + 'line_type': 1 +} + + +def parse_args(): + """Parse input arguments.""" + parser = argparse.ArgumentParser( + description='MMAction2 skeleton-based demo') + + # Input and output + parser.add_argument('video', help='video file/url') + parser.add_argument('out_filename', help='output filename') + + # Model configs + parser.add_argument( + '--config', + default=('configs/skeleton/posec3d/' + 'slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py'), + help='skeleton model config file path') + parser.add_argument( + '--checkpoint', + default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/' + 'slowonly_r50_u48_240e_ntu60_xsub_keypoint/' + 'slowonly_r50_u48_240e_ntu60_xsub_keypoint-f3adabf1.pth'), + help='skeleton model checkpoint file/url') + + # Detection configs + parser.add_argument( + '--det-config', + default='demo/demo_configs/faster-rcnn_r50_fpn_2x_coco_infer.py', + help='human detection config file path (from mmdet)') + parser.add_argument( + '--det-checkpoint', + default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/' + 'faster_rcnn_r50_fpn_2x_coco/' + 'faster_rcnn_r50_fpn_2x_coco_' + 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'), + help='human detection checkpoint file/url') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.9, + help='human detection score threshold') + parser.add_argument( + '--det-cat-id', + type=int, + default=0, + help='human category id for detection') + + # Pose estimation configs + parser.add_argument( + '--pose-config', + default='demo/demo_configs/' + 'td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py', + help='human pose estimation config file path (from mmpose)') + parser.add_argument( + '--pose-checkpoint', + default=('https://download.openmmlab.com/mmpose/top_down/hrnet/' + 'hrnet_w32_coco_256x192-c78dce93_20200708.pth'), + help='human pose estimation checkpoint file/url') + + # Other settings + parser.add_argument( + '--label-map', + default='tools/data/skeleton/label_map_ntu60.txt', + help='label map file') + parser.add_argument( + '--device', + type=str, + default='cuda:0', + help='CPU/CUDA device option') + parser.add_argument( + '--short-side', + type=int, + default=480, + help='specify the short-side length of the image') + parser.add_argument( + '--window-size', + type=int, + default=32, + help='window size for skeleton action recognition') + parser.add_argument( + '--window-stride', + type=int, + default=16, + help='stride for sliding window') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config') + + return parser.parse_args() + + +def process_frame_window( + model: torch.nn.Module, + pose_results: List[Dict], + img_shape: Tuple[int, int], + label_map: List[str] +) -> Tuple[str, float, List[float]]: + """Process a window of frames for action recognition. + + Args: + model: The loaded recognizer model. + pose_results: List of pose estimation results. + img_shape: Original image shape. + label_map: List of action labels. + + Returns: + Tuple containing action label, confidence score, and all prediction scores. + """ + result = inference_skeleton(model, pose_results, img_shape) + + pred_scores = result.pred_score.cpu().numpy() + max_pred_index = pred_scores.argmax() + + action_label = label_map[max_pred_index] + confidence = pred_scores[max_pred_index] + + return action_label, confidence, pred_scores + + +def check_falling_alert(action_label: str, confidence: float, frame_range: Tuple[int, int]): + """Generate warning if falling is detected with confidence threshold.""" + if action_label.lower() == 'falling' and confidence > 0.7: + start, end = frame_range + warning_msg = f"⚠️ FALL DETECTED! ⚠️\nConfidence: {confidence:.2%}\nFrames: {start}-{end}" + print("\n" + "!"*50) + print(warning_msg) + print("!"*50 + "\n") + + +def visualize_with_labels( + args, + frames: List[torch.Tensor], + pose_data_samples: List, + frame_labels: List[str], + frame_confidences: List[float] +) -> None: + """Visualize frames with action labels and skeleton overlays. + + Args: + args: Parsed command line arguments. + frames: List of video frames. + pose_data_samples: List of pose estimation results. + frame_labels: List of action labels for each frame. + frame_confidences: List of confidence scores for each frame. + """ + pose_config = mmengine.Config.fromfile(args.pose_config) + visualizer = VISUALIZERS.build(pose_config.visualizer) + visualizer.set_dataset_meta(pose_data_samples[0].dataset_meta) + + vis_frames = [] + print('Drawing skeleton and labels for each frame') + + for i, (d, f) in enumerate(track_iter_progress(list(zip(pose_data_samples, frames)))): + # Convert frame color space + f = mmcv.imconvert(f, 'bgr', 'rgb') + + # Draw pose estimation results + visualizer.add_datasample( + 'result', + f, + data_sample=d, + draw_gt=False, + draw_heatmap=False, + draw_bbox=True, + show=False, + wait_time=0, + out_file=None, + kpt_thr=0.3 + ) + vis_frame = visualizer.get_image() + + # Add action label and confidence + action_text = f"{frame_labels[i]} ({frame_confidences[i]:.2f})" + + # Highlight falling labels in red + if "falling" in frame_labels[i].lower(): + FONTCOLOR = (255, 0, 0) # Red color for falling + thickness = 2 + else: + FONTCOLOR = VISUALIZATION_SETTINGS['font_color'] + thickness = VISUALIZATION_SETTINGS['thickness'] + + cv2.putText( + vis_frame, + action_text, + (10, 30), + VISUALIZATION_SETTINGS['font_face'], + VISUALIZATION_SETTINGS['font_scale'], + FONTCOLOR, # Now using conditional color + thickness, # Now using conditional thickness + VISUALIZATION_SETTINGS['line_type'] + ) + + # Add frame number + cv2.putText( + vis_frame, + f"Frame: {i}", + (10, 60), + VISUALIZATION_SETTINGS['font_face'], + VISUALIZATION_SETTINGS['font_scale'], + VISUALIZATION_SETTINGS['font_color'], + VISUALIZATION_SETTINGS['thickness'], + VISUALIZATION_SETTINGS['line_type'] + ) + + vis_frames.append(vis_frame) + + # Create and save video + vid = mpy.ImageSequenceClip(vis_frames, fps=24) + vid.write_videofile(args.out_filename, remove_temp=True) + + +def process_video_windows( + model: torch.nn.Module, + pose_results: List[Dict], + img_shape: Tuple[int, int], + label_map: List[str], + window_size: int, + window_stride: int, + num_frames: int +) -> Tuple[List[str], List[float]]: + """Process video using sliding windows for action recognition. + + Args: + model: The loaded recognizer model. + pose_results: List of pose estimation results. + img_shape: Original image shape. + label_map: List of action labels. + window_size: Size of sliding window. + window_stride: Stride for sliding window. + num_frames: Total number of frames. + + Returns: + Tuple containing lists of frame labels and confidence scores. + """ + frame_labels = ["Unknown"] * num_frames + frame_confidences = [0.0] * num_frames + + print('Processing video in sliding windows') + for start_idx in track_iter_progress(range(0, num_frames - window_size + 1, window_stride)): + end_idx = start_idx + window_size + window_pose_results = pose_results[start_idx:end_idx] + + action_label, confidence, _ = process_frame_window( + model, window_pose_results, img_shape, label_map) + + # Add falling detection check + check_falling_alert(action_label, confidence, (start_idx, end_idx)) + + # Assign labels to frames in the window based on confidence + for i in range(start_idx, end_idx): + if confidence > frame_confidences[i]: + frame_labels[i] = action_label + frame_confidences[i] = confidence + + return frame_labels, frame_confidences + + +def main(): + """Main function for skeleton-based action recognition demo.""" + args = parse_args() + + # Create temporary directory for frame extraction + with tempfile.TemporaryDirectory() as tmp_dir: + # Extract video frames + frame_paths, frames = frame_extract( + args.video, args.short_side, tmp_dir) + h, w, _ = frames[0].shape + + # Perform human detection + det_results, _ = detection_inference( + args.det_config, + args.det_checkpoint, + frame_paths, + args.det_score_thr, + args.det_cat_id, + args.device + ) + torch.cuda.empty_cache() + + # Perform pose estimation + pose_results, pose_data_samples = pose_inference( + args.pose_config, + args.pose_checkpoint, + frame_paths, + det_results, + args.device + ) + torch.cuda.empty_cache() + + # Initialize action recognition model + config = mmengine.Config.fromfile(args.config) + config.merge_from_dict(args.cfg_options) + model = init_recognizer(config, args.checkpoint, args.device) + + # Load action labels + label_map = [x.strip() for x in open(args.label_map).readlines()] + + # Process video with sliding windows + frame_labels, frame_confidences = process_video_windows( + model, + pose_results, + (h, w), + label_map, + args.window_size, + args.window_stride, + len(pose_results) + ) + + # Visualize results + visualize_with_labels( + args, + frames, + pose_data_samples, + frame_labels, + frame_confidences + ) + + +if __name__ == '__main__': + main() diff --git a/demo/lying.mp4 b/demo/lying.mp4 new file mode 100644 index 0000000000..17df554312 Binary files /dev/null and b/demo/lying.mp4 differ diff --git a/demo/output_har.mp4 b/demo/output_har.mp4 new file mode 100644 index 0000000000..d1d370930d Binary files /dev/null and b/demo/output_har.mp4 differ diff --git a/infer_har.sh b/infer_har.sh new file mode 100644 index 0000000000..dd7b7d9307 --- /dev/null +++ b/infer_har.sh @@ -0,0 +1,5 @@ +python demo/demo.py /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/configs/recognition/slowfast/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb.py \ + /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/checkpoints/slowfast_r50_8xb8-8x8x1-256e_kinetics400-rgb_20220818-1cb6dfc8.pth \ + /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/1736557631421.mp4 \ + /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/tools/data/kinetics/label_map_k400.txt \ + --out-filename /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/output_har.mp4 \ No newline at end of file diff --git a/infer_skl.sh b/infer_skl.sh new file mode 100644 index 0000000000..21bb0483e1 --- /dev/null +++ b/infer_skl.sh @@ -0,0 +1,25 @@ +# python demo/demo_skeleton.py /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/lying.mp4 /home/minhtranh/works/Project/Rainscales/Lying_detection/src/demo_out.mp4 \ +# --config /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/posec3d_ntu60_2d_adam/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py \ +# --label-map /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/tools/data/skeleton/label_map_ntu60.txt \ +# --checkpoint /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/posec3d_ntu60_2d_adam/epoch_16.pth \ + +# # python demo/demo_enhanced_pose.py + +# VitPose-small +python demo/demo_skeleton_refactored.py /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/1732698562982.mp4 /home/minhtranh/works/Project/Rainscales/Lying_detection/src/demo_out_1736557631421.mp4 \ + --det-score-thr 0.9 \ + --config /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/posec3d_ntu60_2d_adam/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py \ + --checkpoint /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/posec3d_ntu60_2d_adam/best_acc_top1_epoch_24.pth \ + --label-map /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/tools/data/skeleton/label_map_ntu60.txt \ + --pose-config /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/demo_configs/td-hm_ViTPose-small_8xb64-210e_coco-256x192.py \ + --pose-checkpoint https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-small_8xb64-210e_coco-256x192-62d7a712_20230314.pth + +# VitPose-base-simple +# python demo/demo_skeleton_refactored.py /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/data/skeleton/Le2i/Lecture_room/video_1.avi /home/minhtranh/works/Project/Rainscales/Lying_detection/src/demo_out.mp4 \ +# --det-score-thr 0.9 \ +# --config /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/posec3d_ntu60_2d_adam/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py \ +# --checkpoint /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/work_dirs/posec3d_ntu60_2d_adam/best_acc_top1_epoch_24.pth \ +# --label-map /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/tools/data/skeleton/label_map_ntu60.txt \ +# --pose-config /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/demo/demo_configs/td-hm_ViTPose-base-simple_8xb64-210e_coco-256x192.py \ +# --pose-checkpoint https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-base-simple_8xb64-210e_coco-256x192-0b8234ea_20230407.pth + diff --git a/mmaction/apis/batch_inference.py b/mmaction/apis/batch_inference.py new file mode 100644 index 0000000000..61cc79d5f6 --- /dev/null +++ b/mmaction/apis/batch_inference.py @@ -0,0 +1,103 @@ +import torch +import numpy as np +from typing import List, Union, Tuple, Dict +from pathlib import Path +import torch.nn as nn +import mmengine +from tqdm import tqdm + + +def batch_pose_inference( + pose_config: Union[str, Path, mmengine.Config, nn.Module], + pose_checkpoint: str, + frame_paths: List[str], + det_results: List[np.ndarray], + batch_size: int = 4, + device: Union[str, torch.device] = 'cuda:0' +) -> Tuple[List[Dict[str, np.ndarray]], List]: + """Perform batched Top-Down pose estimation for better throughput. + + Args: + pose_config: Pose config file path or model object. + pose_checkpoint: Checkpoint path/url. + frame_paths: The paths of frames to do pose inference. + det_results: List of detected human boxes. + batch_size: Number of frames to process in a batch. + device: The desired device for inference. + + Returns: + Tuple of pose estimation results and data samples. + """ + try: + from mmpose.apis import init_model + from mmpose.structures import PoseDataSample, merge_data_samples + import mmcv + except (ImportError, ModuleNotFoundError): + raise ImportError('Failed to import required modules from MMPose') + + # Input validation + if not frame_paths: + raise ValueError("frame_paths cannot be empty") + if len(frame_paths) != len(det_results): + raise ValueError(f"Number of frames ({len(frame_paths)}) must match " + f"number of detection results ({len(det_results)})") + + # Model initialization + if isinstance(pose_config, nn.Module): + model = pose_config + else: + model = init_model(pose_config, pose_checkpoint, device) + + # Get the correct number of keypoints from model metadata + num_keypoints = model.dataset_meta['num_keypoints'] + + results = [] + data_samples = [] + + # Process in batches + total_batches = (len(frame_paths) + batch_size - 1) // batch_size + for batch_idx in tqdm(range(total_batches), desc="Processing batches"): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(frame_paths)) + + batch_frame_paths = frame_paths[start_idx:end_idx] + batch_det_results = det_results[start_idx:end_idx] + + # Load images + batch_images = [mmcv.imread(frame_path) + for frame_path in batch_frame_paths] + + # Process each image in the batch + for img, dets, frame_path in zip(batch_images, batch_det_results, batch_frame_paths): + # Validate detection format + if dets.size > 0 and dets.shape[1] < 4: + raise ValueError( + f"Detection boxes must have at least 4 values (x1,y1,x2,y2), got shape {dets.shape}") + + # Run inference using the model's test_step directly with image data + # This avoids repeated I/O operations and is more efficient + data_info = dict(img=img, bbox=dets[..., :4]) + data = model.data_preprocessor(data_info, False) + + with torch.no_grad(): + predictions = model.forward(data, mode='predict') + + pose_data_sample = merge_data_samples(predictions) + pose_data_sample.dataset_meta = model.dataset_meta + + # Handle empty predictions + if not hasattr(pose_data_sample, 'pred_instances'): + pred_instances_data = dict( + keypoints=np.empty(shape=(0, num_keypoints, 2)), + keypoints_scores=np.empty( + shape=(0, num_keypoints), dtype=np.float32), + bboxes=np.empty(shape=(0, 4), dtype=np.float32), + bbox_scores=np.empty(shape=(0), dtype=np.float32)) + pose_data_sample.pred_instances = InstanceData( + **pred_instances_data) + + poses = pose_data_sample.pred_instances.to_dict() + results.append(poses) + data_samples.append(pose_data_sample) + + return results, data_samples diff --git a/mmaction/apis/enhanced_inference.py b/mmaction/apis/enhanced_inference.py new file mode 100644 index 0000000000..b503d21008 --- /dev/null +++ b/mmaction/apis/enhanced_inference.py @@ -0,0 +1,150 @@ +import torch +import numpy as np +from typing import List, Union, Tuple, Dict, Optional +from pathlib import Path +import torch.nn as nn +import mmengine +from mmengine.utils import track_iter_progress + + +def enhanced_pose_inference( + pose_config: Union[str, Path, mmengine.Config, nn.Module], + pose_checkpoint: str, + frame_paths: List[str], + det_results: List[np.ndarray], + device: Union[str, torch.device] = 'cuda:0', + batch_size: int = 1, + use_temporal_smoothing: bool = False, + smoothing_window_size: int = 7, + smoothing_sigma: float = 1.5, + keypoint_threshold: float = 0.3, + return_heatmaps: bool = False +) -> Tuple[List[Dict[str, np.ndarray]], List]: + """Enhanced pose estimation with batching and temporal smoothing options. + + Args: + pose_config: Pose config file path or model object. + pose_checkpoint: Checkpoint path/url. + frame_paths: The paths of frames to do pose inference. + det_results: List of detected human boxes. + device: The desired device for inference. + batch_size: Number of frames to process in a batch. + use_temporal_smoothing: Whether to apply temporal smoothing to keypoints. + smoothing_window_size: Size of the Gaussian smoothing window. + smoothing_sigma: Standard deviation for Gaussian kernel. + keypoint_threshold: Confidence threshold for keypoints. + return_heatmaps: Whether to return heatmaps along with keypoints. + + Returns: + Tuple of pose estimation results and data samples. + """ + try: + from mmpose.apis import inference_topdown, init_model + from mmpose.structures import PoseDataSample, merge_data_samples + except (ImportError, ModuleNotFoundError): + raise ImportError('Failed to import required modules from MMPose') + + # Import the temporal_smoothing function we implemented above + try: + from mmaction.apis.temporal_processing import apply_temporal_smoothing + except ImportError: + use_temporal_smoothing = False + print("Warning: Temporal smoothing module not found. Disabling smoothing.") + + # Input validation + if not frame_paths: + raise ValueError("frame_paths cannot be empty") + if len(frame_paths) != len(det_results): + raise ValueError(f"Number of frames ({len(frame_paths)}) must match " + f"number of detection results ({len(det_results)})") + + # Model initialization + if isinstance(pose_config, nn.Module): + model = pose_config + else: + model = init_model(pose_config, pose_checkpoint, device) + + # Get the correct number of keypoints from model metadata + try: + num_keypoints = model.dataset_meta['num_keypoints'] + except (KeyError, AttributeError): + # For action recognition models like PoseC3D + if hasattr(model, 'backbone') and hasattr(model.backbone, 'in_channels'): + # If it's a pose recognition model, use in_channels + num_keypoints = model.backbone.in_channels + print( + f"Using backbone in_channels as num_keypoints: {num_keypoints}") + else: + # Default to standard COCO keypoints + print("Warning: Could not determine number of keypoints, using default (17)") + num_keypoints = 17 + + results = [] + data_samples = [] + + # Process frames (with optional batching) + if batch_size <= 1: + # Single frame processing + print('Performing Human Pose Estimation for each frame') + for f, d in track_iter_progress(list(zip(frame_paths, det_results))): + # Validate detection format + if d.size > 0 and d.shape[1] < 4: + raise ValueError( + f"Detection boxes must have at least 4 values (x1,y1,x2,y2), got shape {d.shape}") + + pose_data_samples: List[PoseDataSample] = inference_topdown( + model, f, d[..., :4], bbox_format='xyxy') + pose_data_sample = merge_data_samples(pose_data_samples) + pose_data_sample.dataset_meta = model.dataset_meta + + # Handle empty predictions + if not hasattr(pose_data_sample, 'pred_instances'): + pred_instances_data = dict( + keypoints=np.empty(shape=(0, num_keypoints, 2)), + keypoints_scores=np.empty( + shape=(0, num_keypoints), dtype=np.float32), + bboxes=np.empty(shape=(0, 4), dtype=np.float32), + bbox_scores=np.empty(shape=(0), dtype=np.float32)) + pose_data_sample.pred_instances = InstanceData( + **pred_instances_data) + + poses = pose_data_sample.pred_instances.to_dict() + + # Filter low-confidence keypoints + if 'keypoints_scores' in poses and poses['keypoints_scores'].size > 0: + mask = poses['keypoints_scores'] < keypoint_threshold + # Set low-confidence keypoints to 0 + if 'keypoints' in poses and poses['keypoints'].size > 0: + poses['keypoints'][mask] = 0 + + results.append(poses) + data_samples.append(pose_data_sample) + else: + # Use the batch processing function we implemented above + try: + from mmaction.apis.batch_inference import batch_pose_inference + results, data_samples = batch_pose_inference( + pose_config, pose_checkpoint, frame_paths, det_results, + batch_size=batch_size, device=device + ) + except ImportError: + print( + "Warning: Batch inference module not found. Falling back to single frame processing.") + # Recursively call this function with batch_size=1 + return enhanced_pose_inference( + pose_config, pose_checkpoint, frame_paths, det_results, + device=device, batch_size=1, use_temporal_smoothing=use_temporal_smoothing, + smoothing_window_size=smoothing_window_size, smoothing_sigma=smoothing_sigma, + keypoint_threshold=keypoint_threshold, return_heatmaps=return_heatmaps + ) + + # Apply temporal smoothing if requested + if use_temporal_smoothing and len(results) > 3: + results = apply_temporal_smoothing( + results, + window_size=smoothing_window_size, + sigma=smoothing_sigma, + min_score_threshold=keypoint_threshold + ) + + return results, data_samples diff --git a/mmaction/apis/inference_refactored.py b/mmaction/apis/inference_refactored.py new file mode 100644 index 0000000000..fbeeb86384 --- /dev/null +++ b/mmaction/apis/inference_refactored.py @@ -0,0 +1,294 @@ +""" +Copyright (c) OpenMMLab. All rights reserved. +Refactored version with improved error handling and organization. +""" +import os.path as osp +from pathlib import Path +from typing import List, Optional, Tuple, Union, Dict + +import mmengine +import numpy as np +import torch +import torch.nn as nn +from mmengine.dataset import Compose, pseudo_collate +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint +from mmengine.structures import InstanceData +from mmengine.utils import track_iter_progress + +from mmaction.registry import MODELS +from mmaction.structures import ActionDataSample + + +def init_recognizer(config: Union[str, Path, mmengine.Config], + checkpoint: Optional[str] = None, + device: Union[str, torch.device] = 'cuda:0') -> nn.Module: + """Initialize a recognizer from config file. + + Args: + config (str or :obj:`Path` or :obj:`mmengine.Config`): Config file + path, :obj:`Path` or the config object. + checkpoint (str, optional): Checkpoint path/url. Defaults to None. + device (str | torch.device): The desired device. Defaults to 'cuda:0'. + + Returns: + nn.Module: The constructed recognizer. + """ + if isinstance(config, (str, Path)): + config = mmengine.Config.fromfile(config) + elif not isinstance(config, mmengine.Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + + init_default_scope(config.get('default_scope', 'mmaction')) + + if hasattr(config.model, 'backbone') and config.model.backbone.get( + 'pretrained', None): + config.model.backbone.pretrained = None + model = MODELS.build(config.model) + + if checkpoint is not None: + load_checkpoint(model, checkpoint, map_location='cpu') + model.cfg = config + model.to(device) + model.eval() + return model + + +def create_empty_pose_result(num_keypoints: int) -> Dict: + """Create empty pose result with proper dimensions. + + Args: + num_keypoints (int): Number of keypoints in the model. + + Returns: + Dict: Empty pose result dictionary. + """ + return { + 'keypoints': np.empty(shape=(0, num_keypoints, 2)), + 'keypoint_scores': np.empty(shape=(0, num_keypoints)), + 'bboxes': np.empty(shape=(0, 4)), + 'bbox_scores': np.empty(shape=(0)) + } + + +def validate_inputs(frame_paths: List[str], det_results: List[np.ndarray]): + """Validate input parameters for pose inference. + + Args: + frame_paths (List[str]): List of frame paths. + det_results (List[np.ndarray]): List of detection results. + + Raises: + ValueError: If inputs are invalid. + """ + if not frame_paths: + raise ValueError("frame_paths cannot be empty") + if len(frame_paths) != len(det_results): + raise ValueError( + f"Number of frames ({len(frame_paths)}) must match " + f"number of detection results ({len(det_results)})") + + +def process_pose_data(pose_data_samples: List['PoseDataSample'], + model_meta: Dict, + num_keypoints: int) -> Tuple['PoseDataSample', Dict]: + """Process pose data samples and create pose results. + + Args: + pose_data_samples (List[PoseDataSample]): List of pose data samples. + model_meta (Dict): Model metadata. + num_keypoints (int): Number of keypoints. + + Returns: + Tuple[PoseDataSample, Dict]: Processed pose data sample and poses dict. + """ + from mmpose.structures import merge_data_samples + pose_data_sample = merge_data_samples(pose_data_samples) + pose_data_sample.dataset_meta = model_meta + + if not hasattr(pose_data_sample, 'pred_instances'): + pred_instances_data = create_empty_pose_result(num_keypoints) + pose_data_sample.pred_instances = InstanceData(**pred_instances_data) + + return pose_data_sample, pose_data_sample.pred_instances.to_dict() + + +def pose_inference(pose_config: Union[str, Path, mmengine.Config, nn.Module], + pose_checkpoint: str, + frame_paths: List[str], + det_results: List[np.ndarray], + device: Union[str, torch.device] = 'cuda:0') -> tuple: + """Perform Top-Down pose estimation with improved error handling. + + Args: + pose_config: Pose config file path or pose model object. + pose_checkpoint: Checkpoint path/url. + frame_paths: The paths of frames to do pose inference. + det_results: List of detected human boxes. + device: The desired device. Defaults to 'cuda:0'. + + Returns: + Tuple[List[Dict], List[PoseDataSample]]: Pose results and data samples. + """ + try: + from mmpose.apis import inference_topdown, init_model + from mmpose.structures import PoseDataSample + except ImportError: + raise ImportError('Failed to import required mmpose components') + + validate_inputs(frame_paths, det_results) + + # Initialize model + model = pose_config if isinstance(pose_config, nn.Module) else \ + init_model(pose_config, pose_checkpoint, device) + num_keypoints = model.dataset_meta['num_keypoints'] + + results = [] + data_samples = [] + print('Performing Human Pose Estimation for each frame') + + for f, d in track_iter_progress(list(zip(frame_paths, det_results))): + if d.size == 0: + # Handle empty detection results + pred_instances_data = create_empty_pose_result(num_keypoints) + pose_data_sample = PoseDataSample() + pose_data_sample.pred_instances = InstanceData( + **pred_instances_data) + pose_data_sample.dataset_meta = model.dataset_meta + results.append(pred_instances_data) + data_samples.append(pose_data_sample) + continue + + if d.shape[1] < 4: + raise ValueError( + f"Detection boxes must have at least 4 values, got shape {d.shape}") + + pose_data_samples = inference_topdown( + model, f, d[..., :4], bbox_format='xyxy') + pose_data_sample, poses = process_pose_data( + pose_data_samples, model.dataset_meta, num_keypoints) + + results.append(poses) + data_samples.append(pose_data_sample) + + return results, data_samples + + +def prepare_skeleton_data(pose_results: List[dict], + img_shape: Tuple[int]) -> Tuple[Dict, np.ndarray, np.ndarray]: + """Prepare data for skeleton inference. + + Args: + pose_results: List of pose estimation results. + img_shape: Original image shape. + + Returns: + Tuple containing fake annotation, keypoint array, and keypoint scores. + """ + h, w = img_shape + num_keypoint = pose_results[0]['keypoints'].shape[1] + num_frame = len(pose_results) + num_person = max([len(x['keypoints']) for x in pose_results]) + + fake_anno = { + 'frame_dict': '', + 'label': -1, + 'img_shape': (h, w), + 'origin_shape': (h, w), + 'start_index': 0, + 'modality': 'Pose', + 'total_frames': num_frame + } + + keypoint = np.zeros( + (num_frame, num_person, num_keypoint, 2), dtype=np.float16) + keypoint_score = np.zeros( + (num_frame, num_person, num_keypoint), dtype=np.float16) + + for f_idx, frm_pose in enumerate(pose_results): + frm_num_persons = frm_pose['keypoints'].shape[0] + for p_idx in range(frm_num_persons): + keypoint[f_idx, p_idx] = frm_pose['keypoints'][p_idx] + keypoint_score[f_idx, p_idx] = frm_pose['keypoint_scores'][p_idx] + + return fake_anno, keypoint, keypoint_score + + +def inference_skeleton(model: nn.Module, + pose_results: List[dict], + img_shape: Tuple[int], + test_pipeline: Optional[Compose] = None) -> ActionDataSample: + """Inference a pose results with the skeleton recognizer. + + Args: + model: The loaded recognizer. + pose_results: The pose estimation results dictionary. + img_shape: The original image shape. + test_pipeline: The test pipeline. Defaults to None. + + Returns: + ActionDataSample: The inference results. + """ + if not pose_results: + raise ValueError("pose_results cannot be empty") + + if test_pipeline is None: + cfg = model.cfg + init_default_scope(cfg.get('default_scope', 'mmaction')) + test_pipeline = Compose(cfg.test_pipeline) + + fake_anno, keypoint, keypoint_score = prepare_skeleton_data( + pose_results, img_shape) + + # Transpose keypoint data to match expected format + fake_anno['keypoint'] = keypoint.transpose((1, 0, 2, 3)) + fake_anno['keypoint_score'] = keypoint_score.transpose((1, 0, 2)) + + return inference_recognizer(model, fake_anno, test_pipeline) + + +def inference_recognizer(model: nn.Module, + video: Union[str, dict], + test_pipeline: Optional[Compose] = None) -> ActionDataSample: + """Inference a video with the recognizer. + + Args: + model: The loaded recognizer. + video: Video file path or results dictionary. + test_pipeline: The test pipeline. Defaults to None. + + Returns: + ActionDataSample: The inference results. + """ + if test_pipeline is None: + cfg = model.cfg + init_default_scope(cfg.get('default_scope', 'mmaction')) + test_pipeline = Compose(cfg.test_pipeline) + + input_flag = None + if isinstance(video, dict): + input_flag = 'dict' + elif isinstance(video, str) and osp.exists(video): + input_flag = 'audio' if video.endswith('.npy') else 'video' + else: + raise RuntimeError(f'Unsupported video type: {type(video)}') + + if input_flag == 'dict': + data = video + elif input_flag == 'video': + data = dict(filename=video, label=-1, start_index=0, modality='RGB') + else: # audio + data = dict( + audio_path=video, + total_frames=len(np.load(video)), + start_index=0, + label=-1) + + data = test_pipeline(data) + data = pseudo_collate([data]) + + with torch.no_grad(): + result = model.test_step(data)[0] + + return result diff --git a/mmaction/apis/temporal_processing.py b/mmaction/apis/temporal_processing.py new file mode 100644 index 0000000000..d2d14d86ed --- /dev/null +++ b/mmaction/apis/temporal_processing.py @@ -0,0 +1,128 @@ +import numpy as np +from typing import List, Dict, Union +import scipy.ndimage as ndimage + + +def apply_temporal_smoothing( + pose_results: List[Dict[str, np.ndarray]], + window_size: int = 7, + sigma: float = 1.5, + min_score_threshold: float = 0.2 +) -> List[Dict[str, np.ndarray]]: + """Apply temporal smoothing to keypoint sequences for video stability. + + Args: + pose_results: List of pose results for each frame + window_size: Size of the Gaussian smoothing window (odd number) + sigma: Standard deviation for Gaussian kernel + min_score_threshold: Minimum confidence score to consider a keypoint valid + + Returns: + List of smoothed pose results + """ + if len(pose_results) < 3: + # Not enough frames to smooth + return pose_results + + # Ensure window size is odd + window_size = max(3, window_size if window_size % + 2 == 1 else window_size + 1) + + # Extract keypoints and scores across frames + all_keypoints = [] + all_scores = [] + + for frame_result in pose_results: + all_keypoints.append(frame_result.get('keypoints', np.array([]))) + all_scores.append(frame_result.get('keypoints_scores', np.array([]))) + + # Find maximum number of people across all frames + max_people = max([kpts.shape[0] if kpts.size > + 0 else 0 for kpts in all_keypoints]) + if max_people == 0: + return pose_results # No keypoints to smooth + + # Get dimensions + num_frames = len(all_keypoints) + num_keypoints = all_keypoints[0].shape[1] if all_keypoints[0].size > 0 else 0 + + if num_keypoints == 0: + return pose_results # No keypoints to smooth + + # Create aligned arrays for smoothing, padding with zeros for missing people + aligned_keypoints = np.zeros((num_frames, max_people, num_keypoints, 2)) + aligned_scores = np.zeros((num_frames, max_people, num_keypoints)) + + # Fill in the data + for i, (kpts, scores) in enumerate(zip(all_keypoints, all_scores)): + if kpts.size > 0: + aligned_keypoints[i, :kpts.shape[0]] = kpts + aligned_scores[i, :scores.shape[0]] = scores + + # Create a mask for valid keypoints (above threshold) + valid_mask = aligned_scores > min_score_threshold + + # Smooth each person's keypoints over time + smoothed_keypoints = aligned_keypoints.copy() + + # For each person and keypoint + for person_idx in range(max_people): + for kpt_idx in range(num_keypoints): + # Only smooth if we have enough valid points + person_kpt_mask = valid_mask[:, person_idx, kpt_idx] + if np.sum(person_kpt_mask) > window_size // 2: + # Smooth X coordinates + x_values = aligned_keypoints[:, person_idx, kpt_idx, 0] + x_valid = np.where(person_kpt_mask, x_values, np.nan) + + # Replace NaN with interpolated values for smoothing + x_interp = np.copy(x_valid) + mask = np.isnan(x_interp) + x_interp[mask] = np.interp( + np.flatnonzero(mask), + np.flatnonzero(~mask), + x_interp[~mask] + ) + + # Apply Gaussian filter + x_smoothed = ndimage.gaussian_filter1d( + x_interp, sigma=sigma, truncate=2.0) + + # Only update valid keypoints + smoothed_keypoints[:, person_idx, kpt_idx, 0] = np.where( + person_kpt_mask, x_smoothed, aligned_keypoints[:, + person_idx, kpt_idx, 0] + ) + + # Smooth Y coordinates similarly + y_values = aligned_keypoints[:, person_idx, kpt_idx, 1] + y_valid = np.where(person_kpt_mask, y_values, np.nan) + + y_interp = np.copy(y_valid) + mask = np.isnan(y_interp) + y_interp[mask] = np.interp( + np.flatnonzero(mask), + np.flatnonzero(~mask), + y_interp[~mask] + ) + + y_smoothed = ndimage.gaussian_filter1d( + y_interp, sigma=sigma, truncate=2.0) + + smoothed_keypoints[:, person_idx, kpt_idx, 1] = np.where( + person_kpt_mask, y_smoothed, aligned_keypoints[:, + person_idx, kpt_idx, 1] + ) + + # Update the original pose results with smoothed keypoints + smoothed_results = [] + for i, result in enumerate(pose_results): + new_result = result.copy() + num_people = all_keypoints[i].shape[0] if all_keypoints[i].size > 0 else 0 + + if num_people > 0: + new_result['keypoints'] = smoothed_keypoints[i, :num_people].copy() + + smoothed_results.append(new_result) + + return smoothed_results diff --git a/mmaction/models/distillation/__init__.py b/mmaction/models/distillation/__init__.py new file mode 100644 index 0000000000..d0e1c5ca32 --- /dev/null +++ b/mmaction/models/distillation/__init__.py @@ -0,0 +1,17 @@ +# Add imports for distillation components +from .losses import (KDLoss, FeatureDistillationLoss, AttentionTransferLoss, + DynamicTemperatureKDLoss, HintLoss, MultiTeacherDistillationLoss, + SelfAttentionDistillationLoss) +from .recognizers import DistillPoseRecognizer, ProgressiveDistillPoseRecognizer +from .necks import MultiLevelFeatureDistillConnector +from .backbones import PoseC3DStudentBackbone +from .heads import PoseC3DDistillHead + +__all__ = [ + # ... existing components + 'KDLoss', 'FeatureDistillationLoss', 'AttentionTransferLoss', + 'DynamicTemperatureKDLoss', 'HintLoss', 'MultiTeacherDistillationLoss', + 'SelfAttentionDistillationLoss', 'DistillPoseRecognizer', + 'ProgressiveDistillPoseRecognizer', 'MultiLevelFeatureDistillConnector', + 'PoseC3DStudentBackbone', 'PoseC3DDistillHead' +] \ No newline at end of file diff --git a/mmaction/models/distillation/configs/posec3d_complete_distill.py b/mmaction/models/distillation/configs/posec3d_complete_distill.py new file mode 100644 index 0000000000..1301b26964 --- /dev/null +++ b/mmaction/models/distillation/configs/posec3d_complete_distill.py @@ -0,0 +1,110 @@ +_base_ = ['./posec3d_k400.py'] + +# Teacher model config and checkpoint +teacher_config = 'configs/skeleton/posec3d/posec3d_k400.py' +teacher_checkpoint = 'checkpoints/posec3d_k400-73b07ecd.pth' + +# Student model will be much smaller than teacher +model = dict( + type='ProgressiveDistillPoseRecognizer', + backbone=dict( + type='PoseC3DStudentBackbone', + depth=18, # ResNet18 vs teacher's ResNet50 + in_channels=17, + base_channels=32, # Half the channels of teacher + num_stages=4, + out_indices=(3,), + stage_blocks=(2, 2, 2, 2), + conv1_stride_s=1, + pool1_stride_s=1, + inflate=(0, 1, 1, 1), + spatial_strides=(1, 2, 2, 2), + temporal_strides=(1, 1, 1, 1)), + cls_head=dict( + type='PoseC3DDistillHead', + in_channels=256, # Reduced from 2048 + num_classes=400, + loss_cls=dict(type='CrossEntropyLoss')), + + # Teacher configuration + teacher_config=teacher_config, + teacher_ckpt=teacher_checkpoint, + + # Feature distillation + feature_dist_cfg=dict( + type='FeatureDistillationLoss', + student_channels=256, # From student's last stage + teacher_channels=2048, # From teacher's last stage + distill_type='mse', + transform_type='linear', + weight=0.5, + ), + + # Attention distillation + attention_dist_cfg=dict( + type='AttentionTransferLoss', + beta=1.0, + normalize=True, + ), + + # Logit distillation with temperature scaling + logit_dist_cfg=dict( + type='KDLoss', + temperature=4.0, + alpha=0.5, + ), + + # Progressive distillation strategy + progressive_cfg=dict( + phases=[ + dict(epochs=(0, 20), distill_types=['feature']), # Start with feature matching + dict(epochs=(20, 40), distill_types=['feature', 'attention']), # Add attention + dict(epochs=(40, -1), distill_types=['feature', 'attention', 'logit']), # Full distillation + ] + ), +) + +# Add custom hooks for distillation process +custom_hooks = [ + dict( + type='DistillationHook', + # Alpha scheduling: gradually increase the weight of KD loss + alpha_scheduler=lambda epoch, max_epochs: min(0.1 + epoch / max_epochs, 0.5), + # Temperature scheduling: gradually reduce temperature + temp_scheduler=lambda epoch, max_epochs: max(4.0 - 3.0 * epoch / max_epochs, 1.0), + ), +] + +# Optimizer: smaller LR for distillation +optimizer = dict( + type='SGD', + lr=0.02, # Reduced from standard training + momentum=0.9, + weight_decay=0.0001, + nesterov=True, + paramwise_cfg=dict( + # Apply different LR to backbone vs. other components + custom_keys={ + 'backbone': dict(lr_mult=0.1), # Lower LR for backbone + 'cls_head': dict(lr_mult=1.0), # Normal LR for classification head + } + ) +) + +# Learning rate config +lr_config = dict( + policy='CosineAnnealing', + min_lr=0, + warmup='linear', + warmup_iters=2000, + warmup_ratio=0.1 +) + +# Runtime settings +total_epochs = 100 # Longer training for distillation +evaluation = dict(interval=5) # Evaluate more frequently +checkpoint_config = dict(interval=5) # Save checkpoint more frequently +log_config = dict(interval=20) + +# Working directory +work_dir = './work_dirs/posec3d_distill_complete/' \ No newline at end of file diff --git a/mmaction/models/distillation/core/hooks/distillation_hook.py b/mmaction/models/distillation/core/hooks/distillation_hook.py new file mode 100644 index 0000000000..3033d6c1bc --- /dev/null +++ b/mmaction/models/distillation/core/hooks/distillation_hook.py @@ -0,0 +1,37 @@ +from mmcv.runner import HOOKS, Hook +from mmaction.registry import MODELS + +@MODELS.register_module() +class DistillationHook(Hook): + """Hook for handling knowledge distillation during training.""" + + def __init__(self, + alpha_scheduler=None, # Schedule for alpha value + temp_scheduler=None, # Schedule for temperature + ): + self.alpha_scheduler = alpha_scheduler + self.temp_scheduler = temp_scheduler + + def before_train_epoch(self, runner): + """Called before each training epoch.""" + # Update alpha value if scheduler provided + if self.alpha_scheduler is not None: + alpha = self.alpha_scheduler(runner.epoch, runner.max_epochs) + self._update_alpha(runner.model, alpha) + + # Update temperature if scheduler provided + if self.temp_scheduler is not None: + temp = self.temp_scheduler(runner.epoch, runner.max_epochs) + self._update_temperature(runner.model, temp) + + def _update_alpha(self, model, alpha): + """Update alpha value in KD loss modules.""" + for name, module in model.named_modules(): + if hasattr(module, 'alpha') and isinstance(module.alpha, float): + module.alpha = alpha + + def _update_temperature(self, model, temp): + """Update temperature in KD loss modules.""" + for name, module in model.named_modules(): + if hasattr(module, 'temperature') and not isinstance(module, DynamicTemperatureKDLoss): + module.temperature = temp \ No newline at end of file diff --git a/mmaction/models/distillation/losses/attention_transfer.py b/mmaction/models/distillation/losses/attention_transfer.py new file mode 100644 index 0000000000..d700d567a5 --- /dev/null +++ b/mmaction/models/distillation/losses/attention_transfer.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmaction.models.builder import LOSSES +from mmaction.registry import MODELS + +@MODELS.register_module() +class AttentionTransferLoss(nn.Module): + """Knowledge distillation via attention transfer.""" + + def __init__(self, beta=1.0, normalize=True): + super().__init__() + self.beta = beta + self.normalize = normalize + + def _attention_map(self, feat): + """Convert feature maps to attention maps.""" + # Sum of absolute values for channel-wise attention (L2 norm along channel dimension) + return F.normalize(feat.pow(2).sum(1), p=1, dim=(1, 2)) + + def forward(self, student_feat, teacher_feat): + """ + Args: + student_feat (list[Tensor]): List of feature maps from student + teacher_feat (list[Tensor]): List of feature maps from teacher + """ + at_loss = 0 + for s, t in zip(student_feat, teacher_feat): + # Handle different spatial dimensions with interpolation + if s.shape[2:] != t.shape[2:]: + s = F.interpolate(s, t.shape[2:], mode='trilinear', align_corners=False) + + # Generate attention maps + s_attention = self._attention_map(s) + t_attention = self._attention_map(t) + + # Calculate L2 distance between normalized attention maps + at_loss += F.mse_loss(s_attention, t_attention) + + return self.beta * at_loss \ No newline at end of file diff --git a/mmaction/models/distillation/losses/distillation_losses.py b/mmaction/models/distillation/losses/distillation_losses.py new file mode 100644 index 0000000000..3c0e3db286 --- /dev/null +++ b/mmaction/models/distillation/losses/distillation_losses.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmaction.models.builder import LOSSES +from mmaction.registry import MODELS + +@MODELS.register_module() +class KDLoss(nn.Module): + """Standard KL-divergence based knowledge distillation loss.""" + + def __init__(self, temperature=4.0, alpha=0.5, reduction='batchmean'): + super().__init__() + self.temperature = temperature + self.alpha = alpha + self.reduction = reduction + self.kl_div = nn.KLDivLoss(reduction=reduction) + self.ce = nn.CrossEntropyLoss() + + def forward(self, student_logits, teacher_logits, labels=None): + """ + Args: + student_logits (Tensor): Logits from student model + teacher_logits (Tensor): Logits from teacher model + labels (Tensor, optional): Ground truth labels + """ + # Apply temperature scaling + soft_student = F.log_softmax(student_logits / self.temperature, dim=1) + soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1) + + # KL divergence loss + kd_loss = self.kl_div(soft_student, soft_teacher) * \ + (self.temperature ** 2) + + if labels is not None: + # Hard label loss (cross-entropy) + ce_loss = self.ce(student_logits, labels) + # Combined loss + total_loss = (1 - self.alpha) * ce_loss + self.alpha * kd_loss + return total_loss, ce_loss, kd_loss + else: + return kd_loss + + +@MODELS.register_module() +class FeatureDistillationLoss(nn.Module): + """Feature-based distillation loss with optional transformation.""" + + def __init__(self, + student_channels=256, + teacher_channels=2048, + distill_type='mse', + transform_type='linear', + weight=1.0): + super().__init__() + self.weight = weight + self.distill_type = distill_type + + # Feature transformation if dimensions differ + if transform_type == 'linear' and student_channels != teacher_channels: + self.transform = nn.Linear(student_channels, teacher_channels) + elif transform_type == 'conv1x1': + self.transform = nn.Conv3d(student_channels, teacher_channels, + kernel_size=1, stride=1, padding=0) + else: + self.transform = None + + def forward(self, student_feat, teacher_feat): + """Calculate feature distillation loss.""" + if self.transform is not None: + if len(student_feat.shape) == 2: # Linear features + student_feat = self.transform(student_feat) + else: # 3D features + student_feat = self.transform(student_feat) + + # Match shapes if needed (e.g., via pooling) + if student_feat.shape != teacher_feat.shape: + # Apply adaptive pooling to match spatial dimensions + if len(student_feat.shape) == 5: # 3D features (B,C,T,H,W) + pool = nn.AdaptiveAvgPool3d(teacher_feat.shape[2:]) + student_feat = pool(student_feat) + + # Calculate loss based on type + if self.distill_type == 'mse': + loss = F.mse_loss(student_feat, teacher_feat) + elif self.distill_type == 'l1': + loss = F.l1_loss(student_feat, teacher_feat) + elif self.distill_type == 'cosine': + student_norm = F.normalize(student_feat, dim=1) + teacher_norm = F.normalize(teacher_feat, dim=1) + loss = 1 - (student_norm * teacher_norm).sum(dim=1).mean() + + return self.weight * loss diff --git a/mmaction/models/distillation/losses/dynamic_temperature.py b/mmaction/models/distillation/losses/dynamic_temperature.py new file mode 100644 index 0000000000..d32c1dd2ec --- /dev/null +++ b/mmaction/models/distillation/losses/dynamic_temperature.py @@ -0,0 +1,75 @@ +from tools.data.video_retrieval.prepare_msvd import F +import torch +import torch.nn as nn +from mmcv.runner import Hook, HOOKS +from mmaction.models.builder import LOSSES +from mmaction.registry import MODELS + +@MODELS.register_module() +class DynamicTemperatureKDLoss(nn.Module): + """KD loss with dynamic temperature scheduling.""" + + def __init__(self, + init_temperature=4.0, + final_temperature=1.0, + alpha=0.5): + super().__init__() + self.init_temp = init_temperature + self.final_temp = final_temperature + self.alpha = alpha + self.current_temp = init_temperature + self.ce = nn.CrossEntropyLoss() + self.kl_div = nn.KLDivLoss(reduction='batchmean') + + def forward(self, student_logits, teacher_logits, labels=None): + """Forward function with current temperature.""" + # Temperature scaled logits + soft_student = F.log_softmax(student_logits / self.current_temp, dim=1) + soft_teacher = F.softmax(teacher_logits / self.current_temp, dim=1) + + # KL divergence loss with temperature scaling + kd_loss = self.kl_div(soft_student, soft_teacher) * (self.current_temp ** 2) + + if labels is not None: + ce_loss = self.ce(student_logits, labels) + total_loss = (1 - self.alpha) * ce_loss + self.alpha * kd_loss + return total_loss, ce_loss, kd_loss + else: + return kd_loss + +@MODELS.register_module() +class DynamicTemperatureHook(Hook): + """Hook to update temperature during training.""" + + def __init__(self, by_epoch=True): + self.by_epoch = by_epoch + + def before_train_epoch(self, runner): + """Update temperature at the beginning of each epoch.""" + if not self.by_epoch: + return + + # Get current progress ratio + progress_ratio = runner.epoch / runner.max_epochs + + # Find all KD loss modules + for name, module in runner.model.named_modules(): + if isinstance(module, DynamicTemperatureKDLoss): + # Linear annealing from init_temp to final_temp + module.current_temp = module.init_temp + progress_ratio * ( + module.final_temp - module.init_temp) + + def before_train_iter(self, runner): + """Update temperature at each iteration if not by_epoch.""" + if self.by_epoch: + return + + # Get current progress ratio + progress_ratio = (runner.iter + 1) / runner.max_iters + + # Find all KD loss modules + for name, module in runner.model.named_modules(): + if isinstance(module, DynamicTemperatureKDLoss): + # Linear annealing from init_temp to final_temp + module.current_temp = module.init_temp + progress_ratio * ( + module.final_temp - module.init_temp) \ No newline at end of file diff --git a/mmaction/models/distillation/losses/self_attention_distillation.py b/mmaction/models/distillation/losses/self_attention_distillation.py new file mode 100644 index 0000000000..b277b90b23 --- /dev/null +++ b/mmaction/models/distillation/losses/self_attention_distillation.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmaction.registry import MODELS + +@MODELS.register_module() +class SelfAttentionDistillationLoss(nn.Module): + """Distill self-attention knowledge from teacher to student.""" + + def __init__(self, weight=1.0): + super().__init__() + self.weight = weight + + def _get_attention_map(self, feat): + """Generate self-attention map: Q*K^T.""" + # Reshape from [B,C,T,H,W] to [B,C,THW] + b, c = feat.shape[:2] + feat_flat = feat.reshape(b, c, -1) + + # Normalize feature for numerical stability + feat_norm = F.normalize(feat_flat, dim=1) + + # Compute self-attention: [B,THW,THW] + attention = torch.bmm(feat_norm.transpose(1, 2), feat_norm) + return attention + + def forward(self, student_feat, teacher_feat): + """Calculate self-attention distillation loss.""" + # Generate attention maps + student_attention = self._get_attention_map(student_feat) + teacher_attention = self._get_attention_map(teacher_feat) + + # If dimensions don't match, downsample the larger one + if student_attention.shape != teacher_attention.shape: + # Get smaller dimension + min_dim = min(student_attention.shape[-1], teacher_attention.shape[-1]) + + # Adaptive pooling to smaller dimension + if student_attention.shape[-1] > min_dim: + pool = nn.AdaptiveAvgPool2d(min_dim) + student_attention = pool(student_attention) + if teacher_attention.shape[-1] > min_dim: + pool = nn.AdaptiveAvgPool2d(min_dim) + teacher_attention = pool(teacher_attention) + + # Calculate loss: MSE between attention maps + loss = F.mse_loss(student_attention, teacher_attention) + + return self.weight * loss \ No newline at end of file diff --git a/mmaction/models/distillation/models/necks/distill_connector.py b/mmaction/models/distillation/models/necks/distill_connector.py new file mode 100644 index 0000000000..fe197d43b6 --- /dev/null +++ b/mmaction/models/distillation/models/necks/distill_connector.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn +from mmaction.registry import MODELS + +@MODELS.register_module() +class MultiLevelFeatureDistillConnector(nn.Module): + """Connector for multi-level feature distillation.""" + + def __init__(self, + in_channels=[64, 128, 256, 512], + out_channels=[256, 512, 1024, 2048], + kernel_sizes=[1, 1, 1, 1]): + super().__init__() + + self.transformers = nn.ModuleList() + for i, (in_ch, out_ch, k) in enumerate(zip(in_channels, out_channels, kernel_sizes)): + self.transformers.append( + nn.Conv3d(in_ch, out_ch, kernel_size=k, stride=1, padding=k//2) + ) + + def forward(self, feats): + """Transform student features to match teacher dimensions.""" + out_feats = [] + for i, feat in enumerate(feats): + out_feats.append(self.transformers[i](feat)) + return out_feats \ No newline at end of file diff --git a/mmaction/models/distillation/recognizers/progressive_distill_recognizer.py b/mmaction/models/distillation/recognizers/progressive_distill_recognizer.py new file mode 100644 index 0000000000..cc8d27b3d0 --- /dev/null +++ b/mmaction/models/distillation/recognizers/progressive_distill_recognizer.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +from mmaction.registry import MODELS +from mmaction.models.recognizers.base import BaseRecognizer + +@MODELS.register_module() +class ProgressiveDistillPoseRecognizer(BaseRecognizer): + """Recognizer with progressive distillation strategy.""" + + def __init__(self, + backbone, + cls_head, + teacher_config=None, + teacher_ckpt=None, + feature_dist_cfg=None, + attention_dist_cfg=None, + logit_dist_cfg=None, + progressive_cfg=dict( + phases=[ + dict(epochs=(0, 10), distill_types=['feature']), + dict(epochs=(10, 20), distill_types=['feature', 'attention']), + dict(epochs=(20, -1), distill_types=['feature', 'attention', 'logit']) + ] + ), + train_cfg=None, + test_cfg=None): + + super().__init__(backbone, None, train_cfg, test_cfg) + self.cls_head = cls_head + self.teacher_model = None + self.current_epoch = 0 + self.progressive_cfg = progressive_cfg + + # Initialize distillation losses + self.feature_dist = None + self.attention_dist = None + self.logit_dist = None + + if feature_dist_cfg is not None: + from mmaction.models.builder import build_loss + self.feature_dist = build_loss(feature_dist_cfg) + + if attention_dist_cfg is not None: + from mmaction.models.builder import build_loss + self.attention_dist = build_loss(attention_dist_cfg) + + if logit_dist_cfg is not None: + from mmaction.models.builder import build_loss + self.logit_dist = build_loss(logit_dist_cfg) + + # Initialize teacher model if provided + if teacher_config is not None and teacher_ckpt is not None: + from mmaction.apis import init_recognizer + self.teacher_model = init_recognizer(teacher_config, teacher_ckpt) + # Freeze teacher model + for param in self.teacher_model.parameters(): + param.requires_grad = False + self.teacher_model.eval() + + def extract_feat(self, imgs): + """Extract features through a backbone (student).""" + x = self.backbone(imgs) + return x + + def extract_teacher_feat(self, imgs): + """Extract features from teacher model.""" + with torch.no_grad(): + if hasattr(self.teacher_model, 'backbone'): + teacher_feats = self.teacher_model.backbone(imgs) + else: + teacher_feats = self.teacher_model.extract_feat(imgs) + return teacher_feats + + def get_active_distill_types(self): + """Get active distillation types based on current epoch.""" + active_types = [] + for phase in self.progressive_cfg['phases']: + start_epoch, end_epoch = phase['epochs'] + if end_epoch == -1 or self.current_epoch < end_epoch: + if self.current_epoch >= start_epoch: + active_types.extend(phase['distill_types']) + return list(set(active_types)) # Remove duplicates + + def forward_train(self, imgs, labels, **kwargs): + """Training forward function with progressive distillation.""" + # Update current epoch if provided + if 'epoch' in kwargs: + self.current_epoch = kwargs['epoch'] + + # Get active distillation types + active_types = self.get_active_distill_types() + + # Extract features from student model + student_feat = self.extract_feat(imgs) + + # Get predictions from student + cls_score = self.cls_head(student_feat) + losses = dict() + + # Standard classification loss + cls_loss = self.cls_head.loss(cls_score, labels) + losses.update(cls_loss) + + # Skip distillation if teacher not available + if self.teacher_model is None: + return losses + + # Apply active distillation strategies + teacher_logits = None + teacher_feats = None + + if 'logit' in active_types and self.logit_dist is not None: + if teacher_logits is None: + # Extract teacher logits + with torch.no_grad(): + teacher_feat = self.extract_teacher_feat(imgs) + teacher_logits = self.teacher_model.cls_head(teacher_feat) + + # Apply logit distillation + kd_loss = self.logit_dist(cls_score, teacher_logits, labels) + if isinstance(kd_loss, tuple): + losses['loss_kd'] = kd_loss[0] + losses['loss_ce_kd'] = kd_loss[1] + losses['loss_kl_kd'] = kd_loss[2] + else: + losses['loss_kd'] = kd_loss + + if 'feature' in active_types and self.feature_dist is not None: + if teacher_feats is None: + # Extract teacher features + with torch.no_grad(): + teacher_feats = self.extract_teacher_feat(imgs) + + # Apply feature distillation + feat_dist_loss = self.feature_dist(student_feat, teacher_feats) + losses['loss_feat_dist'] = feat_dist_loss + + if 'attention' in active_types and self.attention_dist is not None: + if teacher_feats is None: + # Extract teacher features + with torch.no_grad(): + teacher_feats = self.extract_teacher_feat(imgs) + + # Apply attention transfer + at_loss = self.attention_dist(student_feat, teacher_feats) + losses['loss_attention'] = at_loss + + return losses + + def forward_test(self, imgs): + """Test function.""" + # During testing, only use student model + x = self.extract_feat(imgs) + cls_score = self.cls_head(x) + + return cls_score.cpu().numpy() \ No newline at end of file diff --git a/mmaction/models/recognizers/custom/__init__.py b/mmaction/models/recognizers/custom/__init__.py new file mode 100644 index 0000000000..c10f12eff3 --- /dev/null +++ b/mmaction/models/recognizers/custom/__init__.py @@ -0,0 +1,3 @@ +from .distill_posec3d import DistillPoseC3D + +__all__ = ['DistillPoseC3D'] diff --git a/mmaction/models/recognizers/custom/distill_posec3d.py b/mmaction/models/recognizers/custom/distill_posec3d.py new file mode 100644 index 0000000000..71dcedc974 --- /dev/null +++ b/mmaction/models/recognizers/custom/distill_posec3d.py @@ -0,0 +1,188 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy + +from mmaction.models.builder import RECOGNIZERS, build_backbone, build_head, build_loss +from mmaction.models.recognizers.base import BaseRecognizer + + +@RECOGNIZERS.register_module() +class DistillPoseC3D(BaseRecognizer): + """Pose C3D model with knowledge distillation. + + Args: + backbone (dict): Backbone modules to extract features. + teacher_backbone (dict): Teacher backbone modules. + teacher_checkpoint (str): Path to teacher model checkpoint. + cls_head (dict): Classification head to process features. + teacher_cls_head (dict, optional): Teacher classification head. + Default: None. + train_cfg (dict, optional): Config for training. Default: None. + test_cfg (dict, optional): Config for testing. Default: None. + """ + + def __init__(self, + backbone, + teacher_backbone, + teacher_checkpoint, + cls_head, + teacher_cls_head=None, + train_cfg=None, + test_cfg=None): + super().__init__(backbone=backbone, cls_head=cls_head, + train_cfg=train_cfg, test_cfg=test_cfg) + + # Teacher model + self.teacher_backbone = build_backbone(teacher_backbone) + if teacher_cls_head is not None: + self.teacher_cls_head = build_head(teacher_cls_head) + else: + # Use same head config as student but with teacher's feature size + teacher_cls_head_cfg = copy.deepcopy(cls_head) + teacher_cls_head_cfg['in_channels'] = self.teacher_backbone.feat_dim + self.teacher_cls_head = build_head(teacher_cls_head_cfg) + + # Load teacher weights + self.load_teacher(teacher_checkpoint) + + # Freeze teacher parameters + for param in self.teacher_backbone.parameters(): + param.requires_grad = False + for param in self.teacher_cls_head.parameters(): + param.requires_grad = False + + def load_teacher(self, checkpoint): + """Load teacher model weights.""" + if not checkpoint: + print("Warning: No teacher checkpoint provided. Using random weights.") + return + + print(f"Loading teacher model from {checkpoint}") + try: + state_dict = torch.load(checkpoint, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + # Load backbone and head weights + teacher_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('backbone.'): + teacher_state_dict[k.replace('backbone.', '')] = v + elif k.startswith('cls_head.'): + teacher_state_dict[k.replace('cls_head.', '')] = v + + # Load backbone weights + backbone_state_dict = {k: v for k, v in teacher_state_dict.items() + if k in self.teacher_backbone.state_dict()} + if backbone_state_dict: + self.teacher_backbone.load_state_dict( + backbone_state_dict, strict=False) + print( + f"Loaded {len(backbone_state_dict)} keys into teacher backbone") + else: + print("Warning: No matching keys found for teacher backbone") + + # Load head weights + head_state_dict = {k: v for k, v in teacher_state_dict.items() + if k in self.teacher_cls_head.state_dict()} + if head_state_dict: + self.teacher_cls_head.load_state_dict( + head_state_dict, strict=False) + print(f"Loaded {len(head_state_dict)} keys into teacher head") + else: + print("Warning: No matching keys found for teacher head") + + except Exception as e: + print(f"Error loading teacher checkpoint: {e}") + + def extract_feat(self, imgs): + """Extract features through the backbone.""" + return self.backbone(imgs) + + def forward_train(self, imgs, labels, **kwargs): + """Forward computation during training.""" + # Student forward pass + x = self.extract_feat(imgs) + cls_score = self.cls_head(x) + + # Teacher forward pass (no gradient computation) + with torch.no_grad(): + teacher_x = self.teacher_backbone(imgs) + teacher_cls_score = self.teacher_cls_head(teacher_x) + + # Compute distillation loss + distill_loss_cfg = self.train_cfg.get('distill_loss', + dict(type='KLDivLoss', + temperature=4.0, + alpha=0.5)) + + # Build loss if it's a dict config + if isinstance(distill_loss_cfg, dict): + loss_type = distill_loss_cfg.pop('type') + if loss_type == 'KLDivLoss': + temperature = distill_loss_cfg.get('temperature', 4.0) + alpha = distill_loss_cfg.get('alpha', 0.5) + + # Soft targets from teacher + soft_targets = F.softmax( + teacher_cls_score / temperature, dim=1) + # Softmax with temperature for student + soft_prob = F.log_softmax(cls_score / temperature, dim=1) + + # KL divergence loss for soft targets + kd_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * \ + (temperature ** 2) + + # Hard target loss + ce_loss = F.cross_entropy(cls_score, labels) + + # Combined loss + loss = alpha * kd_loss + (1 - alpha) * ce_loss + + losses = { + 'loss_cls': loss, + 'loss_kd': kd_loss, + 'loss_ce': ce_loss + } + else: + # Use mmaction's loss builder for other loss types + distill_loss = build_loss( + dict(type=loss_type, **distill_loss_cfg)) + loss = distill_loss(cls_score, teacher_cls_score, labels) + losses = {'loss_cls': loss} + else: + # If not a dict, assume it's already a loss function + loss = distill_loss_cfg(cls_score, teacher_cls_score, labels) + losses = {'loss_cls': loss} + + return losses + + def _do_test(self, imgs): + """Defines the computation performed at every call when evaluation and + testing.""" + # Only use student model during testing + x = self.extract_feat(imgs) + cls_score = self.cls_head(x) + + return cls_score + + def forward_test(self, imgs): + """Defines the computation performed at every call when evaluation and + testing.""" + return self._do_test(imgs) + + def forward_dummy(self, imgs): + """Used for computing network FLOPs.""" + return self._do_test(imgs) + + def forward(self, imgs, return_loss=True, **kwargs): + """Define the computation performed at every call.""" + if kwargs.get('gradcam', False): + return self.forward_gradcam(imgs) + if return_loss: + if self.train_cfg is None: + raise ValueError('Cannot train without train_cfg') + return self.forward_train(imgs, **kwargs) + + return self.forward_test(imgs) diff --git a/processed_data/split_data.py b/processed_data/split_data.py new file mode 100644 index 0000000000..98daa07789 --- /dev/null +++ b/processed_data/split_data.py @@ -0,0 +1,66 @@ +import pickle +import numpy as np +import os + +# Create directory if it doesn't exist +os.makedirs('data/skeleton/ntu60_2d/', exist_ok=True) + +# Load the data +with open('/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/processed_data/ntu60_2d.pkl', 'rb') as f: + data = pickle.load(f) + +# First, let's examine the structure of the data +print("Data type:", type(data)) + +# If data is a dictionary +if isinstance(data, dict): + # Get the keys and split them + keys = list(data.keys()) + np.random.seed(42) # For reproducibility + np.random.shuffle(keys) + train_keys = keys[:int(0.8 * len(keys))] + val_keys = keys[int(0.8 * len(keys)):] + + # Create train and val datasets + train_data = {k: data[k] for k in train_keys} + val_data = {k: data[k] for k in val_keys} + + print( + f"Split data into {len(train_data)} training samples and {len(val_data)} validation samples") +# If data is a list or another iterable +elif hasattr(data, '__iter__') and not isinstance(data, (str, dict)): + # Use keys from data if it's not a simple list + if hasattr(data, 'keys'): + items = list(data.keys()) + else: + items = list(range(len(data))) + + np.random.seed(42) # For reproducibility + np.random.shuffle(items) + train_items = items[:int(0.8 * len(items))] + val_items = items[int(0.8 * len(items)):] + + # Create train and val datasets based on structure + if hasattr(data, 'keys'): + train_data = {k: data[k] for k in train_items} + val_data = {k: data[k] for k in val_items} + else: + train_data = [data[i] for i in train_items] + val_data = [data[i] for i in val_items] + + print( + f"Split data into {len(train_data)} training samples and {len(val_data)} validation samples") +else: + # If data has a different structure, print it for debugging + print("Unexpected data structure. First few elements:") + print(data) + exit(1) + +# Save the splits +with open('data/skeleton/ntu60_2d/ntu60_2d_train.pkl', 'wb') as f: + pickle.dump(train_data, f) + print("Saved training data to data/skeleton/ntu60_2d/ntu60_2d_train.pkl") + +with open('data/skeleton/ntu60_2d/ntu60_2d_val.pkl', 'wb') as f: + pickle.dump(val_data, f) + print("Saved validation data to data/skeleton/ntu60_2d/ntu60_2d_val.pkl") diff --git a/skeleton_overlay.mp4 b/skeleton_overlay.mp4 new file mode 100644 index 0000000000..1d84ee9a77 Binary files /dev/null and b/skeleton_overlay.mp4 differ diff --git a/tools/data/skeleton/check_pickle.py b/tools/data/skeleton/check_pickle.py new file mode 100644 index 0000000000..74ff542ea2 --- /dev/null +++ b/tools/data/skeleton/check_pickle.py @@ -0,0 +1,20 @@ +import pickle +import numpy as np + +def analyze_data(data_sample): + print(data_sample.keys()) + print('Keys in the pickle file:', data_sample.keys()) + print('Keypoints shape:', data_sample['keypoint']) + print('Keypoint score shape:', data_sample['keypoint_score'].shape) + print('Total frames:', data_sample['total_frames']) + print('Sample keypoint data:', data_sample['keypoint'][0,0,:5]) + print('Frame dir data:', data_sample['frame_dir']) + print('Label:', data_sample['label']) + +if __name__ == "__main__": + data_ntu60 = pickle.load(open('/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/data/skeleton/ntu60_2d/ntu60_2d_train.pkl', 'rb')) + data_ntu_sample = data_ntu60['annotations'][42] + analyze_data(data_ntu_sample) + print("="*50) + data = pickle.load(open('/home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/video_1.pkl', 'rb')) + analyze_data(data) \ No newline at end of file diff --git a/tools/data/skeleton/extract_zip.sh b/tools/data/skeleton/extract_zip.sh new file mode 100755 index 0000000000..819d2da96a --- /dev/null +++ b/tools/data/skeleton/extract_zip.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# Check if correct number of arguments +if [ $# -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign arguments to variables +ZIP_FILE="$1" +DEST_DIR="$2" + +# Check if zip file exists +if [ ! -f "$ZIP_FILE" ]; then + echo "Error: Zip file '$ZIP_FILE' not found" + exit 1 +fi + +# Check if destination directory exists, create it if it doesn't +if [ ! -d "$DEST_DIR" ]; then + echo "Creating destination directory '$DEST_DIR'" + mkdir -p "$DEST_DIR" + if [ $? -ne 0 ]; then + echo "Error: Failed to create destination directory" + exit 1 + fi +fi + +# Extract the zip file to destination +echo "Extracting '$ZIP_FILE' to '$DEST_DIR'" +unzip -q "$ZIP_FILE" -d "$DEST_DIR" + +# Check if extraction was successful +if [ $? -eq 0 ]; then + echo "Extraction completed successfully" +else + echo "Error: Failed to extract zip file" + exit 1 +fi \ No newline at end of file diff --git a/tools/data/skeleton/gen_data.sh b/tools/data/skeleton/gen_data.sh new file mode 100644 index 0000000000..0b56eb506b --- /dev/null +++ b/tools/data/skeleton/gen_data.sh @@ -0,0 +1 @@ +python3 ./tools/data/skeleton/ntu_pose_extraction.py /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/data/skeleton/Le2i/Lecture_room/video_1.avi video_1.pkl \ No newline at end of file diff --git a/tools/data/skeleton/ntu_pose_extraction_refactored.py b/tools/data/skeleton/ntu_pose_extraction_refactored.py new file mode 100644 index 0000000000..798c591307 --- /dev/null +++ b/tools/data/skeleton/ntu_pose_extraction_refactored.py @@ -0,0 +1,676 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +NTU RGB+D Pose Extraction Module + +This module contains functions to extract human pose keypoints from videos, +particularly focused on the NTU RGB+D dataset format. It provides a complete +pipeline for: +1. Human detection in video frames +2. Post-processing of detection results with tracklet formation +3. Pose estimation based on detections +4. Aligning pose keypoints across frames +5. Creating standardized annotation format for action recognition + +The implementation handles both single-person and multi-person scenarios +with special consideration for challenging detection cases. +""" + +import abc +import argparse +import os.path as osp +from collections import defaultdict +from tempfile import TemporaryDirectory + +import mmengine +import numpy as np +import torch + +from mmaction.apis import detection_inference, pose_inference +from mmaction.utils import frame_extract + + +class Args: + """ + Configuration parameters for pose extraction pipeline. + + This class holds default values for detection and pose estimation + models, confidence thresholds, and device configuration. + """ + + def __init__(self): + # Human detection model settings + self.det_config = 'demo/demo_configs/faster-rcnn_r50-caffe_fpn_ms-1x_coco-person.py' + self.det_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth' + self.det_score_thr = 0.5 # Detection confidence threshold + + # Pose estimation model settings + self.pose_config = 'demo/demo_configs/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py' + self.pose_checkpoint = 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth' + + # https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-base-simple_8xb64-210e_coco-256x192-0b8234ea_20230407.pth + + # General settings + self.device = 'cuda:0' + self.skip_postproc = False + + +# Create a global instance of Args with default values +args = Args() + + +def intersection(b0, b1): + """ + Calculate intersection area of two bounding boxes. + + Args: + b0 (ndarray): First bounding box in format [x1, y1, x2, y2, ...] + b1 (ndarray): Second bounding box in format [x1, y1, x2, y2, ...] + + Returns: + float: Intersection area + """ + l, r = max(b0[0], b1[0]), min(b0[2], b1[2]) + u, d = max(b0[1], b1[1]), min(b0[3], b1[3]) + return max(0, r - l) * max(0, d - u) + + +def iou(b0, b1): + """ + Calculate Intersection over Union (IoU) of two bounding boxes. + + Args: + b0 (ndarray): First bounding box in format [x1, y1, x2, y2, ...] + b1 (ndarray): Second bounding box in format [x1, y1, x2, y2, ...] + + Returns: + float: IoU value between 0 and 1 + """ + i = intersection(b0, b1) + u = area(b0) + area(b1) - i + return i / u + + +def area(b): + """ + Calculate area of a bounding box. + + Args: + b (ndarray): Bounding box in format [x1, y1, x2, y2, ...] + + Returns: + float: Area of the bounding box + """ + return (b[2] - b[0]) * (b[3] - b[1]) + + +def removedup(bbox): + """ + Remove duplicate or heavily overlapping bounding boxes. + + Keeps boxes with higher confidence scores when significant overlap exists. + + Args: + bbox (ndarray): Array of bounding boxes, each in format [x1, y1, x2, y2, score] + + Returns: + ndarray: Filtered array of bounding boxes + """ + def inside(box0, box1, threshold=0.8): + """Check if box0 is mostly inside box1.""" + return intersection(box0, box1) / area(box0) > threshold + + num_bboxes = bbox.shape[0] + if num_bboxes <= 1: + return bbox + + valid = [] + for i in range(num_bboxes): + flag = True + for j in range(num_bboxes): + # If box i is inside box j and has lower confidence, remove box i + if i != j and inside(bbox[i], bbox[j]) and bbox[i][4] <= bbox[j][4]: + flag = False + break + if flag: + valid.append(i) + return bbox[valid] + + +def is_easy_example(det_results, num_person): + """ + Determine if a video has consistent high-confidence detections. + + An "easy example" has exactly the expected number of persons + with high confidence (>0.95) detections across all frames. + + Args: + det_results (list): List of detection results per frame + num_person (int): Expected number of persons (1 or 2) + + Returns: + tuple: (is_easy, bboxes), where: + - is_easy (bool): Whether this is an easy example + - bboxes (ndarray or int): If easy, stacked bounding boxes; + otherwise, the number of high-confidence detections + """ + threshold = 0.95 + + def thre_bbox(bboxes, threshold=threshold): + """Count and verify consistent high-confidence detections.""" + shape = [sum(bbox[:, -1] > threshold) for bbox in bboxes] + ret = np.all(np.array(shape) == shape[0]) + return shape[0] if ret else -1 + + if thre_bbox(det_results) == num_person: + # This is an easy example - filter to keep only high-confidence detections + det_results = [x[x[..., -1] > 0.95] for x in det_results] + return True, np.stack(det_results) + return False, thre_bbox(det_results) + + +def bbox2tracklet(bbox): + """ + Convert frame-by-frame bounding boxes to consistent tracklets. + + A tracklet is a sequence of bounding boxes that follows the same person + across multiple frames. This function uses IoU to associate detections + across frames. + + Args: + bbox (list): List of bounding box arrays per frame + + Returns: + dict: Dictionary mapping tracklet IDs to lists of (frame_idx, bbox) tuples + """ + iou_thre = 0.6 # IoU threshold for tracklet association + tracklet_id = -1 + tracklet_st_frame = {} # Start frame for each tracklet + tracklets = defaultdict(list) + + # Process each frame and each detection + for t, box in enumerate(bbox): + for idx in range(box.shape[0]): + matched = False + # Try to match with existing tracklets (from newest to oldest) + for tlet_id in range(tracklet_id, -1, -1): + # Conditions for matching: + # 1. IoU with latest box in tracklet exceeds threshold + cond1 = iou(tracklets[tlet_id][-1][-1], box[idx]) >= iou_thre + # 2. Not too far apart in time (max 10 frame gap) + cond2 = ( + t - tracklet_st_frame[tlet_id] - len(tracklets[tlet_id]) < 10) + # 3. Current frame not already in this tracklet + cond3 = tracklets[tlet_id][-1][0] != t + + if cond1 and cond2 and cond3: + matched = True + tracklets[tlet_id].append((t, box[idx])) + break + + # If no match found, create a new tracklet + if not matched: + tracklet_id += 1 + tracklet_st_frame[tracklet_id] = t + tracklets[tracklet_id].append((t, box[idx])) + + return tracklets + + +def drop_tracklet(tracklet): + """ + Filter out short or small tracklets. + + Removes tracklets that: + 1. Have fewer than 6 frames + 2. Have a mean bounding box area less than 5000 pixels + + Args: + tracklet (dict): Dictionary of tracklets from bbox2tracklet + + Returns: + dict: Filtered dictionary of tracklets + """ + # Filter out short tracklets (less than 6 frames) + tracklet = {k: v for k, v in tracklet.items() if len(v) > 5} + + def meanarea(track): + """Calculate mean area of bounding boxes in a tracklet.""" + boxes = np.stack([x[1] for x in track]).astype(np.float32) + areas = (boxes[..., 2] - boxes[..., 0]) * \ + (boxes[..., 3] - boxes[..., 1]) + return np.mean(areas) + + # Filter out small tracklets (mean area < 5000 pixels) + tracklet = {k: v for k, v in tracklet.items() if meanarea(v) > 5000} + return tracklet + + +def distance_tracklet(tracklet): + """ + Calculate mean distance of each tracklet from the center of the frame. + + This is used to prioritize tracklets that are closer to the center + when there are multiple candidates. + + Args: + tracklet (dict): Dictionary of tracklets + + Returns: + dict: Dictionary mapping tracklet IDs to mean distances from center + """ + dists = {} + for k, v in tracklet.items(): + # Stack all bounding boxes in this tracklet + bboxes = np.stack([x[1] for x in v]) + + # Calculate center coordinates of each bbox + c_x = (bboxes[..., 2] + bboxes[..., 0]) / 2. + c_y = (bboxes[..., 3] + bboxes[..., 1]) / 2. + + # Adjust to center relative to frame center (assumed to be 480, 270) + # This is an approximation based on the original NTU dataset + c_x -= 480 + c_y -= 270 + + # Calculate distance from center + c = np.concatenate([c_x[..., None], c_y[..., None]], axis=1) + dist = np.linalg.norm(c, axis=1) + dists[k] = np.mean(dist) + + return dists + + +def tracklet2bbox(track, num_frame): + """ + Convert a single tracklet to frame-by-frame bounding boxes. + + For frames where the tracklet has no detection, interpolates + from nearest available frame. + + Args: + track (list): List of (frame_idx, bbox) tuples for a single tracklet + num_frame (int): Total number of frames in the video + + Returns: + ndarray: Array of bounding boxes, one per frame + """ + # Initialize empty bounding boxes for all frames + bbox = np.zeros((num_frame, 5)) + trackd = {} + + # Populate with known detections + for k, v in track: + bbox[k] = v + trackd[k] = v + + # Fill in missing frames by finding nearest detection + for i in range(num_frame): + if bbox[i][-1] <= 0.5: # Low-confidence or missing detection + mind = np.Inf + nearest_idx = None + + # Find nearest valid frame + for k in trackd: + if np.abs(k - i) < mind: + mind = np.abs(k - i) + nearest_idx = k + + # Copy detection from nearest frame + bbox[i] = bbox[nearest_idx] + + return bbox + + +def tracklets2bbox(tracklet, num_frame): + """ + Convert multiple tracklets to a single primary tracklet with bounding boxes. + + Prioritizes tracklets that: + 1. Cover at least half the video length + 2. Are closer to the center of the frame + + Args: + tracklet (dict): Dictionary of tracklets + num_frame (int): Total number of frames in the video + + Returns: + tuple: (bad_frames, bboxes), where: + - bad_frames (int): Number of frames with low-confidence detections + - bboxes (ndarray): Array of bounding boxes with shape (num_frame, 1, 5) + """ + # Calculate distances from center for all tracklets + dists = distance_tracklet(tracklet) + sorted_inds = sorted(dists, key=lambda x: dists[x]) + + # Find a long enough tracklet (covers at least half the frames) + # and set a distance threshold based on it + dist_thre = np.Inf + for i in sorted_inds: + if len(tracklet[i]) >= num_frame / 2: + # Use 2x the distance of this tracklet as threshold + dist_thre = 2 * dists[i] + break + + # Set a minimum distance threshold + dist_thre = max(50, dist_thre) + + # Initialize empty bounding boxes + bbox = np.zeros((num_frame, 5)) + bboxd = {} + + # Fill in detections from tracklets within distance threshold + for idx in sorted_inds: + if dists[idx] < dist_thre: + for k, v in tracklet[idx]: + if bbox[k][-1] < 0.01: # Empty or very low confidence + bbox[k] = v + bboxd[k] = v + + # Count and fix frames with missing detections + bad = 0 + for idx in range(num_frame): + if bbox[idx][-1] < 0.01: + bad += 1 + + # Find nearest frame with detection + mind = np.Inf + mink = None + for k in bboxd: + if np.abs(k - idx) < mind: + mind = np.abs(k - idx) + mink = k + + # Copy detection from nearest frame + bbox[idx] = bboxd[mink] + + # Reshape to match expected format (num_frame, 1, 5) + return bad, bbox[:, None, :] + + +def bboxes2bbox(bbox, num_frame): + """ + Process bounding boxes for two-person scenario. + + For each frame, select the top 2 detections by confidence score, + and maintain consistent person IDs across frames using IoU. + + Args: + bbox (list): List of bounding box arrays per frame + num_frame (int): Total number of frames in the video + + Returns: + ndarray: Array of shape (num_frame, 2, 5) containing consistent + two-person tracking results + """ + # Initialize array for two persons across all frames + ret = np.zeros((num_frame, 2, 5)) + + # Process each frame's detections + for t, item in enumerate(bbox): + if item.shape[0] <= 2: + # If we have 0, 1 or 2 detections, use them as is + ret[t, :item.shape[0]] = item + else: + # If we have more than 2, select top 2 by confidence + inds = sorted( + list(range(item.shape[0])), key=lambda x: -item[x, -1]) + ret[t] = item[inds[:2]] + + # Process frames to maintain consistent person IDs + for t in range(num_frame): + # Handle frames with no detections + if ret[t, 0, -1] <= 0.01: + ret[t] = ret[t - 1] # Copy from previous frame + + # Handle frames with only one detection (need to decide which person it is) + elif ret[t, 1, -1] <= 0.01 and t > 0: + if ret[t - 1, 0, -1] > 0.01 and ret[t - 1, 1, -1] > 0.01: + # Determine which previous person this detection matches better + if iou(ret[t, 0], ret[t - 1, 0]) > iou(ret[t, 0], ret[t - 1, 1]): + # Detection matches person 0, so copy person 1 from previous frame + ret[t, 1] = ret[t - 1, 1] + else: + # Detection matches person 1, so copy person 0 and swap positions + ret[t, 1] = ret[t, 0] + ret[t, 0] = ret[t - 1, 0] + + return ret + + +def ntu_det_postproc(vid, det_results): + """ + Post-process detection results for NTU RGB+D format videos. + + This is the main algorithm for handling both easy and hard detection cases, + determining number of persons, and ensuring consistent tracking. + + Args: + vid (str): Path to video file + det_results (list): List of detection results per frame + + Returns: + ndarray: Processed detection results with consistent tracking + """ + # Remove duplicate detections in each frame + det_results = [removedup(x) for x in det_results] + + # Determine number of persons from video filename + # NTU RGB+D follows a specific naming convention + try: + label = int(vid.split('/')[-1].split('A')[1][:3]) + # Actions 50-60 and 106-120 in NTU RGB+D are multi-person actions + mpaction = list(range(50, 61)) + list(range(106, 121)) + n_person = 2 if label in mpaction else 1 + except (IndexError, ValueError): + # For videos not following the convention, assume single person + n_person = 1 + + # Check if this is an easy example (consistent high-confidence detections) + is_easy, bboxes = is_easy_example(det_results, n_person) + if is_easy: + print('\nEasy Example') + return bboxes + + # For harder cases, create tracklets from detections + tracklets = bbox2tracklet(det_results) + tracklets = drop_tracklet(tracklets) + + print(f'\nHard {n_person}-person Example, found {len(tracklets)} tracklet') + + # Handle single-person case + if n_person == 1: + if len(tracklets) == 1: + # Only one tracklet found, use it directly + tracklet = list(tracklets.values())[0] + det_results = tracklet2bbox(tracklet, len(det_results)) + return np.stack(det_results) + else: + # Multiple tracklets found, select the best one + bad, det_results = tracklets2bbox(tracklets, len(det_results)) + return det_results + + # Handle two-person case + if len(tracklets) <= 2: + # We found exactly the right number of tracklets + tracklets = list(tracklets.values()) + bboxes = [] + for tracklet in tracklets: + bboxes.append(tracklet2bbox(tracklet, len(det_results))[:, None]) + bbox = np.concatenate(bboxes, axis=1) + return bbox + else: + # Too many tracklets, need to select and organize them + return bboxes2bbox(det_results, len(det_results)) + + +def pose_inference_with_align(frame_paths, det_results, device): + """ + Perform pose estimation on detection results and align across frames. + + This function: + 1. Runs pose estimation on each detected person + 2. Aligns pose data across frames (handling variable person counts) + 3. Formats keypoints and confidence scores for each joint + + Args: + frame_paths (list): List of paths to video frames + det_results (list): List of detection results per frame + device (str): Device to run inference on ('cuda:0', 'cpu', etc.) + + Returns: + tuple: (keypoints, scores), where: + - keypoints (ndarray): Array of shape (num_persons, num_frames, num_joints, 2) + - scores (ndarray): Array of shape (num_persons, num_frames, num_joints) + """ + # Set up args for pose inference + pose_args = Args() + pose_args.device = device + + # Filter out frames without any detections + det_results = [ + frm_dets for frm_dets in det_results if frm_dets.shape[0] > 0] + + # Run pose estimation + pose_results, _ = pose_inference( + pose_args.pose_config, + pose_args.pose_checkpoint, + frame_paths, + det_results, + device + ) + + # Align the pose results to have consistent num_person across frames + # Find the maximum number of persons detected in any frame + num_persons = max([pose['keypoints'].shape[0] for pose in pose_results]) + num_points = pose_results[0]['keypoints'].shape[1] + num_frames = len(pose_results) + + # Initialize arrays for aligned pose data + keypoints = np.zeros( + (num_persons, num_frames, num_points, 2), dtype=np.float32) + scores = np.zeros((num_persons, num_frames, num_points), dtype=np.float32) + + # Copy pose data for each person and frame + for f_idx, frm_pose in enumerate(pose_results): + frm_num_persons = frm_pose['keypoints'].shape[0] + for p_idx in range(frm_num_persons): + keypoints[p_idx, f_idx] = frm_pose['keypoints'][p_idx] + scores[p_idx, f_idx] = frm_pose['keypoint_scores'][p_idx] + + return keypoints, scores + + +def ntu_pose_extraction(vid, skip_postproc=False, device='cuda:0'): + """ + Extract pose keypoints from a video for action recognition. + + Complete pipeline that: + 1. Extracts frames from video + 2. Detects persons in frames + 3. Processes detections for consistent tracking + 4. Estimates pose keypoints + 5. Formats results into a structured annotation + + Args: + vid (str): Path to video file + skip_postproc (bool): Whether to skip detection post-processing + device (str): Device to run inference on ('cuda:0', 'cpu', etc.) + + Returns: + dict: Annotation dictionary containing: + - keypoint: Pose keypoints array (num_persons, num_frames, num_joints, 2) + - keypoint_score: Confidence scores (num_persons, num_frames, num_joints) + - frame_dir: Video name without extension + - img_shape: Image dimensions (height, width) + - original_shape: Original image dimensions + - total_frames: Number of frames + - label: Action label (if filename follows NTU format) + """ + # Create a temporary directory for extracted frames + tmp_dir = TemporaryDirectory() + frame_paths, _ = frame_extract(vid, out_dir=tmp_dir.name) + + # Set up detection args + det_args = Args() + det_args.device = device + + # Run human detection + det_results, _ = detection_inference( + det_args.det_config, + det_args.det_checkpoint, + frame_paths, + det_args.det_score_thr, + device=device, + with_score=True + ) + + # Post-process detection results for consistent tracking + if not skip_postproc: + det_results = ntu_det_postproc(vid, det_results) + + # Create annotation dictionary + anno = dict() + + # Extract pose keypoints and scores + keypoints, scores = pose_inference_with_align( + frame_paths, det_results, device) + + # Fill annotation dictionary + anno['keypoint'] = keypoints + anno['keypoint_score'] = scores + anno['frame_dir'] = osp.splitext(osp.basename(vid))[0] + # Default shape, will be used if not detected + anno['img_shape'] = (1080, 1920) + anno['original_shape'] = (1080, 1920) + anno['total_frames'] = keypoints.shape[1] + + # Extract label from filename if it follows NTU naming convention + try: + # Original NTU dataset filename format: S001C001P001R001A001.avi + # where A001 is action class 1 + anno['label'] = int(osp.basename(vid).split('A')[1][:3]) - 1 + except (IndexError, ValueError): + # For other video formats, set a default label + print( + f"Warning: Video filename '{osp.basename(vid)}' doesn't follow NTU naming convention.") + print("Setting default label to 0.") + anno['label'] = 0 + + # Clean up temporary directory + tmp_dir.cleanup() + return anno + + +def parse_args(): + """ + Parse command line arguments for standalone script usage. + + Returns: + argparse.Namespace: Parsed arguments + """ + parser = argparse.ArgumentParser( + description='Generate Pose Annotation for a single NTURGB-D video') + parser.add_argument('video', type=str, help='source video') + parser.add_argument('output', type=str, help='output pickle name') + parser.add_argument('--device', type=str, default='cuda:0', + help='Device to use for inference, e.g., cuda:0, cpu') + parser.add_argument('--skip-postproc', action='store_true', + help='Skip detection post-processing (tracking)') + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + # Parse command line arguments + cmd_args = parse_args() + + # Extract pose annotations + anno = ntu_pose_extraction( + cmd_args.video, + cmd_args.skip_postproc, + cmd_args.device + ) + + # Save the annotations + mmengine.dump(anno, cmd_args.output) diff --git a/tools/data/skeleton/overlay_skeleton.py b/tools/data/skeleton/overlay_skeleton.py new file mode 100644 index 0000000000..138d0bdac7 --- /dev/null +++ b/tools/data/skeleton/overlay_skeleton.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import pickle +import numpy as np +import cv2 +import os +import os.path as osp + +try: + import moviepy.editor as mpy +except ImportError: + raise ImportError('Please install moviepy to enable output file') + +# Define the skeleton connections (edges between joints) +# Based on the standard 17 keypoints (COCO format) +SKELETON_CONNECTIONS = [ + (0, 1), (0, 2), (1, 3), (2, 4), # Head and shoulders + (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # Arms + (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) # Legs +] + +# Define colors for visualization +KEYPOINT_COLOR = (0, 255, 0) # Green +CONNECTION_COLOR = (0, 0, 255) # Red +THICKNESS = 2 +KEYPOINT_RADIUS = 4 + + +def parse_args(): + parser = argparse.ArgumentParser(description='Overlay skeleton on video') + parser.add_argument('video_file', help='path to the original video file') + parser.add_argument('pickle_file', help='path to the skeleton pickle file') + parser.add_argument('out_filename', help='output video filename') + parser.add_argument('--keypoint_threshold', type=float, default=0.3, + help='Threshold for keypoint confidence score') + parser.add_argument('--fps', type=int, default=30, + help='FPS for output video') + args = parser.parse_args() + return args + + +def overlay_skeleton_on_video(video_file, pickle_file, out_filename, keypoint_threshold=0.3, fps=30): + # Load skeleton data + print(f"Loading skeleton data from {pickle_file}") + with open(pickle_file, 'rb') as f: + data = pickle.load(f) + + # Extract keypoints and scores + # shape: (num_person, num_frames, num_keypoints, 2) + keypoints = data['keypoint'] + # shape: (num_person, num_frames, num_keypoints) + keypoint_scores = data['keypoint_score'] + num_persons, num_frames, num_keypoints, _ = keypoints.shape + + print( + f"Loaded data with {num_persons} persons, {num_frames} frames, {num_keypoints} keypoints") + + # Open the video file + cap = cv2.VideoCapture(video_file) + if not cap.isOpened(): + raise ValueError(f"Failed to open video file: {video_file}") + + # Get video properties + video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + video_fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + print( + f"Video properties: {video_width}x{video_height}, {video_fps} fps, {total_frames} frames") + + # Create a temporary directory for frames + tmp_dir = "tmp_overlay_vis" + os.makedirs(tmp_dir, exist_ok=True) + + frame_files = [] + frame_idx = 0 + + # Create temp directory for output frames + while True: + ret, frame = cap.read() + if not ret: + break + + # Skip frames if there are more video frames than skeleton frames + if frame_idx >= num_frames: + break + + # Draw skeletons for each person on this frame + for person_idx in range(num_persons): + # Get skeleton data for this person in this frame + person_keypoints = keypoints[person_idx, frame_idx] + person_scores = keypoint_scores[person_idx, frame_idx] + + # Draw keypoints + for kp_idx in range(num_keypoints): + x, y = person_keypoints[kp_idx] + score = person_scores[kp_idx] + + # Only draw keypoints with score above threshold + if score > keypoint_threshold: + x, y = int(x), int(y) + cv2.circle(frame, (x, y), KEYPOINT_RADIUS, + KEYPOINT_COLOR, -1) + + # Draw skeleton connections + for start_idx, end_idx in SKELETON_CONNECTIONS: + start_score = person_scores[start_idx] + end_score = person_scores[end_idx] + + # Only draw connection if both keypoints have score above threshold + if start_score > keypoint_threshold and end_score > keypoint_threshold: + start_x, start_y = person_keypoints[start_idx] + end_x, end_y = person_keypoints[end_idx] + + start_x, start_y = int(start_x), int(start_y) + end_x, end_y = int(end_x), int(end_y) + + cv2.line(frame, (start_x, start_y), + (end_x, end_y), CONNECTION_COLOR, THICKNESS) + + # Add frame number + cv2.putText(frame, f"Frame: {frame_idx}", (30, 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) + + # Save the frame + frame_file = osp.join(tmp_dir, f"frame_{frame_idx:04d}.png") + cv2.imwrite(frame_file, frame) + frame_files.append(frame_file) + + frame_idx += 1 + + # Print progress every 10% + if frame_idx % (num_frames // 10) == 0 or frame_idx == 1: + print(f"Processed {frame_idx}/{num_frames} frames") + + cap.release() + + # Create a video from the frames + print(f"Creating video from {len(frame_files)} frames") + clip = mpy.ImageSequenceClip(frame_files, fps=fps) + clip.write_videofile(out_filename) + + # Clean up temporary files + print("Cleaning up temporary files") + for file in frame_files: + os.remove(file) + os.rmdir(tmp_dir) + + print(f"Video with overlaid skeleton saved to {out_filename}") + + +def main(): + args = parse_args() + overlay_skeleton_on_video(args.video_file, args.pickle_file, args.out_filename, + args.keypoint_threshold, args.fps) + + +if __name__ == '__main__': + main() diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000..c40d1688ad --- /dev/null +++ b/train.sh @@ -0,0 +1,6 @@ +export CUBLAS_WORKSPACE_CONFIG=:16:8 + +python /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/tools/train.py /home/minhtranh/works/Project/Rainscales/Lying_detection/mmaction2/configs/skeleton/posec3d/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py \ + --work-dir work_dirs/posec3d_ntu60_2d_adam/ \ + --seed 0 \ + --amp \ No newline at end of file