Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
FluxKontextPipeline,
FluxPipeline,
ModelMixin,
QwenImagePipeline,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
)
Expand Down Expand Up @@ -228,13 +229,18 @@ def infer_shapes_of_diffusers(
if isinstance(model, (FluxPipeline, FluxKontextPipeline)):
max_sequence_length_2 = input_shapes["text_encoder"].get("sequence_length", None) or max_sequence_length_2

vae_encoder_num_channels = model.vae.config.in_channels
vae_decoder_num_channels = model.vae.config.latent_channels
vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8
if isinstance(model, QwenImagePipeline):
vae_encoder_num_channels = 3
vae_decoder_num_channels = model.vae.config.z_dim
vae_scale_factor = 2 ** len(model.vae.temperal_downsample) or 8
else:
vae_encoder_num_channels = model.vae.config.in_channels
vae_decoder_num_channels = model.vae.config.latent_channels
vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8
height = input_shapes["unet_or_transformer"]["height"]
scaled_height = height // vae_scale_factor
width = input_shapes["unet_or_transformer"]["width"]
scaled_width = width // vae_scale_factor
scaled_height = 2 * (int(height) // (vae_scale_factor * 2))
scaled_width = 2 * (int(width) // (vae_scale_factor * 2))

# Text encoders
if input_shapes["text_encoder"].get("sequence_length") is None or hasattr(model, "text_encoder_2"):
Expand Down
213 changes: 146 additions & 67 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@
DummyBeamValuesGenerator,
DummyControNetInputGenerator,
DummyFluxKontextTransformerRotaryEmbGenerator,
DummyFluxTransformerRotaryEmbGenerator,
DummyIPAdapterInputGenerator,
DummyMaskedPosGenerator,
DummyQwenImageTransformerInputGenerator,
DummyTimestepInputGenerator,
DummyTransformerRotaryEmbGenerator,
WhisperDummyTextInputGenerator,
get_checkpoint_shard_files,
saved_model_in_temporary_directory,
)
from optimum.neuron.models.inference.flux.modeling_flux import NeuronFluxTransformer2DModel

from .config import (
AudioNeuronConfig,
Expand All @@ -72,6 +74,7 @@
FluxTransformerNeuronWrapper,
NoCacheModelWrapper,
PixartTransformerNeuronWrapper,
QwenImageTransformerNeuronWrapper,
SentenceTransformersCLIPNeuronWrapper,
SentenceTransformersTransformerNeuronWrapper,
T5DecoderWrapper,
Expand Down Expand Up @@ -793,77 +796,27 @@ def outputs(self) -> list[str]:
return ["out_hidden_states"]


@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class FluxTransformerNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
INPUT_ARGS = (
"batch_size",
"sequence_length",
"num_channels",
"width",
"height",
"vae_scale_factor",
"encoder_hidden_size",
"rotary_axes_dim",
)
MODEL_TYPE = "flux-transformer-2d"
CUSTOM_MODEL_WRAPPER = FluxTransformerNeuronWrapper
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
height="height",
width="width",
num_channels="in_channels",
vocab_size="attention_head_dim",
hidden_size="joint_attention_dim",
projection_size="pooled_projection_dim",
allow_new=True,
)

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyFluxTransformerTextInputGenerator,
)

@property
def inputs(self) -> list[str]:
common_inputs = [
"hidden_states",
"encoder_hidden_states",
"pooled_projections",
"timestep",
# Q: Why `image_rotary_emb` but not `txt_ids` and `img_ids`? We compute the rotary positional embeddings in CPU to save Neuron memory.
# shape: [txt_ids.shape(0)+img_ids.shape(0), sum(axes_dim), 2]
"image_rotary_emb",
]
if getattr(self._config, "guidance_embeds", False):
common_inputs.append("guidance")

return common_inputs

@property
def outputs(self) -> list[str]:
return ["out_hidden_states"]

class BaseModelBuilderNeuronConfig(VisionNeuronConfig):
TENSOR_PARALLEL_MODEL = None

def patch_model_and_prepare_aliases(self, model_or_path, *args):
base_model_instance = BaseModelInstance(
partial(self.get_parallel_callable, self._config),
input_output_aliases={},
)
return base_model_instance, None

def get_parallel_callable(self, config):
from optimum.neuron.models.inference.flux.modeling_flux import NeuronFluxTransformer2DModel

# Parallelize Flux transformer with NxD backend modeling
valid_params = inspect.signature(NeuronFluxTransformer2DModel.__init__).parameters
valid_params = inspect.signature(self.TENSOR_PARALLEL_MODEL.__init__).parameters
model_config = {k: v for k, v in config.items() if k in valid_params and k != "self"}
model = NeuronFluxTransformer2DModel(**model_config)
model = self.TENSOR_PARALLEL_MODEL(**model_config)
model.eval()
if self.float_dtype == torch.bfloat16:
model.bfloat16()

return model

# Adapted from diffusers.models.modeling_utils.ModelMixin.from_pretrained, this is a helper function for loading checkpoints required by `ModelBuilder`.
def get_checkpoint_loader_fn(self):
is_local = os.path.isdir(self.pretrained_model_name_or_path)
Expand Down Expand Up @@ -910,6 +863,120 @@ def get_checkpoint_loader_fn(self):

return merged_state_dict


@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class FluxTransformerNeuronConfig(BaseModelBuilderNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
INPUT_ARGS = (
"batch_size",
"sequence_length",
"num_channels",
"width",
"height",
"vae_scale_factor",
"encoder_hidden_size",
"rotary_axes_dim",
)
MODEL_TYPE = "flux-transformer-2d"
CUSTOM_MODEL_WRAPPER = FluxTransformerNeuronWrapper
TENSOR_PARALLEL_MODEL = NeuronFluxTransformer2DModel
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
height="height",
width="width",
num_channels="in_channels",
vocab_size="attention_head_dim",
hidden_size="joint_attention_dim",
projection_size="pooled_projection_dim",
allow_new=True,
)

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyFluxTransformerTextInputGenerator,
DummyTransformerRotaryEmbGenerator,
)

