diff --git a/loader_node.py b/loader_node.py index 303b759..d7fa0e1 100644 --- a/loader_node.py +++ b/loader_node.py @@ -7,6 +7,12 @@ import folder_paths import safetensors.torch import torch +try: + import torch._dynamo + DYNAMO = True +except ImportError: + DYNAMO = False + from ltx_video.models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) @@ -18,6 +24,42 @@ from .nodes_registry import comfy_node from .vae import LTXVVAE +@comfy_node(name="LTXTricksTorchCompileSettings") +class LTXTricksTorchCompileSettings: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "backend": (["inductor", "cudagraphs"], {"default": "inductor", "tooltip": "Backend for torch.compile"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode in compilation"}), + "mode": ( + ["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], + {"default": "default", "tooltip": "Compilation mode optimization level"} + ), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic shapes in compilation"}), + "dynamo_cache_size_limit": ( + "INT", + {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "Set torch._dynamo.config.cache_size_limit"} + ), + }, + } + + RETURN_TYPES = ("COMPILEARGS",) + RETURN_NAMES = ("torch_compile_args",) + FUNCTION = "compile_settings" + CATEGORY = "lightricks/LTXV" + DESCRIPTION = "torch compile settings for LTXV models. Requires Triton and torch 2.5.0+ recommended." + + def compile_settings(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit): + compile_args = { + "backend": backend, + "fullgraph": fullgraph, + "mode": mode, + "dynamic": dynamic, + "dynamo_cache_size_limit": dynamo_cache_size_limit, + } + + return (compile_args, ) @comfy_node(name="LTXVLoader") class LTXVLoader: @@ -30,6 +72,9 @@ def INPUT_TYPES(s): {"tooltip": "The name of the checkpoint (model) to load."}, ), "dtype": (["bfloat16", "float32"], {"default": "bfloat16"}), + }, + "optional": { + "torch_compile_args": ("COMPILEARGS", {"tooltip": "Optional compile arguments for torch.compile"}) } } @@ -40,7 +85,7 @@ def INPUT_TYPES(s): TITLE = "LTXV Loader" OUTPUT_NODE = False - def load(self, ckpt_name, dtype): + def load(self, ckpt_name, dtype, torch_compile_args=None): dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.unet_offload_device() @@ -71,6 +116,26 @@ def load(self, ckpt_name, dtype): dtype=dtype_map[dtype], config=unet_config, ) + + # If compile arguments are provided, apply torch.compile + if torch_compile_args is not None: + backend = torch_compile_args.get("backend", "inductor") + fullgraph = torch_compile_args.get("fullgraph", False) + mode = torch_compile_args.get("mode", "default") + dynamic = torch_compile_args.get("dynamic", False) + dynamo_cache_size_limit = torch_compile_args.get("dynamo_cache_size_limit", 64) + if DYNAMO: + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit + + # Compile the diffusion_model with provided arguments + model.model.diffusion_model = torch.compile( + model.model.diffusion_model, + backend=backend, + fullgraph=fullgraph, + mode=mode, + dynamic=dynamic, + ) + return (model, vae) def _load_vae(self, weights, config=None):