Skip to content

Conversation

cdoern
Copy link
Contributor

@cdoern cdoern commented Jun 10, 2025

Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization
into classes:

  1. Model
  2. Accelerator
  3. Checkpointer

The Checkpointer class introduces a unified approach to our various checkpointing techniques. A user can pass in their checkpointing style (full_state or hf_format), and the checkpointer, via checkpointer.checkpoint, will save the model using the selected method and other techniques (LoRA).

This PR adds the new class and unit tests for the class

see previous PRs #572 and #594

note: this is probably the last of these large refactor for now with subsequent smaller followup PRs for cleanup.

@mergify mergify bot added testing Relates to testing ci-failure labels Jun 10, 2025
@cdoern cdoern force-pushed the refactor-checkpoint branch from 5dca802 to 276e9b3 Compare June 10, 2025 15:15
@mergify mergify bot removed the ci-failure label Jun 10, 2025
model_conf from `AutoConfig` has some key info we need in the checkpointer. Associate it with the model class and its subclasses

Signed-off-by: Charlie Doern <[email protected]>
Copy link

E2E (NVIDIA L40S x4) (python 3.11) workflow launched on this PR: View run

Copy link

e2e workflow succeeded on this PR: View run, congrats!

Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments below.

I'm also assuming that the contents of the specific methods (like save_fsdp_lora_model) is largely unchanged. Is that correct?

print("[None] Skipping checkpointing.")

# pylint: disable=unused-argument
def save_fsdp_lora_model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially for a future PR, but I think it would be cleaner to have a base Checkpointer abstract class and then have FSDPLoRACheckpointer, HFFormatAccelerateCheckpointer, etc. subclasses which each implement their own checkpoint method. Instead of doing our own custom routing with self._checkpoint_fn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a similar argument for Model class before... Class hierarchies are exactly meant for such scenarios.

accelerator,
samples_seen,
is_lora=bool(args.lora_r),
checkpointer.save_hf_format_accelerate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be checkpointer.checkpoint()

@@ -50,11 +50,13 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
model_conf=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is model_conf for? Currently we only seem to be using it to access model_conf.model_type. Could we just store model_type instead?

Also currently none of those accesses of model_conf.model_type check that model_conf is not None before trying to access the attribute, so this will raise an error if model_conf is ever actually None (it's current default value).

checkpointer = Checkpointer(
strategy=strategy, model=m, optimizer=optimizer, accelerator=accelerator
)
checkpointer.load_latest_full_state(Path(args.output_dir))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can args.output_dir be set in the Checkpointer.__init__? It seems like we're currently passing it into every load/checkpoint function but it doesn't seem like it's changing values (or should change values) currently.

def save_fsdp_lora_model(
self,
output_dir: Path,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove kwargs that are not used; transform those used into specific arguments for the information that needs to be passed (with proper names, types etc.)

model: Model,
optimizer: torch.optim.Optimizer,
accelerator: Accelerator,
strategy="all",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it enum

# pylint: disable=unused-argument
def save_full_state(
self,
output_dir,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define all args' types.

print("[None] Skipping checkpointing.")

# pylint: disable=unused-argument
def save_fsdp_lora_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a similar argument for Model class before... Class hierarchies are exactly meant for such scenarios.


@pytest.fixture
def mock_accelerator():
accelerator = MagicMock(spec=Accelerator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of mocks, you could introduce a new subclass for TestAccelerator that would "do nothing" / "do bare minimum" for test purposes. Same for the rest. Why do we have to have mocks just to create an object? Are init methods destructive / invasive? (Maybe it should be fixed then - it should be generally safe / cheap to create objects.)

Copy link

This pull request has been automatically marked as stale because it has not had activity within 90 days. It will be automatically closed if no further activity occurs within 30 days.

@github-actions github-actions bot added the stale label Sep 10, 2025
Copy link
Contributor

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. @cdoern please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
@github-actions github-actions bot removed the stale label Sep 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase testing Relates to testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants