diff --git a/.gitignore b/.gitignore index 213ee8f..e833cba 100644 --- a/.gitignore +++ b/.gitignore @@ -206,6 +206,7 @@ vis/ clt_test_pythia_70m_jumprelu/ clt_smoke_output_local_wandb_batchtopk/ clt_smoke_output_remote_wandb/ +wandb/ # models *.pt diff --git a/clt/activation_generation/generator.py b/clt/activation_generation/generator.py index 650f684..4b34e31 100644 --- a/clt/activation_generation/generator.py +++ b/clt/activation_generation/generator.py @@ -14,14 +14,13 @@ from __future__ import annotations import os -import time import json import queue import random import logging import threading from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any +from typing import Dict, List, Optional, Tuple, Any, DefaultDict from concurrent.futures import ThreadPoolExecutor, as_completed import torch @@ -36,7 +35,7 @@ from clt.config.data_config import ActivationConfig # noqa: E402 # --- Profiling Imports --- -import time # Already imported, but good to note +import time # Keep this one from contextlib import contextmanager from collections import defaultdict import psutil @@ -56,8 +55,8 @@ # --- Performance Profiler Class --- class PerformanceProfiler: def __init__(self, chunk_tokens_threshold: int = 1_000_000): - self.timings = defaultdict(list) - self.memory_snapshots = [] + self.timings: DefaultDict[str, List[float]] = defaultdict(list) + self.memory_snapshots: List[Dict[str, Any]] = [] self.chunk_tokens_threshold = chunk_tokens_threshold self.system_metrics_log: List[Dict[str, Any]] = [] self.layer_ids_ref: Optional[List[int]] = None @@ -141,7 +140,7 @@ def log_system_metrics(self, interval_name: str = "interval"): return metrics def report(self): - print("\n=== Performance Report ===") + logger.info("\n=== Performance Report ===") # Sort by total time descending for timings sorted_timings = sorted(self.timings.items(), key=lambda item: sum(item[1]), reverse=True) @@ -153,15 +152,17 @@ def report(self): min_time = min(times) max_time = max(times) - print(f"\n--- Operation: {name} ---") - print(f" Count: {len(times)}") - print(f" Total time: {total_time:.3f}s") - print(f" Avg time: {avg_time:.4f}s") - print(f" Min time: {min_time:.4f}s") - print(f" Max time: {max_time:.4f}s") + logger.info(f"\n--- Operation: {name} ---") + logger.info(f" Count: {len(times)}") + logger.info(f" Total time: {total_time:.3f}s") + logger.info(f" Avg time: {avg_time:.4f}s") + logger.info(f" Min time: {min_time:.4f}s") + logger.info(f" Max time: {max_time:.4f}s") if "chunk_write_total_idx" in name: # New unique name per chunk - print(f" Avg ms/k-tok (for this chunk): {avg_time / self.chunk_tokens_threshold * 1000 * 1000:.2f}") + logger.info( + f" Avg ms/k-tok (for this chunk): {avg_time / self.chunk_tokens_threshold * 1000 * 1000:.2f}" + ) elif ( name == "batch_processing_total" and self.batch_processing_total_calls > 0 @@ -171,29 +172,29 @@ def report(self): self.total_tokens_processed_for_batch_profiling / self.batch_processing_total_calls ) if avg_tok_per_batch_call > 0: - print( + logger.info( f" Avg ms/k-tok (estimated for batch_processing_total): {avg_time / avg_tok_per_batch_call * 1000 * 1000:.2f}" ) - print("\n=== Memory Snapshots (showing top 10 by RSS delta) ===") + logger.info("\n=== Memory Snapshots (showing top 10 by RSS delta) ===") interesting_mem_snapshots = sorted( self.memory_snapshots, key=lambda x: abs(x["rss_delta_bytes"]), reverse=True )[:10] for snap in interesting_mem_snapshots: - print( + logger.info( f" {snap['name']} (took {snap['duration_s']:.3f}s): Total RSS {snap['rss_total_bytes'] / (1024**3):.3f} GB (ΔRSS {snap['rss_delta_bytes'] / (1024**3):.3f} GB)" ) - print("\n=== System Metrics Log (sample) ===") + logger.info("\n=== System Metrics Log (sample) ===") for i, metrics in enumerate(self.system_metrics_log[:5]): # Print first 5 samples - print( + logger.info( f" Sample {i} ({metrics['interval_name']}): CPU {metrics['cpu_percent']:.1f}%, Mem {metrics['memory_percent']:.1f}%, GPU {metrics['gpu_util_percent']:.1f}% (Mem {metrics['gpu_memory_percent']:.1f}%)" ) if len(self.system_metrics_log) > 5: - print(" ...") + logger.info(" ...") if self.system_metrics_log: # Check if not empty before accessing last element metrics = self.system_metrics_log[-1] - print( + logger.info( f" Sample End ({metrics['interval_name']}): CPU {metrics['cpu_percent']:.1f}%, Mem {metrics['memory_percent']:.1f}%, GPU {metrics['gpu_util_percent']:.1f}% (Mem {metrics['gpu_memory_percent']:.1f}%)" ) @@ -286,7 +287,7 @@ def _async_uploader(upload_q: "queue.Queue[Optional[Path]]", cfg: ActivationConf # --> ADDED: Retry Loop <-- for attempt in range(max_retries_per_chunk): try: - print( + logger.info( f"[Uploader Thread Attempt {attempt + 1}/{max_retries_per_chunk}] Uploading chunk: {p.name} to {url}" ) with open(p, "rb") as f: @@ -970,7 +971,7 @@ def _upload_binary_file(self, path: Path, endpoint: str): try: activation_config_instance = ActivationConfig(**loaded_config) except TypeError as e: - print(f"Error creating ActivationConfig from YAML. Ensure all keys are correct: {e}") + logger.error(f"Error creating ActivationConfig from YAML. Ensure all keys are correct: {e}") import sys sys.exit(1) diff --git a/clt/config/data_config.py b/clt/config/data_config.py index 092e813..8c0104b 100644 --- a/clt/config/data_config.py +++ b/clt/config/data_config.py @@ -1,5 +1,8 @@ from dataclasses import dataclass, field from typing import Literal, Optional, Dict, Any +import logging + +logger = logging.getLogger(__name__) @dataclass @@ -74,7 +77,23 @@ def __post_init__(self): except ImportError: raise ImportError("h5py is required for HDF5 output format. Install with: pip install h5py") if self.compression not in ["lz4", "gzip", None, False]: - print( + logger.warning( f"Warning: Unsupported compression '{self.compression}'. Will attempt without compression for {self.output_format}." ) # Allow generator to handle disabling if format doesn't support it. + + # Example: Print a summary or key values + # This is more for user feedback than programmatic use. + logger.info( + "ActivationConfig Summary:\n" + f" Model: {self.model_name}\n" + f" Dataset: {self.dataset_path} (Split: {self.dataset_split})\n" + f" Target Tokens: {self.target_total_tokens}\n" + f" Chunk Threshold: {self.chunk_token_threshold}\n" + f" Activation Dtype: {self.activation_dtype}\n" + f" Output Dir: {self.activation_dir}" + ) + if self.remote_server_url: + logger.info(f" Remote Server URL: {self.remote_server_url}") + if self.delete_after_upload: + logger.info(" Delete after upload: Enabled") diff --git a/clt/models/clt.py b/clt/models/clt.py index 3de28fd..3e1391c 100644 --- a/clt/models/clt.py +++ b/clt/models/clt.py @@ -275,3 +275,30 @@ def convert_to_jumprelu_inplace(self, default_theta_value: float = 1e6) -> None: logger.info( f"Rank {self.rank}: CLT model config updated by ThetaManager. New activation_fn='{self.config.activation_fn}'." ) + + # --- Back-compat: expose ThetaManager.log_threshold at model level --- + @property + def log_threshold(self) -> Optional[torch.nn.Parameter]: + """Proxy to ``theta_manager.log_threshold`` for backward compatibility. + + Older training scripts, conversion utilities and tests referenced + ``model.log_threshold`` directly. After the Step-5 refactor the + parameter migrated into the dedicated ``ThetaManager`` module. We + now expose a read-only view that always returns the *current* parameter + held by ``self.theta_manager``. Modifying the returned tensor (e.g. + in-place updates to ``.data``) therefore continues to work as before. + Assigning a brand-new ``nn.Parameter`` to ``model.log_threshold`` will + forward the assignment to ``theta_manager`` to preserve the linkage. + """ + if hasattr(self, "theta_manager") and self.theta_manager is not None: + return self.theta_manager.log_threshold + return None + + @log_threshold.setter + def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None: + # Keep property writable so callers that used to assign a fresh + # parameter (rare) do not break. We delegate the storage to + # ``ThetaManager`` so there is a single source of truth. + if not hasattr(self, "theta_manager") or self.theta_manager is None: + raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.") + self.theta_manager.log_threshold = new_param diff --git a/clt/training/checkpointing.py b/clt/training/checkpointing.py index 8adc9d1..6cd05bc 100644 --- a/clt/training/checkpointing.py +++ b/clt/training/checkpointing.py @@ -6,19 +6,22 @@ from torch.distributed.checkpoint.state_dict_loader import load_state_dict from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, DefaultLoadPlanner from torch.distributed.checkpoint.filesystem import FileSystemWriter, FileSystemReader -from typing import Optional, Union, Dict, Any +from typing import Optional, Union, Dict, Any, TYPE_CHECKING from safetensors.torch import save_file as save_safetensors_file, load_file as load_safetensors_file +import logging # Import for type hinting, moved outside TYPE_CHECKING for runtime availability from clt.training.wandb_logger import WandBLogger, DummyWandBLogger # Forward declarations for type hinting to avoid circular imports -if False: # TYPE_CHECKING +if TYPE_CHECKING: from clt.models.clt import CrossLayerTranscoder from clt.training.data.base_store import BaseActivationStore # from clt.training.wandb_logger import WandBLogger, DummyWandBLogger # No longer needed here +logger = logging.getLogger(__name__) + class CheckpointManager: wandb_logger: Union[WandBLogger, DummyWandBLogger] @@ -86,7 +89,7 @@ def _save_checkpoint( torch.save(trainer_state_to_save, latest_trainer_state_path) except Exception as e: - print(f"Warning: Failed to save non-distributed checkpoint at step {step}: {e}") + logger.warning(f"Warning: Failed to save non-distributed checkpoint at step {step}: {e}") return # --- Distributed Save --- @@ -109,7 +112,9 @@ def _save_checkpoint( no_dist=False, ) except Exception as e: - print(f"Rank {self.rank}: Warning: Failed to save distributed model checkpoint at step {step}: {e}") + logger.warning( + f"Rank {self.rank}: Warning: Failed to save distributed model checkpoint at step {step}: {e}" + ) if self.rank == 0: # Save activation store @@ -119,7 +124,7 @@ def _save_checkpoint( torch.save(self.activation_store.state_dict(), store_checkpoint_path) torch.save(self.activation_store.state_dict(), latest_store_path) except Exception as e: - print(f"Rank 0: Warning: Failed to save activation store state at step {step}: {e}") + logger.warning(f"Rank 0: Warning: Failed to save activation store state at step {step}: {e}") # Save consolidated model as .safetensors model_safetensors_path = os.path.join(checkpoint_dir, "model.safetensors") @@ -128,11 +133,11 @@ def _save_checkpoint( full_model_state_dict = self.model.state_dict() save_safetensors_file(full_model_state_dict, model_safetensors_path) save_safetensors_file(full_model_state_dict, latest_model_safetensors_path) - print( + logger.info( f"Rank 0: Saved consolidated model to {model_safetensors_path} and {latest_model_safetensors_path}" ) except Exception as e: - print(f"Rank 0: Warning: Failed to save consolidated .safetensors model at step {step}: {e}") + logger.warning(f"Rank 0: Warning: Failed to save consolidated .safetensors model at step {step}: {e}") # Save trainer state (optimizer, scheduler, etc.) trainer_state_filepath = os.path.join(checkpoint_dir, "trainer_state.pt") @@ -140,9 +145,11 @@ def _save_checkpoint( try: torch.save(trainer_state_to_save, trainer_state_filepath) torch.save(trainer_state_to_save, latest_trainer_state_filepath) - print(f"Rank 0: Saved trainer state to {trainer_state_filepath} and {latest_trainer_state_filepath}") + logger.info( + f"Rank 0: Saved trainer state to {trainer_state_filepath} and {latest_trainer_state_filepath}" + ) except Exception as e: - print(f"Rank 0: Warning: Failed to save trainer state at step {step}: {e}") + logger.warning(f"Rank 0: Warning: Failed to save trainer state at step {step}: {e}") self.wandb_logger.log_artifact( artifact_path=checkpoint_dir, @@ -168,16 +175,16 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: if not self.distributed: # Non-distributed load model_file_path = checkpoint_path # Expecting path to .safetensors if not (os.path.isfile(model_file_path) and model_file_path.endswith(".safetensors")): - print( + logger.error( f"Error: For non-distributed load, checkpoint_path must be a .safetensors model file. Got: {model_file_path}" ) return loaded_trainer_state # Return empty if path is not as expected - print(f"Attempting to load non-distributed checkpoint from model file: {model_file_path}") + logger.info(f"Attempting to load non-distributed checkpoint from model file: {model_file_path}") try: full_state_dict = load_safetensors_file(model_file_path, device=str(self.device)) self.model.load_state_dict(full_state_dict) - print(f"Successfully loaded non-distributed model from {model_file_path}") + logger.info(f"Successfully loaded non-distributed model from {model_file_path}") # Infer paths for store and trainer state base_dir = os.path.dirname(model_file_path) @@ -195,7 +202,7 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: trainer_state_fname = f"trainer_state_{step_str}.pt" if not store_checkpoint_fname or not trainer_state_fname: - print( + logger.warning( f"Warning: Could not determine store/trainer state filenames from model path {model_file_path}" ) else: @@ -206,27 +213,27 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: if os.path.exists(store_path): store_state = torch.load(store_path, map_location=self.device) self.activation_store.load_state_dict(store_state) - print(f"Loaded activation store state from {store_path}") + logger.info(f"Loaded activation store state from {store_path}") else: - print(f"Warning: Activation store checkpoint not found at {store_path}") + logger.warning(f"Warning: Activation store checkpoint not found at {store_path}") # Load trainer state if os.path.exists(trainer_state_path): loaded_trainer_state = torch.load( trainer_state_path, map_location=self.device, weights_only=False ) - print(f"Loaded trainer state from {trainer_state_path}") + logger.info(f"Loaded trainer state from {trainer_state_path}") else: - print(f"Warning: Trainer state checkpoint not found at {trainer_state_path}") + logger.warning(f"Warning: Trainer state checkpoint not found at {trainer_state_path}") except Exception as e: - print(f"Error loading non-distributed checkpoint from {model_file_path}: {e}") + logger.error(f"Error loading non-distributed checkpoint from {model_file_path}: {e}") return loaded_trainer_state # --- Distributed Load --- # checkpoint_path is a directory for distributed checkpoints if not os.path.isdir(checkpoint_path): - print( + logger.error( f"Error: Checkpoint path {checkpoint_path} is not a directory. Distributed checkpoints are saved as directories." ) return loaded_trainer_state @@ -241,14 +248,14 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: ) # The model's state_dict is modified in-place by load_state_dict # self.model.load_state_dict(state_dict_to_load) # Not needed if modified in-place - print(f"Rank {self.rank}: Loaded distributed model checkpoint from {checkpoint_path}") + logger.info(f"Rank {self.rank}: Loaded distributed model checkpoint from {checkpoint_path}") except Exception as e: - print(f"Rank {self.rank}: Error loading distributed model checkpoint from {checkpoint_path}: {e}") + logger.error(f"Rank {self.rank}: Error loading distributed model checkpoint from {checkpoint_path}: {e}") # Attempt to load consolidated if sharded load fails (e.g. loading TP model on single GPU) # This part is tricky because load_state_dict above might have partially modified the model. # A cleaner approach for "load TP sharded on single GPU" would be separate. # For now, if distributed load_state_dict fails, we try the consolidated .safetensors - print( + logger.info( f"Rank {self.rank}: Attempting to load consolidated model.safetensors from the directory as fallback." ) consolidated_model_path = os.path.join(checkpoint_path, "model.safetensors") @@ -257,14 +264,16 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: # This load is for a single rank, assuming this rank needs the full model full_model_state = load_safetensors_file(consolidated_model_path, device=str(self.device)) self.model.load_state_dict(full_model_state) - print(f"Rank {self.rank}: Successfully loaded consolidated model from {consolidated_model_path}") + logger.info( + f"Rank {self.rank}: Successfully loaded consolidated model from {consolidated_model_path}" + ) except Exception as e_consol: - print( + logger.error( f"Rank {self.rank}: Failed to load consolidated model from {consolidated_model_path}: {e_consol}" ) return loaded_trainer_state # Failed both sharded and consolidated else: - print( + logger.info( f"Rank {self.rank}: Consolidated model.safetensors not found in {checkpoint_path}. Cannot fallback." ) return loaded_trainer_state # Failed sharded, no consolidated to fallback to @@ -278,22 +287,24 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: try: store_state = torch.load(store_file_path, map_location=self.device) self.activation_store.load_state_dict(store_state) - print(f"Rank 0: Loaded activation store state from {store_file_path}") + logger.info(f"Rank 0: Loaded activation store state from {store_file_path}") except Exception as e: - print(f"Rank 0: Warning: Failed to load activation store state from {store_file_path}: {e}") + logger.warning( + f"Rank 0: Warning: Failed to load activation store state from {store_file_path}: {e}" + ) else: - print(f"Rank 0: Warning: Activation store checkpoint not found in {checkpoint_path}") + logger.warning(f"Rank 0: Warning: Activation store checkpoint not found in {checkpoint_path}") trainer_state_file_path = os.path.join(checkpoint_path, "trainer_state.pt") if os.path.exists(trainer_state_file_path): try: # map_location CPU for items that might be on CUDA but not needed there by all ranks yet loaded_trainer_state = torch.load(trainer_state_file_path, map_location="cpu", weights_only=False) - print(f"Rank 0: Loaded trainer state from {trainer_state_file_path}") + logger.info(f"Rank 0: Loaded trainer state from {trainer_state_file_path}") except Exception as e: - print(f"Rank 0: Warning: Failed to load trainer state from {trainer_state_file_path}: {e}") + logger.warning(f"Rank 0: Warning: Failed to load trainer state from {trainer_state_file_path}: {e}") else: - print(f"Rank 0: Warning: Trainer state file not found in {checkpoint_path}") + logger.warning(f"Rank 0: Warning: Trainer state file not found in {checkpoint_path}") # Barrier to ensure all ranks have attempted model loading before proceeding. # And rank 0 has loaded other states. @@ -311,18 +322,20 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]: # Ensure loaded_trainer_state is a dict even if broadcast failed or returned None (defensive) if loaded_trainer_state is None: loaded_trainer_state = {} - print(f"Rank {self.rank}: Received broadcasted trainer state. Step: {loaded_trainer_state.get('step')}") + logger.info( + f"Rank {self.rank}: Received broadcasted trainer state. Step: {loaded_trainer_state.get('step')}" + ) return loaded_trainer_state def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoint_path: Optional[str] = None): """Loads a standard single-file model checkpoint (.pt or .safetensors).""" if self.distributed: - print("Error: Attempting to load non-distributed checkpoint in distributed mode.") + logger.error("Error: Attempting to load non-distributed checkpoint in distributed mode.") return if not os.path.exists(checkpoint_path): - print(f"Error: Model checkpoint not found at {checkpoint_path}") + logger.error(f"Error: Model checkpoint not found at {checkpoint_path}") return try: if checkpoint_path.endswith(".safetensors"): @@ -332,12 +345,14 @@ def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoin elif checkpoint_path.endswith(".pt"): full_state_dict = torch.load(checkpoint_path, map_location=self.device) else: - print(f"Error: Unknown checkpoint file extension for {checkpoint_path}. Must be .pt or .safetensors.") + logger.error( + f"Error: Unknown checkpoint file extension for {checkpoint_path}. Must be .pt or .safetensors." + ) return self.model.load_state_dict(full_state_dict) - print(f"Loaded non-distributed model checkpoint from {checkpoint_path}") + logger.info(f"Loaded non-distributed model checkpoint from {checkpoint_path}") except Exception as e: - print(f"Error loading non-distributed model checkpoint from {checkpoint_path}: {e}") + logger.error(f"Error loading non-distributed model checkpoint from {checkpoint_path}: {e}") return if store_checkpoint_path is None: @@ -372,12 +387,12 @@ def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoin store_state = torch.load(store_checkpoint_path, map_location=self.device) if hasattr(self, "activation_store") and self.activation_store is not None: self.activation_store.load_state_dict(store_state) - print(f"Loaded activation store state from {store_checkpoint_path}") + logger.info(f"Loaded activation store state from {store_checkpoint_path}") else: - print("Warning: Activation store not initialized. Cannot load state.") + logger.warning("Warning: Activation store not initialized. Cannot load state.") except Exception as e: - print(f"Warning: Failed to load activation store state from {store_checkpoint_path}: {e}") + logger.warning(f"Warning: Failed to load activation store state from {store_checkpoint_path}: {e}") else: - print( + logger.warning( f"Warning: Activation store checkpoint path not found or specified: {store_checkpoint_path}. Store state not loaded." ) diff --git a/clt/training/evaluator.py b/clt/training/evaluator.py index 3d9e651..25e94f6 100644 --- a/clt/training/evaluator.py +++ b/clt/training/evaluator.py @@ -6,6 +6,7 @@ import time # Import time import datetime # Import datetime +from clt.config import TrainingConfig, CLTConfig # Ensure these are imported for type hints from clt.models.clt import CrossLayerTranscoder # Configure logging @@ -50,6 +51,7 @@ def __init__( # Store normalisation stats if provided self.mean_tg = mean_tg or {} self.std_tg = std_tg or {} + self.metrics_history: List[Dict[str, Any]] = [] # For storing metrics over time if needed @staticmethod def _log_density(density: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: @@ -185,13 +187,13 @@ def _compute_sparsity(self, activations: Dict[int, torch.Tensor]) -> Dict[str, A Dictionary with L0 stats under 'sparsity/' and 'layerwise/l0/' keys. """ if not activations or not any(v.numel() > 0 for v in activations.values()): - print("Warning: Received empty activations for sparsity computation. " "Returning zeros.") - num_layers = self.model.config.num_layers + if not self.model.world_size > 1 or self.model.rank == 0: + logger.warning("Warning: Received empty activations for sparsity computation. " "Returning zeros.") return { "sparsity/total_l0": 0.0, "sparsity/avg_l0": 0.0, "sparsity/sparsity_fraction": 1.0, # Renamed from 'sparsity' - "layerwise/l0": {f"layer_{i}": 0.0 for i in range(num_layers)}, + "layerwise/l0": {f"layer_{i}": 0.0 for i in range(self.model.config.num_layers)}, } per_layer_l0_dict = {} @@ -401,7 +403,89 @@ def _compute_dead_neuron_metrics(self, dead_neuron_mask: Optional[torch.Tensor]) per_layer_dead_dict[f"layer_{layer_idx}"] = dead_neuron_mask[layer_idx].sum().item() dead_neuron_metrics["layerwise/dead_features"] = per_layer_dead_dict else: - print( - f"Warning: Received dead_neuron_mask with unexpected shape {dead_neuron_mask.shape}. Expected {expected_shape}. Skipping dead neuron eval metrics." - ) + if not self.model.world_size > 1 or self.model.rank == 0: + logger.warning( + f"Warning: Received dead_neuron_mask with unexpected shape {dead_neuron_mask.shape}. Expected {expected_shape}. Skipping dead neuron eval metrics." + ) return dead_neuron_metrics + + def print_evaluation_report( + self, + step: int, + metrics: Dict[str, Any], + detailed_metrics: Dict[str, Any], + current_training_config: Optional[TrainingConfig] = None, + current_clt_config: Optional[CLTConfig] = None, + ): + if not self.model.world_size > 1 or self.model.rank == 0: + logger.info( + "\n=======================================================================" + "\n--- Model Evaluation Report ---" + ) + logger.info(f"Evaluation at Step: {step}") + logger.info(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}") + logger.info("--- Overall Performance ---") + logger.info(f" Total Reconstruction Loss: {metrics.get('reconstruction/total_loss', float('nan')):.4f}") + logger.info(f" Total Sparsity Loss: {metrics.get('sparsity/total_loss', float('nan')):.4f}") + logger.info( + f" Explained Variance (Avg): {metrics.get('reconstruction/explained_variance', float('nan')):.4f}" + ) + logger.info(f" NMSE (Avg): {metrics.get('reconstruction/nmse', float('nan')):.4f}") + logger.info(f" L0 Norm (Avg per token): {metrics.get('sparsity/avg_l0', float('nan')):.2f}") + logger.info(f" Sparsity Fraction: {metrics.get('sparsity/sparsity_fraction', float('nan')):.4f}") + + # Layer-wise details + if metrics.get("layerwise/reconstruction_loss"): + logger.info("--- Layer-wise Reconstruction Loss ---") + for layer, loss in metrics["layerwise/reconstruction_loss"].items(): + logger.info(f" {layer}: {loss:.4f}") + + if metrics.get("layerwise/explained_variance"): + logger.info("--- Layer-wise Explained Variance ---") + for layer, ev in metrics["layerwise/explained_variance"].items(): + logger.info(f" {layer}: {ev:.4f}") + + if metrics.get("layerwise/nmse"): + logger.info("--- Layer-wise NMSE ---") + for layer, nmse_val in metrics["layerwise/nmse"].items(): + logger.info(f" {layer}: {nmse_val:.4f}") + + if metrics.get("layerwise/l0"): + logger.info("--- Layer-wise L0 Norm (Avg per token) ---") + for layer, l0 in metrics["layerwise/l0"].items(): + logger.info(f" {layer}: {l0:.2f}") + + # Feature density details + if detailed_metrics.get("feature_density_per_layer"): + logger.info("Feature Density Per Layer:") + for layer, density in detailed_metrics["feature_density_per_layer"].items(): + logger.info(f" Layer {layer}: {density:.4f}") + + # Dead features details + if detailed_metrics.get("dead_features_per_layer_eval"): + logger.info("Dead Features Per Layer (Evaluation Batch):") + for layer, count in detailed_metrics["dead_features_per_layer_eval"].items(): + logger.info(f" Layer {layer}: {count}") + + # Active features details + if detailed_metrics.get("active_features_per_layer_eval"): + logger.info("Active Features Per Layer (Evaluation Batch):") + for layer, count in detailed_metrics["active_features_per_layer_eval"].items(): + logger.info(f" Layer {layer}: {count}") + + # Overall dead/active features + logger.info(f"Total Dead Features (Eval Batch): {detailed_metrics.get('dead_features/total_eval', 0)}") + logger.info(f"Total Active Features (Eval Batch): {detailed_metrics.get('active_features/total_eval', 0)}") + + if current_training_config: + logger.info("--- Training Configuration ---") + logger.info(f" Learning Rate: {current_training_config.learning_rate}") + logger.info(f" Sparsity Lambda: {current_training_config.sparsity_lambda}") + # Add other relevant training config details + + if current_clt_config: + logger.info("--- CLT Model Configuration ---") + logger.info(f" Activation Function: {current_clt_config.activation_fn}") + logger.info(f" Number of Features: {current_clt_config.num_features}") + # Add other relevant CLT config details + logger.info("=======================================================================\n") diff --git a/clt/training/trainer.py b/clt/training/trainer.py index cf8dae6..359be85 100644 --- a/clt/training/trainer.py +++ b/clt/training/trainer.py @@ -442,54 +442,54 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: """ # Print startup message from rank 0 only if not self.distributed or self.rank == 0: - print(f"Starting CLT training on {self.device}...") - print( + logger.info(f"Starting CLT training on {self.device}...") + logger.info( f"Model has {self.clt_config.num_features} features per layer " f"and {self.clt_config.num_layers} layers" ) - print(f"Training for {self.training_config.training_steps} steps.") - print(f"Logging to {self.log_dir}") + logger.info(f"Training for {self.training_config.training_steps} steps.") + logger.info(f"Logging to {self.log_dir}") if self.distributed: - print(f"Distributed training with {self.world_size} processes (Tensor Parallelism)") + logger.info(f"Distributed training with {self.world_size} processes (Tensor Parallelism)") # Check if using normalization and notify user if self.training_config.normalization_method == "estimated_mean_std": - print("\n>>> NORMALIZATION PHASE <<<") - print("Normalization statistics are being estimated from dataset activations.") - print("This may take some time, but happens only once before training begins.") - print(f"Using {self.training_config.normalization_estimation_batches} batches for estimation.\n") + logger.info("\n>>> NORMALIZATION PHASE <<<") + logger.info("Normalization statistics are being estimated from dataset activations.") + logger.info("This may take some time, but happens only once before training begins.") + logger.info(f"Using {self.training_config.normalization_estimation_batches} batches for estimation.\n") # Make sure we flush stdout to ensure prints appear immediately, # especially important in Jupyter/interactive environments sys.stdout.flush() # Wait for 1 second to ensure output is displayed before training starts time.sleep(1) - print("\n>>> TRAINING PHASE <<<") + logger.info("\n>>> TRAINING PHASE <<<") sys.stdout.flush() # After the existing startup messages if self.distributed: - print("\n!!! DIAGNOSTIC INFO !!!") - print(f"Rank {self.rank}: Process group type: {type(self.process_group)}") - print(f"Rank {self.rank}: RowParallelLinear _reduce does NOT divide by world_size") - print(f"Rank {self.rank}: Using weight regularization in sparsity penalty") - print(f"Rank {self.rank}: Averaging replicated parameter gradients") + logger.info("\n!!! DIAGNOSTIC INFO !!!") + logger.info(f"Rank {self.rank}: Process group type: {type(self.process_group)}") + logger.info(f"Rank {self.rank}: RowParallelLinear _reduce does NOT divide by world_size") + logger.info(f"Rank {self.rank}: Using weight regularization in sparsity penalty") + logger.info(f"Rank {self.rank}: Averaging replicated parameter gradients") # Check if activation store has rank/world attributes before accessing store_rank = getattr(self.activation_store, "rank", "N/A") store_world = getattr(self.activation_store, "world", "N/A") - print(f"Rank {self.rank}: Data sharding: rank={store_rank}, world={store_world}") - print(f"Rank {self.rank}: Batch size tokens: {self.training_config.train_batch_size_tokens}") - print(f"Rank {self.rank}: Sparsity lambda: {self.training_config.sparsity_lambda}") + logger.info(f"Rank {self.rank}: Data sharding: rank={store_rank}, world={store_world}") + logger.info(f"Rank {self.rank}: Batch size tokens: {self.training_config.train_batch_size_tokens}") + logger.info(f"Rank {self.rank}: Sparsity lambda: {self.training_config.sparsity_lambda}") # Check if activation store actually loaded correctly batch_avail = next(iter(self.activation_store), None) - print(f"Rank {self.rank}: First batch available: {batch_avail is not None}") + logger.info(f"Rank {self.rank}: First batch available: {batch_avail is not None}") # Force torch to compile/execute our code by running a tiny forward/backward pass dummy = torch.ones(1, device=self.device, requires_grad=True) dummy_out = dummy * 2 dummy_out.backward() - print("!!! END DIAGNOSTIC !!!\n") + logger.info("!!! END DIAGNOSTIC !!!\n") # --- Enable Anomaly Detection (if configured) --- if self.training_config.debug_anomaly: @@ -584,7 +584,7 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: except StopIteration: # Rank 0 prints message if not self.distributed or self.rank == 0: - print("Activation store exhausted. Training finished early.") + logger.info("Activation store exhausted. Training finished early.") if self.distributed: dist.barrier() # Ensure all ranks see this break # Exit training loop if data runs out @@ -593,7 +593,7 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: # This check should ideally happen *before* moving data potentially if not inputs or not targets or not any(v.numel() > 0 for v in inputs.values()): if not self.distributed or self.rank == 0: - print(f"\nRank {self.rank}: Warning: Received empty batch at step {step}. Skipping.") + logger.warning(f"Rank {self.rank}: Warning: Received empty batch at step {step}. Skipping.") continue # --- BEGIN: One-time Normalization Check --- @@ -901,12 +901,12 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: except KeyboardInterrupt: if not self.distributed or self.rank == 0: - print("\nTraining interrupted by user.") + logger.info("\nTraining interrupted by user.") finally: if isinstance(pbar, tqdm): pbar.close() if not self.distributed or self.rank == 0: - print(f"Training loop finished at step {step}.") + logger.info(f"Training loop finished at step {step}.") # Sync before final save attempt if self.distributed: @@ -937,37 +937,39 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: trainer_state_to_save=final_trainer_state_for_checkpoint, ) except IOError as e: # More specific: catch IOError for checkpoint saving - print(f"Rank {self.rank}: Warning: Failed to save final distributed model state due to IOError: {e}") + logger.warning( + f"Rank {self.rank}: Warning: Failed to save final distributed model state due to IOError: {e}" + ) except Exception as e: # Catch other potential errors during final save but log them as more critical - print(f"Rank {self.rank}: CRITICAL: Unexpected error during final model state save: {e}") + logger.critical(f"Rank {self.rank}: CRITICAL: Unexpected error during final model state save: {e}") # Rank 0 saves store, metrics, logs artifact if not self.distributed or self.rank == 0: - print(f"Saving final activation store state to {final_store_path}...") + logger.info(f"Saving final activation store state to {final_store_path}...") os.makedirs(final_checkpoint_dir, exist_ok=True) try: # Check if the store has a close method before calling (for compatibility) if hasattr(self.activation_store, "close") and callable(getattr(self.activation_store, "close")): self.activation_store.close() except IOError as e: # More specific: catch IOError for store closing - print(f"Rank 0: Warning: Failed to close activation store due to IOError: {e}") + logger.warning(f"Rank 0: Warning: Failed to close activation store due to IOError: {e}") except Exception as e: # Catch other potential errors during store close - print(f"Rank 0: Warning: Unexpected error closing activation store: {e}") + logger.warning(f"Rank 0: Warning: Unexpected error closing activation store: {e}") - print("Saving final metrics...") + logger.info("Saving final metrics...") # self.metric_logger._save_metrics_to_disk() # Final save - this should be robust try: self.metric_logger._save_metrics_to_disk() except IOError as e: - print(f"Rank 0: Warning: Failed to save final metrics to disk due to IOError: {e}") + logger.warning(f"Rank 0: Warning: Failed to save final metrics to disk due to IOError: {e}") except Exception as e: - print(f"Rank 0: Warning: Unexpected error saving final metrics: {e}") + logger.warning(f"Rank 0: Warning: Unexpected error saving final metrics: {e}") # --- Save CLT Config to JSON --- # The config saved here will now reflect the configuration *during training* (e.g. BatchTopK) # The user will need to run estimate_theta_posthoc and then save the converted JumpReLU model themselves. config_save_path = os.path.join(self.log_dir, "cfg.json") - print(f"Saving CLT configuration (as trained) to {config_save_path}...") + logger.info(f"Saving CLT configuration (as trained) to {config_save_path}...") try: config_dict_as_trained = asdict(self.clt_config) @@ -988,16 +990,16 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: with open(config_save_path, "w") as f: json.dump(config_dict_as_trained, f, indent=2) - print(f"Successfully saved training configuration to {config_save_path}") + logger.info(f"Successfully saved training configuration to {config_save_path}") if self.clt_config.activation_fn == "batchtopk": - print( + logger.info( "NOTE: Model was trained with BatchTopK. Run estimate_theta_posthoc() on the saved model to convert to JumpReLU and finalize theta values." ) except IOError as e: # More specific: catch IOError for config saving - print(f"Rank 0: Warning: Failed to save CLT configuration to JSON due to IOError: {e}") + logger.warning(f"Rank 0: Warning: Failed to save CLT configuration to JSON due to IOError: {e}") except Exception as e: # Catch other potential errors during config saving - print(f"Rank 0: Warning: Unexpected error saving CLT configuration to JSON: {e}") + logger.warning(f"Rank 0: Warning: Unexpected error saving CLT configuration to JSON: {e}") # --- End Save CLT Config --- # Log final checkpoint directory as artifact @@ -1005,7 +1007,7 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: # Finish WandB logging self.wandb_logger.finish() - print(f"Training completed! Final checkpoint saved to {final_checkpoint_dir}") + logger.info(f"Training completed! Final checkpoint saved to {final_checkpoint_dir}") # --- Close the activation store (stops prefetch thread if applicable) --- # if hasattr(self.activation_store, "close") and callable(getattr(self.activation_store, "close")): diff --git a/clt/training/wandb_logger.py b/clt/training/wandb_logger.py index 754fcf5..3c351af 100644 --- a/clt/training/wandb_logger.py +++ b/clt/training/wandb_logger.py @@ -1,9 +1,13 @@ -import time -import importlib.util +import wandb +import os from typing import Dict, Optional, Any +import logging +from dataclasses import asdict from clt.config import CLTConfig, TrainingConfig +logger = logging.getLogger(__name__) + # Define the dummy logger class explicitly for better type checking class DummyWandBLogger: @@ -32,6 +36,7 @@ class WandBLogger: """Wrapper class for Weights & Biases logging.""" _run_id: Optional[str] = None + wandb_run: Optional[Any] def __init__( self, @@ -50,67 +55,51 @@ def __init__( """ self.enabled = training_config.enable_wandb self.log_dir = log_dir + self.resume_wandb_id = resume_wandb_id if not self.enabled: + logger.info("WandB logging is disabled by training_config.enable_wandb=False.") + self.wandb_run = None return - # Check if wandb is installed - if not importlib.util.find_spec("wandb"): - print( - "Warning: WandB logging requested but wandb not installed. " - "Install with 'pip install wandb'. Continuing without WandB." - ) - self.enabled = False - return - - # Import wandb - import wandb - - # Set up run name with timestamp if not provided - run_name = training_config.wandb_run_name - if run_name is None: - run_name = f"clt-{time.strftime('%Y%m%d-%H%M%S')}" - - # Initialize wandb - wandb_init_kwargs = { - "project": training_config.wandb_project, - "entity": training_config.wandb_entity, - "name": run_name, - "dir": log_dir, - "tags": training_config.wandb_tags, - "config": { - **clt_config.__dict__, - **training_config.__dict__, - "log_dir": log_dir, - }, - } + try: + # Try to get project from env var first, then from config + project_name = os.environ.get("WANDB_PROJECT", training_config.wandb_project) + entity_name = os.environ.get("WANDB_ENTITY", training_config.wandb_entity) + run_name = training_config.wandb_run_name # Can be None - if resume_wandb_id: - wandb_init_kwargs["id"] = resume_wandb_id - wandb_init_kwargs["resume"] = "must" - # If resuming by ID, let WandB use the original run's name or handle naming. - # Setting name explicitly here might conflict if the auto-generated name differs. - # Let's try removing the name from kwargs if resume_wandb_id is present. - if "name" in wandb_init_kwargs: - # Important: Only remove 'name' if we are truly trying to resume by ID. - # If resume_wandb_id was found, we prioritize it. - del wandb_init_kwargs["name"] - print( - f"Attempting to resume WandB run with ID: {resume_wandb_id} and resume='must'. Name will be sourced from existing run." - ) - - wandb.init(**wandb_init_kwargs) - - if wandb.run is not None: - print(f"WandB logging initialized: {wandb.run.name} (ID: {wandb.run.id})") - self._run_id = wandb.run.id - else: - if resume_wandb_id: - print( - f"Warning: Failed to resume WandB run {resume_wandb_id}. A new run might have been started or init failed." + if self.resume_wandb_id: + logger.info(f"Attempting to resume WandB run with ID: {self.resume_wandb_id}") + # When resuming, wandb.init will use the passed id. Do not pass project/entity/name again if resuming. + self.wandb_run = wandb.init(id=self.resume_wandb_id, resume="allow") + else: + self.wandb_run = wandb.init( + project=project_name, + entity=entity_name, + name=run_name, + config=self._create_wandb_config(training_config, clt_config), + reinit=True, # Allow re-initialization in the same process if needed ) + + if self.wandb_run: + logger.info(f"WandB logging initialized: {self.wandb_run.name} (ID: {self.wandb_run.id})") else: - print("Warning: WandB run initialization failed.") + logger.warning("Warning: WandB run initialization failed but no exception was raised.") + + except ImportError: + logger.warning("wandb library not installed. Please install with `pip install wandb` to use WandB logging.") + self.wandb_run = None + except Exception as e: + logger.error(f"Error initializing WandB: {e}") + self.wandb_run = None + + def _create_wandb_config(self, training_config: TrainingConfig, clt_config: CLTConfig) -> Dict[str, Any]: + config_dict = { + "training_config": asdict(training_config), + "clt_config": asdict(clt_config), + "log_dir": self.log_dir, + } + return config_dict def get_current_wandb_run_id(self) -> Optional[str]: """Returns the current WandB run ID, if a run is active.""" @@ -169,7 +158,11 @@ def log_step( metrics["training/total_tokens_processed"] = total_tokens_processed # Log to wandb - wandb.log(metrics, step=step) + if self.wandb_run: + try: + self.wandb_run.log(metrics, step=step) + except Exception as e: + logger.error(f"Wandb: Error logging metrics at step {step}: {e}") def log_evaluation(self, step: int, eval_metrics: Dict[str, Any]): """Log evaluation metrics, organized by the structure from CLTEvaluator. @@ -202,7 +195,7 @@ def log_evaluation(self, step: int, eval_metrics: Dict[str, Any]): try: wandb_log_dict[wandb_key] = wandb.Histogram(layer_value) except Exception as e: - print(f"Wandb: Error creating histogram for {wandb_key}: {e}") + logger.error(f"Wandb: Error creating histogram for {wandb_key}: {e}") # Fallback: log mean or placeholder try: mean_val = sum(layer_value) / len(layer_value) if layer_value else 0.0 @@ -220,7 +213,7 @@ def log_evaluation(self, step: int, eval_metrics: Dict[str, Any]): try: wandb_log_dict[key] = wandb.Histogram(value) except Exception as e: - print(f"Wandb: Error creating aggregate histogram for {key}: {e}") + logger.error(f"Wandb: Error creating aggregate histogram for {key}: {e}") # Optional Fallback: log mean of aggregate data try: mean_val = sum(value) / len(value) if value else 0.0 @@ -250,7 +243,11 @@ def log_evaluation(self, step: int, eval_metrics: Dict[str, Any]): # Log the prepared dictionary to wandb if wandb_log_dict: - wandb.log(wandb_log_dict, step=step) + if self.wandb_run: + try: + self.wandb_run.log(wandb_log_dict, step=step) + except Exception as e: + logger.error(f"Wandb: Error logging metrics at step {step}: {e}") def log_artifact(self, artifact_path: str, artifact_type: str, name: Optional[str] = None): """Log an artifact to WandB. diff --git a/scripts/convert_batchtopk_to_jumprelu.py b/scripts/convert_batchtopk_to_jumprelu.py index f4abe29..74a26b4 100644 --- a/scripts/convert_batchtopk_to_jumprelu.py +++ b/scripts/convert_batchtopk_to_jumprelu.py @@ -30,6 +30,37 @@ logger = logging.getLogger(__name__) +def _remap_checkpoint_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Remaps old state_dict keys to the new format with module prefixes.""" + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("encoders."): + new_key = "encoder_module." + key + elif key.startswith("decoders."): + new_key = "decoder_module." + key + else: + new_key = key + new_state_dict[new_key] = value + if not any(k.startswith("encoder_module.") or k.startswith("decoder_module.") for k in new_state_dict.keys()): + if any(k.startswith("encoders.") or k.startswith("decoders.") for k in state_dict.keys()): + logger.warning( + "Key remapping applied, but no keys were actually changed. " + "This might indicate the checkpoint is already in the new format or the remapping logic is flawed." + ) + else: + # Neither old nor new prefixes found, probably a different kind of checkpoint or already remapped. + pass # Don't log a warning if it seems like it's already fine. + + # Check if we missed any expected old keys + remapped_old_keys = any(k.startswith("encoders.") or k.startswith("decoders.") for k in new_state_dict.keys()) + if remapped_old_keys: + logger.error("CRITICAL: Key remapping failed to remove all old-style keys. Check remapping logic.") + # This state is problematic, so we might choose to raise an error or return the original dict. + # For now, returning the partially remapped dict and letting load_state_dict fail might be more informative. + + return new_state_dict + + def main(args): if args.device: device = torch.device(args.device) @@ -94,10 +125,13 @@ def main(args): if args.batchtopk_checkpoint_path.endswith(".safetensors"): logger.info(f"Loading BatchTopK model state from safetensors file: {args.batchtopk_checkpoint_path}") state_dict = load_safetensors_file(args.batchtopk_checkpoint_path, device=device.type) + state_dict = _remap_checkpoint_keys(state_dict) # Remap keys model.load_state_dict(state_dict) else: logger.info(f"Loading BatchTopK model state from .pt file: {args.batchtopk_checkpoint_path}") - model.load_state_dict(torch.load(args.batchtopk_checkpoint_path, map_location=device)) + state_dict = torch.load(args.batchtopk_checkpoint_path, map_location=device) + state_dict = _remap_checkpoint_keys(state_dict) # Remap keys + model.load_state_dict(state_dict) model.eval() logger.info("BatchTopK model loaded and set to eval mode.") @@ -367,12 +401,16 @@ def main(args): storage_reader=FileSystemReader(original_checkpoint_path), no_dist=True, ) + state_dict_to_populate_orig = _remap_checkpoint_keys(state_dict_to_populate_orig) original_model.load_state_dict(state_dict_to_populate_orig) elif original_checkpoint_path.endswith(".safetensors"): state_dict_orig = load_safetensors_file(original_checkpoint_path, device=device.type) + state_dict_orig = _remap_checkpoint_keys(state_dict_orig) original_model.load_state_dict(state_dict_orig) else: - original_model.load_state_dict(torch.load(original_checkpoint_path, map_location=device)) + state_dict_orig = torch.load(original_checkpoint_path, map_location=device) + state_dict_orig = _remap_checkpoint_keys(state_dict_orig) + original_model.load_state_dict(state_dict_orig) original_model.eval() logger.info("Original model loaded successfully for L0 target calculation.")