Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion loader_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import folder_paths
import safetensors.torch
import torch
try:
import torch._dynamo
DYNAMO = True
except ImportError:
DYNAMO = False

from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
Expand All @@ -18,6 +24,42 @@
from .nodes_registry import comfy_node
from .vae import LTXVVAE

@comfy_node(name="LTXTricksTorchCompileSettings")
class LTXTricksTorchCompileSettings:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"backend": (["inductor", "cudagraphs"], {"default": "inductor", "tooltip": "Backend for torch.compile"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode in compilation"}),
"mode": (
["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"],
{"default": "default", "tooltip": "Compilation mode optimization level"}
),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic shapes in compilation"}),
"dynamo_cache_size_limit": (
"INT",
{"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "Set torch._dynamo.config.cache_size_limit"}
),
},
}

RETURN_TYPES = ("COMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "compile_settings"
CATEGORY = "lightricks/LTXV"
DESCRIPTION = "torch compile settings for LTXV models. Requires Triton and torch 2.5.0+ recommended."

def compile_settings(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit):
compile_args = {
"backend": backend,
"fullgraph": fullgraph,
"mode": mode,
"dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
}

return (compile_args, )

@comfy_node(name="LTXVLoader")
class LTXVLoader:
Expand All @@ -30,6 +72,9 @@ def INPUT_TYPES(s):
{"tooltip": "The name of the checkpoint (model) to load."},
),
"dtype": (["bfloat16", "float32"], {"default": "bfloat16"}),
},
"optional": {
"torch_compile_args": ("COMPILEARGS", {"tooltip": "Optional compile arguments for torch.compile"})
}
}

Expand All @@ -40,7 +85,7 @@ def INPUT_TYPES(s):
TITLE = "LTXV Loader"
OUTPUT_NODE = False

def load(self, ckpt_name, dtype):
def load(self, ckpt_name, dtype, torch_compile_args=None):
dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
Expand Down Expand Up @@ -71,6 +116,26 @@ def load(self, ckpt_name, dtype):
dtype=dtype_map[dtype],
config=unet_config,
)

# If compile arguments are provided, apply torch.compile
if torch_compile_args is not None:
backend = torch_compile_args.get("backend", "inductor")
fullgraph = torch_compile_args.get("fullgraph", False)
mode = torch_compile_args.get("mode", "default")
dynamic = torch_compile_args.get("dynamic", False)
dynamo_cache_size_limit = torch_compile_args.get("dynamo_cache_size_limit", 64)
if DYNAMO:
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit

# Compile the diffusion_model with provided arguments
model.model.diffusion_model = torch.compile(
model.model.diffusion_model,
backend=backend,
fullgraph=fullgraph,
mode=mode,
dynamic=dynamic,
)

return (model, vae)

def _load_vae(self, weights, config=None):
Expand Down