14
14
from __future__ import annotations
15
15
16
16
import os
17
- import time
18
17
import json
19
18
import queue
20
19
import random
21
20
import logging
22
21
import threading
23
22
from pathlib import Path
24
- from typing import Dict , List , Optional , Tuple , Any
23
+ from typing import Dict , List , Optional , Tuple , Any , DefaultDict
25
24
from concurrent .futures import ThreadPoolExecutor , as_completed
26
25
27
26
import torch
36
35
from clt .config .data_config import ActivationConfig # noqa: E402
37
36
38
37
# --- Profiling Imports ---
39
- import time # Already imported, but good to note
38
+ import time # Keep this one
40
39
from contextlib import contextmanager
41
40
from collections import defaultdict
42
41
import psutil
58
57
# --- Performance Profiler Class ---
59
58
class PerformanceProfiler :
60
59
def __init__ (self , chunk_tokens_threshold : int = 1_000_000 ):
61
- self .timings = defaultdict (list )
62
- self .memory_snapshots = []
60
+ self .timings : DefaultDict [ str , List [ float ]] = defaultdict (list )
61
+ self .memory_snapshots : List [ Dict [ str , Any ]] = []
63
62
self .chunk_tokens_threshold = chunk_tokens_threshold
64
63
self .system_metrics_log : List [Dict [str , Any ]] = []
65
64
self .layer_ids_ref : Optional [List [int ]] = None
@@ -143,7 +142,7 @@ def log_system_metrics(self, interval_name: str = "interval"):
143
142
return metrics
144
143
145
144
def report (self ):
146
- print ("\n === Performance Report ===" )
145
+ logger . info ("\n === Performance Report ===" )
147
146
# Sort by total time descending for timings
148
147
sorted_timings = sorted (self .timings .items (), key = lambda item : sum (item [1 ]), reverse = True )
149
148
@@ -155,15 +154,17 @@ def report(self):
155
154
min_time = min (times )
156
155
max_time = max (times )
157
156
158
- print (f"\n --- Operation: { name } ---" )
159
- print (f" Count: { len (times )} " )
160
- print (f" Total time: { total_time :.3f} s" )
161
- print (f" Avg time: { avg_time :.4f} s" )
162
- print (f" Min time: { min_time :.4f} s" )
163
- print (f" Max time: { max_time :.4f} s" )
157
+ logger . info (f"\n --- Operation: { name } ---" )
158
+ logger . info (f" Count: { len (times )} " )
159
+ logger . info (f" Total time: { total_time :.3f} s" )
160
+ logger . info (f" Avg time: { avg_time :.4f} s" )
161
+ logger . info (f" Min time: { min_time :.4f} s" )
162
+ logger . info (f" Max time: { max_time :.4f} s" )
164
163
165
164
if "chunk_write_total_idx" in name : # New unique name per chunk
166
- print (f" Avg ms/k-tok (for this chunk): { avg_time / self .chunk_tokens_threshold * 1000 * 1000 :.2f} " )
165
+ logger .info (
166
+ f" Avg ms/k-tok (for this chunk): { avg_time / self .chunk_tokens_threshold * 1000 * 1000 :.2f} "
167
+ )
167
168
elif (
168
169
name == "batch_processing_total"
169
170
and self .batch_processing_total_calls > 0
@@ -173,29 +174,29 @@ def report(self):
173
174
self .total_tokens_processed_for_batch_profiling / self .batch_processing_total_calls
174
175
)
175
176
if avg_tok_per_batch_call > 0 :
176
- print (
177
+ logger . info (
177
178
f" Avg ms/k-tok (estimated for batch_processing_total): { avg_time / avg_tok_per_batch_call * 1000 * 1000 :.2f} "
178
179
)
179
180
180
- print ("\n === Memory Snapshots (showing top 10 by RSS delta) ===" )
181
+ logger . info ("\n === Memory Snapshots (showing top 10 by RSS delta) ===" )
181
182
interesting_mem_snapshots = sorted (
182
183
self .memory_snapshots , key = lambda x : abs (x ["rss_delta_bytes" ]), reverse = True
183
184
)[:10 ]
184
185
for snap in interesting_mem_snapshots :
185
- print (
186
+ logger . info (
186
187
f" { snap ['name' ]} (took { snap ['duration_s' ]:.3f} s): Total RSS { snap ['rss_total_bytes' ] / (1024 ** 3 ):.3f} GB (ΔRSS { snap ['rss_delta_bytes' ] / (1024 ** 3 ):.3f} GB)"
187
188
)
188
189
189
- print ("\n === System Metrics Log (sample) ===" )
190
+ logger . info ("\n === System Metrics Log (sample) ===" )
190
191
for i , metrics in enumerate (self .system_metrics_log [:5 ]): # Print first 5 samples
191
- print (
192
+ logger . info (
192
193
f" Sample { i } ({ metrics ['interval_name' ]} ): CPU { metrics ['cpu_percent' ]:.1f} %, Mem { metrics ['memory_percent' ]:.1f} %, GPU { metrics ['gpu_util_percent' ]:.1f} % (Mem { metrics ['gpu_memory_percent' ]:.1f} %)"
193
194
)
194
195
if len (self .system_metrics_log ) > 5 :
195
- print (" ..." )
196
+ logger . info (" ..." )
196
197
if self .system_metrics_log : # Check if not empty before accessing last element
197
198
metrics = self .system_metrics_log [- 1 ]
198
- print (
199
+ logger . info (
199
200
f" Sample End ({ metrics ['interval_name' ]} ): CPU { metrics ['cpu_percent' ]:.1f} %, Mem { metrics ['memory_percent' ]:.1f} %, GPU { metrics ['gpu_util_percent' ]:.1f} % (Mem { metrics ['gpu_memory_percent' ]:.1f} %)"
200
201
)
201
202
@@ -288,7 +289,7 @@ def _async_uploader(upload_q: "queue.Queue[Optional[Path]]", cfg: ActivationConf
288
289
# --> ADDED: Retry Loop <--
289
290
for attempt in range (max_retries_per_chunk ):
290
291
try :
291
- print (
292
+ logger . info (
292
293
f"[Uploader Thread Attempt { attempt + 1 } /{ max_retries_per_chunk } ] Uploading chunk: { p .name } to { url } "
293
294
)
294
295
with open (p , "rb" ) as f :
@@ -972,7 +973,7 @@ def _upload_binary_file(self, path: Path, endpoint: str):
972
973
try :
973
974
activation_config_instance = ActivationConfig (** loaded_config )
974
975
except TypeError as e :
975
- print (f"Error creating ActivationConfig from YAML. Ensure all keys are correct: { e } " )
976
+ logger . error (f"Error creating ActivationConfig from YAML. Ensure all keys are correct: { e } " )
976
977
import sys
977
978
978
979
sys .exit (1 )
0 commit comments