Skip to content

Commit c8aa9c1

Browse files
authored
Merge pull request #31 from curt-tigges/fix/restore-conversion-script
Fix/restore conversion script
2 parents 3623275 + 802fc64 commit c8aa9c1

File tree

9 files changed

+354
-170
lines changed

9 files changed

+354
-170
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ vis/
206206
clt_test_pythia_70m_jumprelu/
207207
clt_smoke_output_local_wandb_batchtopk/
208208
clt_smoke_output_remote_wandb/
209+
wandb/
209210

210211
# models
211212
*.pt

clt/activation_generation/generator.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
from __future__ import annotations
1515

1616
import os
17-
import time
1817
import json
1918
import queue
2019
import random
2120
import logging
2221
import threading
2322
from pathlib import Path
24-
from typing import Dict, List, Optional, Tuple, Any
23+
from typing import Dict, List, Optional, Tuple, Any, DefaultDict
2524
from concurrent.futures import ThreadPoolExecutor, as_completed
2625

2726
import torch
@@ -36,7 +35,7 @@
3635
from clt.config.data_config import ActivationConfig # noqa: E402
3736

3837
# --- Profiling Imports ---
39-
import time # Already imported, but good to note
38+
import time # Keep this one
4039
from contextlib import contextmanager
4140
from collections import defaultdict
4241
import psutil
@@ -58,8 +57,8 @@
5857
# --- Performance Profiler Class ---
5958
class PerformanceProfiler:
6059
def __init__(self, chunk_tokens_threshold: int = 1_000_000):
61-
self.timings = defaultdict(list)
62-
self.memory_snapshots = []
60+
self.timings: DefaultDict[str, List[float]] = defaultdict(list)
61+
self.memory_snapshots: List[Dict[str, Any]] = []
6362
self.chunk_tokens_threshold = chunk_tokens_threshold
6463
self.system_metrics_log: List[Dict[str, Any]] = []
6564
self.layer_ids_ref: Optional[List[int]] = None
@@ -143,7 +142,7 @@ def log_system_metrics(self, interval_name: str = "interval"):
143142
return metrics
144143

145144
def report(self):
146-
print("\n=== Performance Report ===")
145+
logger.info("\n=== Performance Report ===")
147146
# Sort by total time descending for timings
148147
sorted_timings = sorted(self.timings.items(), key=lambda item: sum(item[1]), reverse=True)
149148

@@ -155,15 +154,17 @@ def report(self):
155154
min_time = min(times)
156155
max_time = max(times)
157156

158-
print(f"\n--- Operation: {name} ---")
159-
print(f" Count: {len(times)}")
160-
print(f" Total time: {total_time:.3f}s")
161-
print(f" Avg time: {avg_time:.4f}s")
162-
print(f" Min time: {min_time:.4f}s")
163-
print(f" Max time: {max_time:.4f}s")
157+
logger.info(f"\n--- Operation: {name} ---")
158+
logger.info(f" Count: {len(times)}")
159+
logger.info(f" Total time: {total_time:.3f}s")
160+
logger.info(f" Avg time: {avg_time:.4f}s")
161+
logger.info(f" Min time: {min_time:.4f}s")
162+
logger.info(f" Max time: {max_time:.4f}s")
164163

