diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 34fdf4c62..9ecb725c7 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -85,6 +85,7 @@ FluxKontextPipeline, FluxPipeline, ModelMixin, + QwenImagePipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, ) @@ -228,13 +229,18 @@ def infer_shapes_of_diffusers( if isinstance(model, (FluxPipeline, FluxKontextPipeline)): max_sequence_length_2 = input_shapes["text_encoder"].get("sequence_length", None) or max_sequence_length_2 - vae_encoder_num_channels = model.vae.config.in_channels - vae_decoder_num_channels = model.vae.config.latent_channels - vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8 + if isinstance(model, QwenImagePipeline): + vae_encoder_num_channels = 3 + vae_decoder_num_channels = model.vae.config.z_dim + vae_scale_factor = 2 ** len(model.vae.temperal_downsample) or 8 + else: + vae_encoder_num_channels = model.vae.config.in_channels + vae_decoder_num_channels = model.vae.config.latent_channels + vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8 height = input_shapes["unet_or_transformer"]["height"] - scaled_height = height // vae_scale_factor width = input_shapes["unet_or_transformer"]["width"] - scaled_width = width // vae_scale_factor + scaled_height = 2 * (int(height) // (vae_scale_factor * 2)) + scaled_width = 2 * (int(width) // (vae_scale_factor * 2)) # Text encoders if input_shapes["text_encoder"].get("sequence_length") is None or hasattr(model, "text_encoder_2"): diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index a2919424d..82f8b9025 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -50,14 +50,16 @@ DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyFluxKontextTransformerRotaryEmbGenerator, - DummyFluxTransformerRotaryEmbGenerator, DummyIPAdapterInputGenerator, DummyMaskedPosGenerator, + DummyQwenImageTransformerInputGenerator, DummyTimestepInputGenerator, + DummyTransformerRotaryEmbGenerator, WhisperDummyTextInputGenerator, get_checkpoint_shard_files, saved_model_in_temporary_directory, ) +from optimum.neuron.models.inference.flux.modeling_flux import NeuronFluxTransformer2DModel from .config import ( AudioNeuronConfig, @@ -72,6 +74,7 @@ FluxTransformerNeuronWrapper, NoCacheModelWrapper, PixartTransformerNeuronWrapper, + QwenImageTransformerNeuronWrapper, SentenceTransformersCLIPNeuronWrapper, SentenceTransformersTransformerNeuronWrapper, T5DecoderWrapper, @@ -793,77 +796,27 @@ def outputs(self) -> list[str]: return ["out_hidden_states"] -@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers") -class FluxTransformerNeuronConfig(VisionNeuronConfig): - ATOL_FOR_VALIDATION = 1e-3 - INPUT_ARGS = ( - "batch_size", - "sequence_length", - "num_channels", - "width", - "height", - "vae_scale_factor", - "encoder_hidden_size", - "rotary_axes_dim", - ) - MODEL_TYPE = "flux-transformer-2d" - CUSTOM_MODEL_WRAPPER = FluxTransformerNeuronWrapper - NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( - height="height", - width="width", - num_channels="in_channels", - vocab_size="attention_head_dim", - hidden_size="joint_attention_dim", - projection_size="pooled_projection_dim", - allow_new=True, - ) - - DUMMY_INPUT_GENERATOR_CLASSES = ( - DummyTransformerTimestepInputGenerator, - DummyFluxTransformerVisionInputGenerator, - DummyFluxTransformerTextInputGenerator, - ) - - @property - def inputs(self) -> list[str]: - common_inputs = [ - "hidden_states", - "encoder_hidden_states", - "pooled_projections", - "timestep", - # Q: Why `image_rotary_emb` but not `txt_ids` and `img_ids`? We compute the rotary positional embeddings in CPU to save Neuron memory. - # shape: [txt_ids.shape(0)+img_ids.shape(0), sum(axes_dim), 2] - "image_rotary_emb", - ] - if getattr(self._config, "guidance_embeds", False): - common_inputs.append("guidance") - - return common_inputs - - @property - def outputs(self) -> list[str]: - return ["out_hidden_states"] - +class BaseModelBuilderNeuronConfig(VisionNeuronConfig): + TENSOR_PARALLEL_MODEL = None + def patch_model_and_prepare_aliases(self, model_or_path, *args): base_model_instance = BaseModelInstance( partial(self.get_parallel_callable, self._config), input_output_aliases={}, ) return base_model_instance, None - + def get_parallel_callable(self, config): - from optimum.neuron.models.inference.flux.modeling_flux import NeuronFluxTransformer2DModel - # Parallelize Flux transformer with NxD backend modeling - valid_params = inspect.signature(NeuronFluxTransformer2DModel.__init__).parameters + valid_params = inspect.signature(self.TENSOR_PARALLEL_MODEL.__init__).parameters model_config = {k: v for k, v in config.items() if k in valid_params and k != "self"} - model = NeuronFluxTransformer2DModel(**model_config) + model = self.TENSOR_PARALLEL_MODEL(**model_config) model.eval() if self.float_dtype == torch.bfloat16: model.bfloat16() return model - + # Adapted from diffusers.models.modeling_utils.ModelMixin.from_pretrained, this is a helper function for loading checkpoints required by `ModelBuilder`. def get_checkpoint_loader_fn(self): is_local = os.path.isdir(self.pretrained_model_name_or_path) @@ -910,6 +863,120 @@ def get_checkpoint_loader_fn(self): return merged_state_dict + +@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers") +class FluxTransformerNeuronConfig(BaseModelBuilderNeuronConfig): + ATOL_FOR_VALIDATION = 1e-3 + INPUT_ARGS = ( + "batch_size", + "sequence_length", + "num_channels", + "width", + "height", + "vae_scale_factor", + "encoder_hidden_size", + "rotary_axes_dim", + ) + MODEL_TYPE = "flux-transformer-2d" + CUSTOM_MODEL_WRAPPER = FluxTransformerNeuronWrapper + TENSOR_PARALLEL_MODEL = NeuronFluxTransformer2DModel + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + height="height", + width="width", + num_channels="in_channels", + vocab_size="attention_head_dim", + hidden_size="joint_attention_dim", + projection_size="pooled_projection_dim", + allow_new=True, + ) + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTransformerTimestepInputGenerator, + DummyFluxTransformerVisionInputGenerator, + DummyFluxTransformerTextInputGenerator, + DummyTransformerRotaryEmbGenerator, + ) + + @property + def inputs(self) -> list[str]: + common_inputs = [ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep", + # Q: Why `image_rotary_emb` but not `txt_ids` and `img_ids`? We compute the rotary positional embeddings in CPU to save Neuron memory. + # shape: [txt_ids.shape(0)+img_ids.shape(0), sum(axes_dim), 2] + "image_rotary_emb", + ] + if getattr(self._config, "guidance_embeds", False): + common_inputs.append("guidance") + + return common_inputs + + @property + def outputs(self) -> list[str]: + return ["out_hidden_states"] + + @property + def is_flux_kontext(self) -> bool: + return self._is_flux_kontext + + @is_flux_kontext.setter + def is_flux_kontext(self, is_flux_kontext: bool): + self._is_flux_kontext = is_flux_kontext + + +@register_in_tasks_manager("qwen-image-transformer-2d", *["semantic-segmentation"], library_name="diffusers") +class QwenImageTransformerNeuronConfig(BaseModelBuilderNeuronConfig): + ATOL_FOR_VALIDATION = 1e-3 + INPUT_ARGS = ( + "batch_size", + "sequence_length", + "num_channels", + "width", + "height", + "vae_scale_factor", + "encoder_hidden_size", + "rotary_axes_dim", + ) + MODEL_TYPE = "qwen-image-transformer-2d" + CUSTOM_MODEL_WRAPPER = QwenImageTransformerNeuronWrapper + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + height="height", + width="width", + num_channels="in_channels", + vocab_size="attention_head_dim", + hidden_size="joint_attention_dim", + projection_size="pooled_projection_dim", + allow_new=True, + ) + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTransformerTimestepInputGenerator, + DummyQwenImageTransformerInputGenerator, + DummyTransformerRotaryEmbGenerator, + ) + + @property + def inputs(self) -> list[str]: + common_inputs = [ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep", + # Q: Why `image_rotary_emb` but not `txt_ids` and `img_ids`? We compute the rotary positional embeddings in CPU to save Neuron memory. + # shape: [txt_ids.shape(0)+img_ids.shape(0), sum(axes_dim), 2] + "image_rotary_emb", + ] + if getattr(self._config, "guidance_embeds", False): + common_inputs.append("guidance") + + return common_inputs + + @property + def outputs(self) -> list[str]: + return ["out_hidden_states"] + def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): if self.is_flux_kontext: self.DUMMY_INPUT_GENERATOR_CLASSES = self.DUMMY_INPUT_GENERATOR_CLASSES + ( @@ -921,7 +988,7 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): ) else: self.DUMMY_INPUT_GENERATOR_CLASSES = self.DUMMY_INPUT_GENERATOR_CLASSES + ( - DummyFluxTransformerRotaryEmbGenerator, + DummyTransformerRotaryEmbGenerator, ) dummy_inputs = super().generate_dummy_inputs(**kwargs) @@ -930,14 +997,6 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): else: return dummy_inputs - @property - def is_flux_kontext(self) -> bool: - return self._is_flux_kontext - - @is_flux_kontext.setter - def is_flux_kontext(self, is_flux_kontext: bool): - self._is_flux_kontext = is_flux_kontext - @register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers") class ControlNetNeuronConfig(VisionNeuronConfig): @@ -1038,6 +1097,26 @@ def patch_model_and_prepare_aliases( return super().patch_model_and_prepare_aliases(model=model, input_names=input_names, forward_with_tuple=True) +@register_in_tasks_manager("qwen-image-vae-encoder", *["semantic-segmentation"], library_name="diffusers") +class QwenImageVaeEncoderNeuronConfig(VaeEncoderNeuronConfig): + pass + + +@register_in_tasks_manager("qwen-image-vae-decoder", *["semantic-segmentation"], library_name="diffusers") +class QwenImageVaeDecoderNeuronConfig(VaeDecoderNeuronConfig): + pass + + +@register_in_tasks_manager("qwen2-5-vl", *["feature-extraction"], library_name="diffusers") +class Qwen2_5_VLEncoderNeuronConfig(TextEncoderNeuronConfig): + ATOL_FOR_VALIDATION = 1e-3 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @property + def inputs(self) -> list[str]: + return ["input_ids", "attention_mask"] + + class T5EncoderBaseNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index e5a1425b7..bdc5c6ce7 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -161,6 +161,44 @@ def forward(self, *inputs): return out_tuple +class QwenImageTransformerNeuronWrapper(torch.nn.Module): + def __init__(self, model, input_names: list[str], device: str = None): + super().__init__() + self.model = model + self.dtype = model.dtype + self.input_names = input_names + self.device = device + + def forward(self, *inputs): + if len(inputs) != len(self.input_names): + raise ValueError( + f"The model needs {len(self.input_names)} inputs: {self.input_names}." + f" But only {len(input)} inputs are passed." + ) + + ordered_inputs = dict(zip(self.input_names, inputs)) + + hidden_states = ordered_inputs.pop("hidden_states", None) + encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) + encoder_hidden_states_mask = ordered_inputs.pop("encoder_hidden_states_mask", None) + timestep = ordered_inputs.pop("timestep", None) + guidance = ordered_inputs.pop("guidance", None) + image_rotary_emb = ordered_inputs.pop("image_rotary_emb", None) + + out_tuple = self.model( + hidden_states=hidden_states, + timestep=timestep, + guidance=guidance, + encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=None, + return_dict=False, + ) + + return out_tuple + + class ControlNetNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: list[str], device: str = None): super().__init__() diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index 2eb8aeedd..175d28537 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -64,6 +64,7 @@ FluxKontextPipeline, FluxPipeline, ModelMixin, + QwenImagePipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, @@ -426,7 +427,7 @@ def get_submodels_for_export_diffusion( is_stable_diffusion_xl = isinstance( pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline) ) - is_flux = isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxKontextPipeline)) + is_flux_or_qwen = isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxKontextPipeline, QwenImagePipeline)) # Lora pipeline = _load_lora_weights_to_pipeline(pipeline=pipeline, lora_args=lora_args) @@ -480,7 +481,7 @@ def get_submodels_for_export_diffusion( # Diffusion transformer transformer = getattr(pipeline, "transformer", None) if transformer is not None: - if not is_flux: # The following will be handled by `ModelBuilder` if `is_flux`. + if not is_flux_or_qwen: # The following will be handled by `ModelBuilder` if `is_flux`. transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) transformer.config.text_encoder_projection_dim = projection_dim # apply optimized scaled_dot_product_attention @@ -579,6 +580,8 @@ def __getattr__(self, name): "UNet2DConditionModel": "unet", "PixArtTransformer2DModel": "pixart-transformer-2d", "T5EncoderModel": "t5", + "Qwen2_5_VLForConditionalGeneration": "qwen2-5-vl", + "QwenImageTransformer2DModel": "qwen-image-transformer-2d", } @@ -690,3 +693,5 @@ def get_encoder_decoder_models_for_export( ) return models_for_export + + diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 65228a30f..eb7cd46f4 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -82,6 +82,7 @@ LCMScheduler, PixArtAlphaPipeline, PixArtSigmaPipeline, + QwenImagePipeline, StableDiffusionControlNetPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, @@ -1628,3 +1629,8 @@ class NeuronFluxKontextPipeline(NeuronDiffusionPipelineBase, FluxKontextPipeline class NeuronFluxInpaintPipeline(NeuronDiffusionPipelineBase, FluxInpaintPipeline): main_input_name = "prompt" auto_model_class = FluxInpaintPipeline + + +class NeuronQwenImagePipeline(NeuronDiffusionPipelineBase, QwenImagePipeline): + main_input_name = "prompt" + auto_model_class = QwenImagePipeline diff --git a/optimum/neuron/models/inference/backend/modules/diffusion/__init__.py b/optimum/neuron/models/inference/backend/modules/diffusion/__init__.py new file mode 100644 index 000000000..6071c51e5 --- /dev/null +++ b/optimum/neuron/models/inference/backend/modules/diffusion/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/optimum/neuron/models/inference/flux/modules/activations.py b/optimum/neuron/models/inference/backend/modules/diffusion/activations.py similarity index 100% rename from optimum/neuron/models/inference/flux/modules/activations.py rename to optimum/neuron/models/inference/backend/modules/diffusion/activations.py diff --git a/optimum/neuron/models/inference/backend/modules/diffusion/attention.py b/optimum/neuron/models/inference/backend/modules/diffusion/attention.py new file mode 100644 index 000000000..af8ce1291 --- /dev/null +++ b/optimum/neuron/models/inference/backend/modules/diffusion/attention.py @@ -0,0 +1,457 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# This implementation is derived from the Diffusers library. +# The original codebase has been optimized and modified to achieve optimal performance +# characteristics when executed on Amazon Neuron devices. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size +from neuronxcc.nki.language import nc + +from ..rms_norm import NeuronRMSNorm +from .activations import NeuronGELU +from .embeddings import apply_rotary_emb + + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel # noqa: E402 +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel # noqa: E402 + +from torch_neuronx.xla_impl.ops import nki_jit # noqa: E402 + + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +def attention_wrapper_sharded_without_swap(query, key, value): + bs, n_head, q_len, d_head = query.shape # my change + q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) + k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) + v = value.clone().reshape((bs * n_head, q_len, d_head)) + + attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device) + + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + use_sharded_attention_kernel = vc_size == 2 + scale = 1 / math.sqrt(d_head) + + if use_sharded_attention_kernel: + grid = (nc(2),) + _flash_fwd_call[grid](q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + else: + _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + + attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) + + return attn_output + + +class NeuronAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: str | None = None, + qk_norm: str | None = None, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + scale_qk: bool = True, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + is_causal: bool = False, + pad_heads: bool = True, + reduce_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + + tp_degree = get_tensor_model_parallel_size() + if pad_heads: + self.heads = math.ceil(heads / tp_degree) * tp_degree + self.padded_inner_dim = dim_head * self.heads + # Only shard the heads, dim_head is unchanged. + # So that the original RMSNorm and apply_rotary_emb implementations still work + self.heads = self.heads // tp_degree + + self.added_kv_proj_dim = added_kv_proj_dim + + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "rms_norm": + self.norm_q = NeuronRMSNorm(dim_head, eps=eps) + self.norm_k = NeuronRMSNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'rms_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + else: + raise ValueError(f"unknown cross_attention_norm: {cross_attention_norm}. Should be None") + # breakpoint() + self.to_q = ColumnParallelLinear( + query_dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_k = ColumnParallelLinear( + self.cross_attention_dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_v = ColumnParallelLinear( + self.cross_attention_dim, + self.padded_inner_dim, + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.padded_inner_dim, + bias=added_proj_bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.add_v_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.padded_inner_dim, + bias=added_proj_bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + if self.context_pre_only is not None: + self.add_q_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.padded_inner_dim, + bias=added_proj_bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append( + RowParallelLinear( + self.padded_inner_dim, + self.out_dim, + bias=out_bias, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + ) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = RowParallelLinear( + self.padded_inner_dim, + self.out_context_dim, + bias=out_bias, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + else: + self.to_add_out = None + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "rms_norm": + self.norm_added_q = NeuronRMSNorm(dim_head, eps=eps) + self.norm_added_k = NeuronRMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + else: + self.norm_added_q = None + self.norm_added_k = None + + def forward( + self, + hidden_states: torch.Tensor, + image_rotary_emb: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, self.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, self.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, self.heads, head_dim + ).transpose(1, 2) + + if self.norm_added_q is not None: + encoder_hidden_states_query_proj = self.norm_added_q(encoder_hidden_states_query_proj) + if self.norm_added_k is not None: + encoder_hidden_states_key_proj = self.norm_added_k(encoder_hidden_states_key_proj) + + # attention + # the concatenation is happening along the sequence dimension after the transpose operation above. + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if attention_mask is not None: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + hidden_states = attention_wrapper_sharded_without_swap(query, key, value) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # splitting along the sequence dimension + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + if self.padded_inner_dim != self.out_dim: + # For the single transformer block, we don't have an output projection to remove the padded hidden dimension + # So we just cut them here base on the expected out_dim (they are all zeros) + return hidden_states[..., : self.out_dim] + else: + return hidden_states + + +class NeuronFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + reduce_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu-approximate": + act_fn = NeuronGELU( + dim, + inner_dim, + approximate="tanh", + bias=bias, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append( + RowParallelLinear(inner_dim, dim_out, bias=bias, input_is_parallel=True, reduce_dtype=reduce_dtype) + ) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/optimum/neuron/models/inference/flux/modules/embeddings.py b/optimum/neuron/models/inference/backend/modules/diffusion/embeddings.py similarity index 96% rename from optimum/neuron/models/inference/flux/modules/embeddings.py rename to optimum/neuron/models/inference/backend/modules/diffusion/embeddings.py index d4ca3c147..c87e7d489 100644 --- a/optimum/neuron/models/inference/flux/modules/embeddings.py +++ b/optimum/neuron/models/inference/backend/modules/diffusion/embeddings.py @@ -382,6 +382,22 @@ def forward(self, sample): return sample +class NeuronQwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = NeuronTimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() diff --git a/optimum/neuron/models/inference/flux/modules/normalization.py b/optimum/neuron/models/inference/backend/modules/diffusion/normalization.py similarity index 100% rename from optimum/neuron/models/inference/flux/modules/normalization.py rename to optimum/neuron/models/inference/backend/modules/diffusion/normalization.py diff --git a/optimum/neuron/models/inference/flux/modeling_flux.py b/optimum/neuron/models/inference/flux/modeling_flux.py index a5d384d23..e71283755 100644 --- a/optimum/neuron/models/inference/flux/modeling_flux.py +++ b/optimum/neuron/models/inference/flux/modeling_flux.py @@ -18,75 +18,35 @@ """ import logging -import math -import os from types import SimpleNamespace from typing import Any import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from neuronx_distributed.parallel_layers.layer_norm import LayerNorm from neuronx_distributed.parallel_layers.layers import ( ColumnParallelLinear, RowParallelLinear, ) from neuronx_distributed.parallel_layers.mappings import reduce_from_tensor_model_parallel_region -from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size -from ..backend.modules.rms_norm import NeuronRMSNorm -from .modules.activations import NeuronGELU -from .modules.embeddings import ( +from ..backend.modules.diffusion.attention import NeuronAttention, NeuronFeedForward +from ..backend.modules.diffusion.embeddings import ( FluxPosEmbed, NeuronCombinedTimestepGuidanceTextProjEmbeddings, NeuronCombinedTimestepTextProjEmbeddings, - apply_rotary_emb, ) -from .modules.normalization import ( +from ..backend.modules.diffusion.normalization import ( NeuronAdaLayerNormContinuous, NeuronAdaLayerNormZero, NeuronAdaLayerNormZeroSingle, ) -try: - from neuronxcc.nki._private_kernels.attention import attention_isa_kernel # noqa: E402 -except ImportError: - from neuronxcc.nki.kernels.attention import attention_isa_kernel # noqa: E402 - -from neuronxcc.nki.language import nc -from torch_neuronx.xla_impl.ops import nki_jit # noqa: E402 - - -_flash_fwd_call = nki_jit()(attention_isa_kernel) - logger = logging.getLogger(__name__) # pylint: disable=invalid-name -def attention_wrapper_sharded_without_swap(query, key, value): - bs, n_head, q_len, d_head = query.shape # my change - q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) - k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) - v = value.clone().reshape((bs * n_head, q_len, d_head)) - - attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device) - - vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) - use_sharded_attention_kernel = vc_size == 2 - scale = 1 / math.sqrt(d_head) - - if use_sharded_attention_kernel: - grid = (nc(2),) - _flash_fwd_call[grid](q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") - else: - _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") - - attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) - - return attn_output - - class NeuronFluxTransformer2DModel(torch.nn.Module): """ The Transformer model introduced in Flux. @@ -507,398 +467,3 @@ def forward( encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states - - -class NeuronFeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ - - def __init__( - self, - dim: int, - dim_out: int | None = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim=None, - bias: bool = True, - reduce_dtype: torch.dtype = torch.bfloat16, - ): - super().__init__() - if inner_dim is None: - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu-approximate": - act_fn = NeuronGELU( - dim, - inner_dim, - approximate="tanh", - bias=bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append( - RowParallelLinear(inner_dim, dim_out, bias=bias, input_is_parallel=True, reduce_dtype=reduce_dtype) - ) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - logger.warning( - "The `scale` argument is deprecated and will be ignored. Please remove it, " - "as passing it will raise an error in the future. `scale` should directly be " - "passed while calling the underlying pipeline component i.e., " - "via `cross_attention_kwargs`." - ) - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class NeuronAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): - The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): - The number of heads to use for multi-head attention. - kv_heads (`int`, *optional*, defaults to `None`): - The number of key and value heads to use for multi-head attention. Defaults to `heads`. If - `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi - Query Attention (MQA) otherwise GQA is used. - dim_head (`int`, *optional*, defaults to 64): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - upcast_attention (`bool`, *optional*, defaults to False): - Set to `True` to upcast the attention computation to `float32`. - upcast_softmax (`bool`, *optional*, defaults to False): - Set to `True` to upcast the softmax computation to `float32`. - cross_attention_norm (`str`, *optional*, defaults to `None`): - The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. - cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the group norm in the cross attention. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - norm_num_groups (`int`, *optional*, defaults to `None`): - The number of groups to use for the group norm in the attention. - spatial_norm_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the spatial normalization. - out_bias (`bool`, *optional*, defaults to `True`): - Set to `True` to use a bias in the output linear layer. - scale_qk (`bool`, *optional*, defaults to `True`): - Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. - eps (`float`, *optional*, defaults to 1e-5): - An additional value added to the denominator in group normalization that is used for numerical stability. - rescale_output_factor (`float`, *optional*, defaults to 1.0): - A factor to rescale the output by dividing it with this value. - residual_connection (`bool`, *optional*, defaults to `False`): - Set to `True` to add the residual connection to the output. - _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): - Set to `True` if the attention block is loaded from a deprecated state dict. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: str | None = None, - qk_norm: str | None = None, - added_kv_proj_dim: int | None = None, - added_proj_bias: bool | None = True, - out_bias: bool = True, - scale_qk: bool = True, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - out_dim: int = None, - out_context_dim: int = None, - context_pre_only=None, - pre_only=False, - is_causal: bool = False, - pad_heads: bool = True, - reduce_dtype: torch.dtype = torch.bfloat16, - ): - super().__init__() - - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.is_causal = is_causal - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - - tp_degree = get_tensor_model_parallel_size() - if pad_heads: - self.heads = math.ceil(heads / tp_degree) * tp_degree - self.padded_inner_dim = dim_head * self.heads - # Only shard the heads, dim_head is unchanged. - # So that the original RMSNorm and apply_rotary_emb implementations still work - self.heads = self.heads // tp_degree - - self.added_kv_proj_dim = added_kv_proj_dim - - self.group_norm = None - - self.spatial_norm = None - - if qk_norm is None: - self.norm_q = None - self.norm_k = None - elif qk_norm == "rms_norm": - self.norm_q = NeuronRMSNorm(dim_head, eps=eps) - self.norm_k = NeuronRMSNorm(dim_head, eps=eps) - else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'rms_norm'") - - if cross_attention_norm is None: - self.norm_cross = None - else: - raise ValueError(f"unknown cross_attention_norm: {cross_attention_norm}. Should be None") - # breakpoint() - self.to_q = ColumnParallelLinear( - query_dim, - self.padded_inner_dim, - bias=bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - self.to_k = ColumnParallelLinear( - self.cross_attention_dim, - self.padded_inner_dim, - bias=bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - self.to_v = ColumnParallelLinear( - self.cross_attention_dim, - self.padded_inner_dim, - bias=bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - - self.added_proj_bias = added_proj_bias - if self.added_kv_proj_dim is not None: - self.add_k_proj = ColumnParallelLinear( - added_kv_proj_dim, - self.padded_inner_dim, - bias=added_proj_bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - self.add_v_proj = ColumnParallelLinear( - added_kv_proj_dim, - self.padded_inner_dim, - bias=added_proj_bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - if self.context_pre_only is not None: - self.add_q_proj = ColumnParallelLinear( - added_kv_proj_dim, - self.padded_inner_dim, - bias=added_proj_bias, - gather_output=False, - reduce_dtype=reduce_dtype, - ) - else: - self.add_q_proj = None - self.add_k_proj = None - self.add_v_proj = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append( - RowParallelLinear( - self.padded_inner_dim, - self.out_dim, - bias=out_bias, - input_is_parallel=True, - reduce_dtype=reduce_dtype, - ) - ) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = RowParallelLinear( - self.padded_inner_dim, - self.out_context_dim, - bias=out_bias, - input_is_parallel=True, - reduce_dtype=reduce_dtype, - ) - else: - self.to_add_out = None - - if qk_norm is not None and added_kv_proj_dim is not None: - if qk_norm == "rms_norm": - self.norm_added_q = NeuronRMSNorm(dim_head, eps=eps) - self.norm_added_k = NeuronRMSNorm(dim_head, eps=eps) - else: - raise ValueError( - f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" - ) - else: - self.norm_added_q = None - self.norm_added_k = None - - def forward( - self, - hidden_states: torch.Tensor, - image_rotary_emb: torch.Tensor, - attention_mask: torch.Tensor | None = None, - encoder_hidden_states: torch.Tensor | None = None, - ) -> torch.Tensor: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = self.to_q(hidden_states) - key = self.to_k(hidden_states) - value = self.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, self.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, self.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, self.heads, head_dim - ).transpose(1, 2) - - if self.norm_added_q is not None: - encoder_hidden_states_query_proj = self.norm_added_q(encoder_hidden_states_query_proj) - if self.norm_added_k is not None: - encoder_hidden_states_key_proj = self.norm_added_k(encoder_hidden_states_key_proj) - - # attention - # the concatenation is happening along the sequence dimension after the transpose operation above. - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if attention_mask is not None: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - else: - hidden_states = attention_wrapper_sharded_without_swap(query, key, value) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - # splitting along the sequence dimension - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - encoder_hidden_states = self.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - if self.padded_inner_dim != self.out_dim: - # For the single transformer block, we don't have an output projection to remove the padded hidden dimension - # So we just cut them here base on the expected out_dim (they are all zeros) - return hidden_states[..., : self.out_dim] - else: - return hidden_states diff --git a/optimum/neuron/models/inference/qwenimage/modeling_qwenimage.py b/optimum/neuron/models/inference/qwenimage/modeling_qwenimage.py new file mode 100644 index 000000000..5efa8cd8c --- /dev/null +++ b/optimum/neuron/models/inference/qwenimage/modeling_qwenimage.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from `neuronx_distributed_inference/models/diffusers/flux/modeling_flux.py`. +""" + +import logging + +import torch +import torch.nn as nn + + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class NeuronQwenEmbedRope(nn.Module): + pass + + +class NeuronQwenDoubleStreamAttnProcessor2_0: + pass + + +class NeuronQwenImageTransformer2DModel(torch.nn.Module): + pass + + +class NeuronQwenImageTransformerBlock(nn.Module): + pass diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index df6444fde..a5b08e8df 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -57,6 +57,8 @@ "DummyControNetInputGenerator", "ASTDummyAudioInputGenerator", "DummyIPAdapterInputGenerator", + "DummyQwenImageTransformerInputGenerator", + "DummyTransformerRotaryEmbGenerator", "DummyFluxTransformerRotaryEmbGenerator", "DummyFluxKontextTransformerRotaryEmbGenerator", "DummyTimestepInputGenerator", @@ -130,10 +132,11 @@ DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyFluxKontextTransformerRotaryEmbGenerator, - DummyFluxTransformerRotaryEmbGenerator, DummyIPAdapterInputGenerator, DummyMaskedPosGenerator, + DummyQwenImageTransformerInputGenerator, DummyTimestepInputGenerator, + DummyTransformerRotaryEmbGenerator, WhisperDummyTextInputGenerator, ) from .misc import ( diff --git a/optimum/neuron/utils/input_generators.py b/optimum/neuron/utils/input_generators.py index 6e4cf77bd..f98483dcd 100644 --- a/optimum/neuron/utils/input_generators.py +++ b/optimum/neuron/utils/input_generators.py @@ -351,7 +351,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) -class DummyFluxTransformerRotaryEmbGenerator(DummyInputGenerator): +class DummyTransformerRotaryEmbGenerator(DummyInputGenerator): """ Generates dummy image rotary embedding. """ @@ -385,12 +385,32 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) -class DummyFluxKontextTransformerRotaryEmbGenerator(DummyInputGenerator): +class DummyFluxKontextTransformerRotaryEmbGenerator(DummyTransformerRotaryEmbGenerator): """ - Generates dummy image rotary embedding. + Generates dummy image rotary embedding for Flux Kontext. """ - SUPPORTED_INPUT_NAMES = ("image_rotary_emb",) + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "image_rotary_emb": + shape = [ + self.sequence_length + (self.height // 2) * (self.width // 2) * 2, + self.rotary_axes_dim, + 2, # freqs_cos, freqs_sin + ] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +class DummyQwenImageTransformerInputGenerator(DummyInputGenerator): + """ + Generates dummy inputs for Qwen Image transformer. + """ + + SUPPORTED_INPUT_NAMES = ( + "hidden_states", + "guidance", + "encoder_hidden_states", + "encoder_hidden_states_mask", + ) def __init__( self, @@ -410,10 +430,20 @@ def __init__( self.normalized_config = normalized_config def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - if input_name == "image_rotary_emb": - shape = [ - self.sequence_length + (self.height // 2) * (self.width // 2) * 2, - self.rotary_axes_dim, - 2, # freqs_cos, freqs_sin - ] + if input_name == "hidden_states": + shape = [] return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "guidance": + shape = [self.batch_size] + return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype) + elif input_name == "encoder_hidden_states": + shape = [] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "encoder_hidden_states_mask": + shape = [self.batch_size, self.sequence_length] + return self.random_mask_tensor( + shape=shape, + padding_side="right", + framework=framework, + dtype=int_dtype, + )