@property
def inputs(self) -> list[str]:
common_inputs = [
"hidden_states",
"encoder_hidden_states",
"pooled_projections",
"timestep",
# Q: Why `image_rotary_emb` but not `txt_ids` and `img_ids`? We compute the rotary positional embeddings in CPU to save Neuron memory.
# shape: [txt_ids.shape(0)+img_ids.shape(0), sum(axes_dim), 2]
"image_rotary_emb",
]
if getattr(self._config, "guidance_embeds", False):
common_inputs.append("guidance")

return common_inputs

@property
def outputs(self) -> list[str]:
return ["out_hidden_states"]

@property
def is_flux_kontext(self) -> bool:
return self._is_flux_kontext

@is_flux_kontext.setter
def is_flux_kontext(self, is_flux_kontext: bool):
self._is_flux_kontext = is_flux_kontext


@register_in_tasks_manager("qwen-image-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class QwenImageTransformerNeuronConfig(BaseModelBuilderNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
INPUT_ARGS = (
"batch_size",
"sequence_length",
"num_channels",
"width",
"height",
"vae_scale_factor",
"encoder_hidden_size",
"rotary_axes_dim",
)
MODEL_TYPE = "qwen-image-transformer-2d"
CUSTOM_MODEL_WRAPPER = QwenImageTransformerNeuronWrapper
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
height="height",
width="width",
num_channels="in_channels",
vocab_size="attention_head_dim",
hidden_size="joint_attention_dim",
projection_size="pooled_projection_dim",
allow_new=True,
)

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyQwenImageTransformerInputGenerator,
DummyTransformerRotaryEmbGenerator,
)

@property
def inputs(self) -> list[str]:
common_inputs = [
"hidden_states",
"encoder_hidden_states",
"pooled_projections",
"timestep",
# Q: Why `image_rotary_emb` but not `txt_ids` and `img_ids`? We compute the rotary positional embeddings in CPU to save Neuron memory.
# shape: [txt_ids.shape(0)+img_ids.shape(0), sum(axes_dim), 2]
"image_rotary_emb",
]
if getattr(self._config, "guidance_embeds", False):
common_inputs.append("guidance")

return common_inputs

@property
def outputs(self) -> list[str]:
return ["out_hidden_states"]

def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
if self.is_flux_kontext:
self.DUMMY_INPUT_GENERATOR_CLASSES = self.DUMMY_INPUT_GENERATOR_CLASSES + (
Expand All @@ -921,7 +988,7 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
)
else:
self.DUMMY_INPUT_GENERATOR_CLASSES = self.DUMMY_INPUT_GENERATOR_CLASSES + (
DummyFluxTransformerRotaryEmbGenerator,
DummyTransformerRotaryEmbGenerator,
)
dummy_inputs = super().generate_dummy_inputs(**kwargs)

Expand All @@ -930,14 +997,6 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
else:
return dummy_inputs

@property
def is_flux_kontext(self) -> bool:
return self._is_flux_kontext

@is_flux_kontext.setter
def is_flux_kontext(self, is_flux_kontext: bool):
self._is_flux_kontext = is_flux_kontext


@register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers")
class ControlNetNeuronConfig(VisionNeuronConfig):
Expand Down Expand Up @@ -1038,6 +1097,26 @@ def patch_model_and_prepare_aliases(
return super().patch_model_and_prepare_aliases(model=model, input_names=input_names, forward_with_tuple=True)


@register_in_tasks_manager("qwen-image-vae-encoder", *["semantic-segmentation"], library_name="diffusers")
class QwenImageVaeEncoderNeuronConfig(VaeEncoderNeuronConfig):
pass


@register_in_tasks_manager("qwen-image-vae-decoder", *["semantic-segmentation"], library_name="diffusers")
class QwenImageVaeDecoderNeuronConfig(VaeDecoderNeuronConfig):
pass


@register_in_tasks_manager("qwen2-5-vl", *["feature-extraction"], library_name="diffusers")
class Qwen2_5_VLEncoderNeuronConfig(TextEncoderNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

@property
def inputs(self) -> list[str]:
return ["input_ids", "attention_mask"]


class T5EncoderBaseNeuronConfig(TextSeq2SeqNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
Expand Down
38 changes: 38 additions & 0 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,44 @@ def forward(self, *inputs):
return out_tuple


class QwenImageTransformerNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: list[str], device: str = None):
super().__init__()
self.model = model
self.dtype = model.dtype
self.input_names = input_names
self.device = device

def forward(self, *inputs):
if len(inputs) != len(self.input_names):
raise ValueError(
f"The model needs {len(self.input_names)} inputs: {self.input_names}."
f" But only {len(input)} inputs are passed."
)

ordered_inputs = dict(zip(self.input_names, inputs))

hidden_states = ordered_inputs.pop("hidden_states", None)
encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None)
encoder_hidden_states_mask = ordered_inputs.pop("encoder_hidden_states_mask", None)
timestep = ordered_inputs.pop("timestep", None)
guidance = ordered_inputs.pop("guidance", None)
image_rotary_emb = ordered_inputs.pop("image_rotary_emb", None)

out_tuple = self.model(
hidden_states=hidden_states,
timestep=timestep,
guidance=guidance,
encoder_hidden_states_mask=encoder_hidden_states_mask,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=None,
return_dict=False,
)

return out_tuple


class ControlNetNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: list[str], device: str = None):
super().__init__()
Expand Down
Loading
Loading