165164
if "chunk_write_total_idx" in name: # New unique name per chunk
166-
print(f" Avg ms/k-tok (for this chunk): {avg_time / self.chunk_tokens_threshold * 1000 * 1000:.2f}")
165+
logger.info(
166+
f" Avg ms/k-tok (for this chunk): {avg_time / self.chunk_tokens_threshold * 1000 * 1000:.2f}"
167+
)
167168
elif (
168169
name == "batch_processing_total"
169170
and self.batch_processing_total_calls > 0
@@ -173,29 +174,29 @@ def report(self):
173174
self.total_tokens_processed_for_batch_profiling / self.batch_processing_total_calls
174175
)
175176
if avg_tok_per_batch_call > 0:
176-
print(
177+
logger.info(
177178
f" Avg ms/k-tok (estimated for batch_processing_total): {avg_time / avg_tok_per_batch_call * 1000 * 1000:.2f}"
178179
)
179180

180-
print("\n=== Memory Snapshots (showing top 10 by RSS delta) ===")
181+
logger.info("\n=== Memory Snapshots (showing top 10 by RSS delta) ===")
181182
interesting_mem_snapshots = sorted(
182183
self.memory_snapshots, key=lambda x: abs(x["rss_delta_bytes"]), reverse=True
183184
)[:10]
184185
for snap in interesting_mem_snapshots:
185-
print(
186+
logger.info(
186187
f" {snap['name']} (took {snap['duration_s']:.3f}s): Total RSS {snap['rss_total_bytes'] / (1024**3):.3f} GB (ΔRSS {snap['rss_delta_bytes'] / (1024**3):.3f} GB)"
187188
)
188189

189-
print("\n=== System Metrics Log (sample) ===")
190+
logger.info("\n=== System Metrics Log (sample) ===")
190191
for i, metrics in enumerate(self.system_metrics_log[:5]): # Print first 5 samples
191-
print(
192+
logger.info(
192193
f" Sample {i} ({metrics['interval_name']}): CPU {metrics['cpu_percent']:.1f}%, Mem {metrics['memory_percent']:.1f}%, GPU {metrics['gpu_util_percent']:.1f}% (Mem {metrics['gpu_memory_percent']:.1f}%)"
193194
)
194195
if len(self.system_metrics_log) > 5:
195-
print(" ...")
196+
logger.info(" ...")
196197
if self.system_metrics_log: # Check if not empty before accessing last element
197198
metrics = self.system_metrics_log[-1]
198-
print(
199+
logger.info(
199200
f" Sample End ({metrics['interval_name']}): CPU {metrics['cpu_percent']:.1f}%, Mem {metrics['memory_percent']:.1f}%, GPU {metrics['gpu_util_percent']:.1f}% (Mem {metrics['gpu_memory_percent']:.1f}%)"
200201
)
201202

@@ -288,7 +289,7 @@ def _async_uploader(upload_q: "queue.Queue[Optional[Path]]", cfg: ActivationConf
288289
# --> ADDED: Retry Loop <--
289290
for attempt in range(max_retries_per_chunk):
290291
try:
291-
print(
292+
logger.info(
292293
f"[Uploader Thread Attempt {attempt + 1}/{max_retries_per_chunk}] Uploading chunk: {p.name} to {url}"
293294
)
294295
with open(p, "rb") as f:
@@ -972,7 +973,7 @@ def _upload_binary_file(self, path: Path, endpoint: str):
972973
try:
973974
activation_config_instance = ActivationConfig(**loaded_config)
974975
except TypeError as e:
975-
print(f"Error creating ActivationConfig from YAML. Ensure all keys are correct: {e}")
976+
logger.error(f"Error creating ActivationConfig from YAML. Ensure all keys are correct: {e}")
976977
import sys
977978

978979
sys.exit(1)

clt/config/data_config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import Literal, Optional, Dict, Any
3+
import logging
4+
5+
logger = logging.getLogger(__name__)
36

47

58
@dataclass
@@ -74,7 +77,23 @@ def __post_init__(self):
7477
except ImportError:
7578
raise ImportError("h5py is required for HDF5 output format. Install with: pip install h5py")
7679
if self.compression not in ["lz4", "gzip", None, False]:
77-
print(
80+
logger.warning(
7881
f"Warning: Unsupported compression '{self.compression}'. Will attempt without compression for {self.output_format}."
7982
)
8083
# Allow generator to handle disabling if format doesn't support it.
84+
85+
# Example: Print a summary or key values
86+
# This is more for user feedback than programmatic use.
87+
logger.info(
88+
"ActivationConfig Summary:\n"
89+
f" Model: {self.model_name}\n"
90+
f" Dataset: {self.dataset_path} (Split: {self.dataset_split})\n"
91+
f" Target Tokens: {self.target_total_tokens}\n"
92+
f" Chunk Threshold: {self.chunk_token_threshold}\n"
93+
f" Activation Dtype: {self.activation_dtype}\n"
94+
f" Output Dir: {self.activation_dir}"
95+
)
96+
if self.remote_server_url:
97+
logger.info(f" Remote Server URL: {self.remote_server_url}")
98+
if self.delete_after_upload:
99+
logger.info(" Delete after upload: Enabled")

clt/models/clt.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,30 @@ def convert_to_jumprelu_inplace(self, default_theta_value: float = 1e6) -> None:
275275
logger.info(
276276
f"Rank {self.rank}: CLT model config updated by ThetaManager. New activation_fn='{self.config.activation_fn}'."
277277
)
278+
279+
# --- Back-compat: expose ThetaManager.log_threshold at model level ---
280+
@property
281+
def log_threshold(self) -> Optional[torch.nn.Parameter]:
282+
"""Proxy to ``theta_manager.log_threshold`` for backward compatibility.
283+
284+
Older training scripts, conversion utilities and tests referenced
285+
``model.log_threshold`` directly. After the Step-5 refactor the
286+
parameter migrated into the dedicated ``ThetaManager`` module. We
287+
now expose a read-only view that always returns the *current* parameter
288+
held by ``self.theta_manager``. Modifying the returned tensor (e.g.
289+
in-place updates to ``.data``) therefore continues to work as before.
290+
Assigning a brand-new ``nn.Parameter`` to ``model.log_threshold`` will
291+
forward the assignment to ``theta_manager`` to preserve the linkage.
292+
"""
293+
if hasattr(self, "theta_manager") and self.theta_manager is not None:
294+
return self.theta_manager.log_threshold
295+
return None
296+
297+
@log_threshold.setter
298+
def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None:
299+
# Keep property writable so callers that used to assign a fresh
300+
# parameter (rare) do not break. We delegate the storage to
301+
# ``ThetaManager`` so there is a single source of truth.
302+
if not hasattr(self, "theta_manager") or self.theta_manager is None:
303+
raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.")
304+
self.theta_manager.log_threshold = new_param

0 commit comments

Comments
 (0)