Skip to content

Commit 025d3cb

Browse files
author
Curt Tigges
committed
added logging to train script
1 parent cd79a72 commit 025d3cb

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

scripts/train_clt.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
# Attempt to import transformers for model dimension detection
1717
try:
1818
from transformers import AutoConfig
19+
import transformers # Import the library itself to check version
20+
import sys # Import sys to check path
1921
except ImportError:
2022
AutoConfig = None
23+
transformers = None # type: ignore
24+
sys = None # type: ignore
2125

2226
# Import necessary CLT components
2327
try:
@@ -46,10 +50,33 @@ def get_model_dimensions(model_name: str) -> tuple[Optional[int], Optional[int]]
4650
return None, None # Indicate failure to auto-detect
4751

4852
try:
53+
if transformers and hasattr(transformers, "__version__"):
54+
logger.info(f"Transformers library version: {transformers.__version__}")
55+
if sys:
56+
logger.info(f"Python sys.path: {sys.path}")
57+
58+
logger.info(f"Attempting to load config for model_name: '{model_name}'")
4959
config = AutoConfig.from_pretrained(model_name)
50-
num_layers = getattr(config, "num-hidden-layers", None) or getattr(config, "n-layer", None)
51-
d_model = getattr(config, "hidden-size", None) or getattr(config, "n-embd", None)
60+
logger.info(f"Loaded config object: type={type(config)}")
61+
if hasattr(config, "to_dict"):
62+
# Log only a few key attributes to avoid excessively long log messages
63+
# if the config is huge. Relevant ones might be 'model_type', 'architectures'.
64+
config_dict_summary = {
65+
k: v
66+
for k, v in config.to_dict().items()
67+
if k in ["model_type", "architectures", "num_hidden_layers", "n_layer", "hidden_size", "n_embd"]
68+
}
69+
logger.info(f"Config content summary: {config_dict_summary}")
70+
# If still debugging, can log the full dict, but be wary of verbosity:
71+
# logger.debug(f"Full config content: {config.to_dict()}")
72+
elif hasattr(config, "__dict__"):
73+
logger.info(f"Config content (vars): {vars(config)}")
74+
else:
75+
logger.info(f"Config object does not have to_dict or __dict__ methods. Content: {config}")
5276

77+
num_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer", None)
78+
d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd", None)
79+
logger.info(f"Attempted to get dimensions: num_layers={num_layers}, d_model={d_model}")
5380
if num_layers is None or d_model is None:
5481
logger.warning(
5582
f"Could not automatically determine num_layers or d_model for {model_name}. "
@@ -86,7 +113,7 @@ def parse_args():
86113
"--output-dir",
87114
type=str,
88115
default=f"clt_train_{int(time.time())}",
89-
help="Directory to save logs, checkpoints, and final model. If resuming, this might be overridden by --resume-from-checkpoint-dir.",
116+
help="Directory to save logs, checkpoints, and final model. If resuming, this might be overridden by --resume_from_checkpoint_dir.",
90117
)
91118
core_group.add_argument(
92119
"--model-name",
@@ -106,16 +133,16 @@ def parse_args():
106133
help="Enable distributed training (requires torchrun/appropriate launcher).",
107134
)
108135
core_group.add_argument(
109-
"--resume-from-checkpoint-dir",
136+
"--resume_from_checkpoint_dir",
110137
type=str,
111138
default=None,
112-
help="Path to the output directory of a previous run to resume from. Will attempt to load 'latest' or a specific step if --resume-step is also given.",
139+
help="Path to the output directory of a previous run to resume from. Will attempt to load 'latest' or a specific step if --resume_step is also given.",
113140
)
114141
core_group.add_argument(
115-
"--resume-step",
142+
"--resume_step",
116143
type=int,
117144
default=None,
118-
help="Optional specific step to resume from. Used in conjunction with --resume-from-checkpoint-dir.",
145+
help="Optional specific step to resume from. Used in conjunction with --resume_from_checkpoint_dir.",
119146
)
120147

121148
# --- Local Activation Source Parameters ---

0 commit comments

Comments
 (0)