diff --git a/paddleformers/quantization/quantization_config.py b/paddleformers/quantization/quantization_config.py index df403b873e..a78e0b0f92 100644 --- a/paddleformers/quantization/quantization_config.py +++ b/paddleformers/quantization/quantization_config.py @@ -16,7 +16,10 @@ import json from dataclasses import dataclass -from paddle.nn.quant.quantized_linear import _get_arch_info +try: + from paddle.nn.quant.quantized_linear import _get_arch_info +except: + _get_arch_info = None quant_inference_mapping = {"avg": "abs_max", "abs_max_channel_wise": "abs_max_channel_wise", "abs_max": "abs_max"} fp8_format_mapping = { @@ -114,7 +117,8 @@ def __init__( f"weight_quantize_algo:{weight_quantize_algo} not in supported list ['weight_only_int8', 'weight_only_int4', 'llm.int8', 'a8w8', 'nf4', 'fp4']" ) if ( - (isinstance(weight_quantize_algo, dict) and "fp8linear" in weight_quantize_algo) + _get_arch_info is not None + and (isinstance(weight_quantize_algo, dict) and "fp8linear" in weight_quantize_algo) or weight_quantize_algo == "fp8linear" ) and _get_arch_info() not in [89, 90]: raise RuntimeError("fp8Linear is only supported on NVIDIA Hopper GPUs.") diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index b91717c012..12b7f0f509 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -37,12 +37,6 @@ "AddedToken", "normalize_chars", "tokenize_special_chars,convert_to_unicode,", - "PreTrainedTokenizer", - ], - "tokenizer_utils_base": [ - "PaddingStrategy", - "TextInput", - "TensorType", ], "attention_utils": ["create_bigbird_rand_mask_idx_list"], "tensor_parallel_utils": [], @@ -88,6 +82,11 @@ "AutoDiscriminator", "AutoModelForConditionalGeneration", ], + "tokenizer_utils_base": [ + "PaddingStrategy", + "TextInput", + "TensorType", + ], "auto.processing": ["AutoProcessor"], "auto.tokenizer": ["AutoTokenizer"], "deepseek_v2.configuration": ["DeepseekV2Config"], @@ -320,6 +319,8 @@ "Qwen3MoePretrainingCriterion", ], "qwen3_moe.modeling_pp": ["Qwen3MoeForCausalLMPipe"], + "ernie4_5vl.tokenizer": ["Ernie4_5_VLTokenizer"], + "ernie4_5vl": [], "bert": [], "llama": [], "qwen2": [], @@ -346,6 +347,7 @@ tokenize_special_chars, convert_to_unicode, ) + from .tokenizer_utils_fast import PretrainedTokenizerFast from .processing_utils import ProcessorMixin from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .image_processing_utils import ImageProcessingMixin diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py index f8ef19493a..6bca0681f5 100644 --- a/paddleformers/transformers/auto/configuration.py +++ b/paddleformers/transformers/auto/configuration.py @@ -25,7 +25,6 @@ from ...utils.import_utils import import_module from ...utils.log import logger from ..configuration_utils import PretrainedConfig -from ..model_utils import PretrainedModel __all__ = [ "AutoConfig", @@ -222,6 +221,7 @@ def _get_config_class_from_config( if config_class is not PretrainedConfig: model_config_class = config_class return model_config_class + from ..model_utils import PretrainedModel assert inspect.isclass(model_class) and issubclass( model_class, PretrainedModel diff --git a/paddleformers/transformers/auto/tokenizer.py b/paddleformers/transformers/auto/tokenizer.py index 37d3df4408..86d21aefec 100644 --- a/paddleformers/transformers/auto/tokenizer.py +++ b/paddleformers/transformers/auto/tokenizer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import json + +# import logging import os import warnings from typing import Dict, Optional, Union @@ -140,13 +142,36 @@ def get_paddleformers_tokenizer_config( return result -class AutoTokenizer(hf.AutoTokenizer): +def _bind_paddle_mixin_if_available(tokenizer_class): """ - Adapted from transformers.AutoTokenizer.from_pretrained with modifications: - 1. Added get_paddleformers_tokenizer_config() to extend tokenizer_config.json download source - 2. Explicitly binds PaddleTokenizerMixin to the tokenizer class before final instantiation + Bind the PaddleTokenizerMixin if Paddle is available; otherwise, return the original class. - Note: This extends HuggingFace's standard tokenizer loading logic with PaddlePaddle integration. + Args: + tokenizer_class: The original tokenizer class. + + Returns: + The tokenizer class bound with PaddleTokenizerMixin, or the original class. + """ + try: + return type(tokenizer_class.__name__, (PaddleTokenizerMixin, tokenizer_class), {}) + except: + return tokenizer_class + + +class AutoTokenizer(hf.AutoTokenizer): + """ + Smart AutoTokenizer that automatically adapts based on available dependencies: + + 1. **Multi-source support**: Supports HuggingFace, PaddleFormers, and other download sources + 2. **Conditional Paddle integration**: Automatically detects PaddlePaddle availability + 3. **Fallback compatibility**: Works seamlessly with or without Paddle dependencies + 4. **Enhanced functionality**: Extends HuggingFace's standard tokenizer loading logic + + Features: + - Automatically binds PaddleTokenizerMixin when PaddlePaddle is available + - Falls back to pure Transformers mode when PaddlePaddle is not available + - Maintains full compatibility with all HuggingFace tokenizers + - Supports custom download sources through environment variables """ @classmethod @@ -201,7 +226,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): if tokenizer_class is None: raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.") - tokenizer_class = type(tokenizer_class.__name__, (PaddleTokenizerMixin, tokenizer_class), {}) + + # Bind PaddleTokenizerMixin + tokenizer_class = _bind_paddle_mixin_if_available(tokenizer_class) return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) # Next, let's try to use the tokenizer_config file to get the tokenizer class. @@ -268,6 +295,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None ) ) + if has_remote_code: if use_fast and tokenizer_auto_map[1] is not None: class_ref = tokenizer_auto_map[1] @@ -285,11 +313,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) tokenizer_class.register_for_auto_class() - tokenizer_class = type(tokenizer_class.__name__, (PaddleTokenizerMixin, tokenizer_class), {}) + + # Bind PaddleTokenizerMixin + tokenizer_class = _bind_paddle_mixin_if_available(tokenizer_class) return tokenizer_class.from_pretrained( pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs ) elif config_tokenizer_class is not None: + tokenizer_class = None if use_fast and not config_tokenizer_class.endswith("Fast"): tokenizer_class_candidate = f"{config_tokenizer_class}Fast" @@ -301,7 +332,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): raise ValueError( f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported." ) - tokenizer_class = type(tokenizer_class.__name__, (PaddleTokenizerMixin, tokenizer_class), {}) + + # Bind PaddleTokenizerMixin + tokenizer_class = _bind_paddle_mixin_if_available(tokenizer_class) return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) # Otherwise we have to be creative. @@ -321,15 +354,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): - tokenizer_class_fast = type( - tokenizer_class_fast.__name__, (PaddleTokenizerMixin, tokenizer_class_fast), {} - ) + # Bind PaddleTokenizerMixin + tokenizer_class_fast = _bind_paddle_mixin_if_available(tokenizer_class_fast) return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) else: if tokenizer_class_py is not None: - tokenizer_class_py = type( - tokenizer_class_py.__name__, (PaddleTokenizerMixin, tokenizer_class_py), {} - ) + # Bind PaddleTokenizerMixin + tokenizer_class_py = _bind_paddle_mixin_if_available(tokenizer_class_py) return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) else: raise ValueError( diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index c4cbf229e1..4117091455 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -29,7 +29,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -import paddle from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError @@ -581,6 +580,8 @@ def __init__(self, **kwargs): if "torch_dtype" in kwargs: self.dtype = kwargs.pop("torch_dtype") else: + import paddle + self.dtype = kwargs.pop("dtype", paddle.get_default_dtype()) # Is decoder is used in encoder-decoder models to differentiate encoder from decoder diff --git a/paddleformers/transformers/ernie4_5vl/__init__.py b/paddleformers/transformers/ernie4_5vl/__init__.py new file mode 100644 index 0000000000..c59ba4ea9e --- /dev/null +++ b/paddleformers/transformers/ernie4_5vl/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import TYPE_CHECKING + +from ...utils.lazy_import import _LazyModule + +import_structure = { + "tokenizer": ["Ernie4_5_VLTokenizer"], + "configuration": [ + "Ernie4_5_VLMoEConfig", + ], +} + +if TYPE_CHECKING: + from .configuration import * + from .tokenizer import Ernie4_5_VLTokenizer +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/paddleformers/transformers/ernie4_5vl/configuration.py b/paddleformers/transformers/ernie4_5vl/configuration.py new file mode 100644 index 0000000000..5733cdc597 --- /dev/null +++ b/paddleformers/transformers/ernie4_5vl/configuration.py @@ -0,0 +1,634 @@ +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ernie model configuration""" +import copy +from typing import Union + +from transformers import PretrainedConfig + +from ...utils.log import logger + +__all__ = [ + "ERNIE_PRETRAINED_INIT_CONFIGURATION", + "Ernie4_5_Config", + "Ernie4_5_MoEConfig", + "Ernie4_5_VLMoEConfig", +] + + +class DFNRopeVisionTransformerConfig(PretrainedConfig): + """ + Configuration class for DFNRopeVisionTransformer model. + This class inherits from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + """ + + model_type = "DFNRope_vision_transformer" + base_model_tp_plan = {} + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + attn_implementation="eager", # new added + pp_data_balance=False, + recompute=False, + attn_sep=False, + vit_first_fwd_bsz=128, + vit_num_recompute_layers=10000, + **kwargs, + ): + """ + Initialize DFNRopeVisionTransformer model configuration with default or specified parameters. + + Args: + depth (int): Number of transformer layers in the model. + embed_dim (int): Dimensionality of the embedding layer. + hidden_size (int): Dimensionality of the feedforward network. + hidden_act (str): Activation function for the feedforward network. + mlp_ratio (float): Ratio between the number of input features and + the number of output features in the feedforward network. + num_heads (int): Number of attention heads in each attention layer. + in_channels (int): Number of channels in the input image. + patch_size (int): + Size of patches in the input image. Defaults to 14. + spatial_merge_size (int): + Spatial merge size for the spatial transformer module. Defaults to 2. + attn_implementation (str): Attention implementation type. Defaults to "eager". + pp_data_balance (bool): Whether to balance data during preprocessing. Defaults to False. + recompute (bool): Whether to use recompute. Defaults to False. + attn_sep (bool): Whether to separate attention computation into two stages. Defaults to False. + vit_first_fwd_bsz (int): First forward batch size for ViT. Defaults to 128. + vit_num_recompute_layers (int): Number of recomputed layers for ViT. Defaults to + """ + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.attn_implementation = attn_implementation + self.pp_data_balance = pp_data_balance + self.recompute = recompute + self.attn_sep = attn_sep + self.vit_first_fwd_bsz = vit_first_fwd_bsz + self.vit_num_recompute_layers = vit_num_recompute_layers + + def get(self, key, default=None): + """get config value by key""" + if hasattr(self, key): + return getattr(self, key) + else: + return default + + +ERNIE_PRETRAINED_INIT_CONFIGURATION = { + "ernie/tiny-random-ernie": { + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "ernie", + "num_attention_heads": 2, + "num_hidden_layers": 2, + "rms_norm_eps": 1e-06, + "vocab_size": 32000, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + "use_cache": False, + "recompute": False, + "use_flash_attn": True, + "use_pure_fp16": False, + }, +} + + +class Ernie4_5_Config(PretrainedConfig): + """ + Configuration class for ERNIE model. + + This class stores the configuration of an ERNIE model, defining the model architecture. + It inherits from PretrainedConfig and can be used to control model outputs. + """ + + model_type = "ernie" + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + base_model_tp_plan = {} + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + max_position_embeddings=32768, + num_hidden_layers=2, + num_attention_heads=2, + initializer_range=0.02, # no use + rms_norm_eps=1e-6, + use_cache=False, + use_flash_attention=True, + use_sparse_flash_attn=True, + use_var_len_flash_attn=False, + recompute=False, + recompute_granularity="core_attn", + recompute_use_reentrant=False, + use_rmsnorm=True, + fuse_rms_norm=False, + fuse_ln=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + fuse_swiglu=False, + use_bias=False, + rope_theta=10000, + fuse_rope=False, + fuse_softmax_mask=False, + use_fast_ln=False, + weight_share_add_bias=True, + fuse_linear=False, + max_sequence_length=None, + ignored_index=-100, + add_tail_layers=False, + use_recompute_lm_head=False, + use_recompute_loss_fn=False, + refined_recompute=dict(), + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + compression_ratio: float = 1.0, + num_key_value_heads=None, + use_sparse_head_and_loss_fn=False, + micro_batch_size=-1, + use_ep_comm_overlap=False, + use_fused_head_and_loss_fn=False, + token_balance_loss=False, + token_balance_seqlen=False, # calculated based on batchsize and seqlen + cachekv_quant: bool = False, + pp_seg_method="layer:ErnieDecoderLayer|EmptyLayer", + **kwargs, + ): + """ + Initialize ERNIE model configuration with default or specified parameters. + + Args: + vocab_size (int): Size of the vocabulary (number of unique tokens) + hidden_size (int): Dimensionality of the encoder layers and the pooler layer + intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer + max_position_embeddings (int): Maximum sequence length the model can handle + num_hidden_layers (int): Number of hidden layers in the Transformer encoder + num_attention_heads (int): Number of attention heads for each attention layer + rms_norm_eps (float): The epsilon used by the RMS normalization layers + use_cache (bool): Whether to use caching for faster generation (decoding) + use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation + use_sparse_flash_attn (bool): Whether to use sparse FlashAttention + use_var_len_flash_attn (bool): Whether to use variable-length FlashAttention + recompute (bool): Whether to use gradient checkpointing to save memory + recompute_granularity (str): Granularity of recomputation ("core_attn", "full", etc.) + recompute_use_reentrant (bool): Whether to use reentrant checkpointing + use_rmsnorm (bool): Whether to use RMSNorm instead of LayerNorm + fuse_rms_norm (bool): Whether to fuse RMSNorm operations for optimization + fuse_ln (bool): Whether to fuse LayerNorm operations + pad_token_id (int): Token ID used for padding sequences + bos_token_id (int): Token ID used for beginning-of-sequence + eos_token_id (int): Token ID used for end-of-sequence + fuse_swiglu (bool): Whether to fuse SwiGLU operations + use_bias (bool): Whether to use bias terms in linear layers + rope_theta (float): The base period of the RoPE embeddings + fuse_rope (bool): Whether to fuse RoPE operations + use_fast_ln (bool): Whether to use optimized LayerNorm implementation + weight_share_add_bias (bool): Whether to share bias weights in certain layers + fuse_linear (bool): Whether to fuse linear operations + max_sequence_length (int): Maximum sequence length for positional embeddings + ignored_index (int): Target value that is ignored during loss computation + add_tail_layers (bool): Whether to add additional layers at the end + use_recompute_lm_head (bool): Whether to recompute gradients for language model head + use_recompute_loss_fn (bool): Whether to recompute gradients for loss function + refined_recompute (dict): Dictionary specifying refined recomputation settings + attention_probs_dropout_prob (float): Dropout probability for attention weights + hidden_dropout_prob (float): Dropout probability for hidden layers + compression_ratio (float): Ratio for KV cache compression (1.0 = no compression) + num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention) + use_sparse_head_and_loss_fn (bool): Whether to use sparse attention head and loss function + micro_batch_size (int): Size of micro batches (-1 for automatic) + use_ep_comm_overlap (bool): Whether to overlap communication with computation + use_fused_head_loss_fn (bool): Whether to use fused head and loss function + token_balance_loss (bool): Whether to balance loss by token count + token_balance_seqlen (bool): Whether to balance sequence lengths + cachekv_quant (bool): Whether to quantize key-value cache + pp_seg_method (str): Method for pipeline parallel segmentation + **kwargs: Additional keyword arguments passed to parent class + """ + + # Set default for tied embeddings if not specified. + if "tie_word_embeddings" not in kwargs: + kwargs["tie_word_embeddings"] = False + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.recompute = recompute + self.recompute_granularity = recompute_granularity + self.use_flash_attention = use_flash_attention + self.use_sparse_flash_attn = use_sparse_flash_attn + self.recompute_use_reentrant = recompute_use_reentrant + self.use_var_len_flash_attn = use_var_len_flash_attn + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.fuse_swiglu = fuse_swiglu + self.fuse_rms_norm = fuse_rms_norm + self.fuse_ln = fuse_ln + self.use_rmsnorm = use_rmsnorm + self.micro_batch_size = micro_batch_size + + self.max_sequence_length = max_sequence_length + self.use_bias = use_bias + self.weight_share_add_bias = weight_share_add_bias + self.rope_theta = rope_theta + self.fuse_rope = fuse_rope + self.fuse_softmax_mask = fuse_softmax_mask + self.use_fast_ln = use_fast_ln + + self.fuse_linear = fuse_linear + self.ignored_index = ignored_index + self.add_tail_layers = add_tail_layers + self.use_recompute_lm_head = use_recompute_lm_head + self.use_recompute_loss_fn = use_recompute_loss_fn + + self.refined_recompute = refined_recompute + self.skip_recompute_ops = dict() + """ + `refined_recompute` is a dictionary that specifies fine-grained gradient recomputation settings, + which currently only takes effect in Pipeline Parallel (PP) mode. + + In PP mode, this dictionary populates `self.skip_recompute_ops` with the following structure: + - Key (`op_name`): The operation name to configure, with possible values: + * "mlp_row_ln" - MLP row-wise layer normalization + * "flash_attn" - Flash attention operation + * "attention_row_ln" - Attention row-wise layer normalization + * "attention_column_ln" - Attention column-wise layer normalization + * "mlp_column_ln" - MLP column-wise layer normalization + + - Value (`skip_num`): Controls how many times to skip recomputation: + * 0: Never skip recomputation (minimum memory usage) + * -1: Always skip recomputation (maximum memory usage) + * [0,1,...,12]: Skip recomputation for specified number of times + * ≥12: Equivalent to -1 (always skip recomputation) + + This allows precise control over memory/computation tradeoffs for different operations. + """ + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.compression_ratio = compression_ratio + self.num_key_value_heads = num_key_value_heads + self.use_sparse_head_and_loss_fn = use_sparse_head_and_loss_fn + self.use_ep_comm_overlap = use_ep_comm_overlap + self.use_fused_head_and_loss_fn = use_fused_head_and_loss_fn + self.token_balance_loss = token_balance_loss + self.token_balance_seqlen = token_balance_seqlen + self.cachekv_quant = cachekv_quant + self.pp_seg_method = pp_seg_method + + def get(self, key, default=None): + """get config value by key""" + if hasattr(self, key): + return getattr(self, key) + else: + return default + + +class Ernie4_5_MoEConfig(Ernie4_5_Config): + r""" + Configuration class for ErnieMoE model architecture. + + This class stores the configuration for a [`~ErnieModel`] and is used to instantiate + an ErnieMoE model according to the specified arguments. Inherits from [`PretrainedConfig`] + and can control model outputs. + + Attributes: + Inherits all attributes from Ernie4_5_Config and adds MoE-specific configurations. + """ + + model_type = "ernie" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + base_model_tp_plan = {} + + def __init__( + self, + moe_num_experts: Union[int, list] = 0, + use_recompute_moe=False, + moe_capacity=(), + moe_layer_interval=2, + moe_layer_start_index=0, + moe_layer_end_index=-1, + moe_aux_loss_lambda=1e-2, + moe_z_loss_lambda=1e-4, + moe_orthogonal_loss_lambda=1e-2, + sinkhorn_2gate=True, + sinkhorn_temp=3e-2, + global_aux_loss=False, + moe_dropout_prob=0.0, + moe_group="world", + moe_gate="top2", + moe_intermediate_size: Union[int, list] = 0, + moe_num_shared_experts: int = 0, + moe_reverse_token_drop: bool = False, + moe_gate_act: str = "softmax", + moe_norm_gate_logits=True, + moe_all_to_all_dropout: float = 0.0, + moe_k=2, + moe_use_aux_free: bool = False, + # `moe_group_experts` must be used with `moe_use_hard_gate=True` + moe_group_experts: bool = False, + moe_group_orthogonal_loss: bool = True, + enable_delay_scale_loss: bool = True, + num_acc_steps: int = 1, + fuse_gate_detach_matmul: bool = False, + dpo_config=None, + moe_multimodal_dispatch_use_allgather: str = "", + moe_use_hard_gate=False, + moe_dense_experts_token_type_id=3, + **kwargs, + ): + """ + Initialize ErnieMoE configuration with MoE-specific parameters. + + Args: + moe_num_experts: Number of experts in MoE layers + use_recompute_moe: Whether to use recomputation for MoE layers + moe_capacity: Capacity configuration for MoE layers + moe_layer_interval: Interval between MoE layers + moe_layer_start_index: Starting layer index for MoE + moe_layer_end_index: Ending layer index for MoE (-1 means last layer) + moe_aux_loss_lambda: Weight for auxiliary loss + moe_z_loss_lambda: Weight for z-loss + moe_orthogonal_loss_lambda: Weight for orthogonal loss + sinkhorn_2gate: Whether to use sinkhorn 2-gate routing + sinkhorn_temp: Temperature for sinkhorn routing + global_aux_loss: Whether to use global auxiliary loss + moe_dropout_prob: Dropout probability for MoE layers + moe_group: Group configuration for MoE experts + moe_gate: Type of gating mechanism ('top2', etc.) + moe_intermediate_size: Intermediate size for MoE layers + moe_num_shared_experts: Number of shared experts + moe_reverse_token_drop: Whether to use reverse token dropping + moe_gate_act: Activation function for gating + moe_norm_gate_logits: Whether to normalize gate logits + moe_all_to_all_dropout: Dropout for all-to-all communication + moe_k: Number of experts to route to + moe_use_aux_free: Whether to use auxiliary-free routing + moe_group_experts: Whether to group experts (requires hard gating) + moe_group_orthogonal_loss: Whether to use group orthogonal loss + enable_delay_scale_loss: Whether to enable delayed loss scaling + num_acc_steps: Number of accumulation steps + fuse_gate_detach_matmul: Whether to fuse gate detach matmul + **kwargs: Additional base model configuration parameters + + Note: + When use_recompute_moe is True, recompute_granularity will be changed to full_attn. + """ + + if use_recompute_moe: + logger.warning( + "set `use_recompute_moe`=True, disabling `recompute_granularity=full`, change to full_attn." + ) + if kwargs["recompute"] and kwargs["recompute_granularity"] == "full": + kwargs["recompute_granularity"] = "full_attn" + super().__init__(**kwargs) + + self.moe_num_experts = moe_num_experts + self.use_recompute_moe = use_recompute_moe + self.moe_capacity = moe_capacity + self.moe_aux_loss_lambda = moe_aux_loss_lambda + self.moe_z_loss_lambda = moe_z_loss_lambda + self.moe_orthogonal_loss_lambda = moe_orthogonal_loss_lambda + self.global_aux_loss = global_aux_loss + self.sinkhorn_2gate = sinkhorn_2gate + self.sinkhorn_temp = sinkhorn_temp + self.moe_layer_interval = moe_layer_interval + self.moe_dropout_prob = moe_dropout_prob + self.moe_group = moe_group + self.moe_gate = moe_gate + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_shared_experts = moe_num_shared_experts + self.moe_reverse_token_drop = moe_reverse_token_drop + self.moe_k = moe_k + self.moe_all_to_all_dropout = moe_all_to_all_dropout + self.moe_group_experts = moe_group_experts + self.moe_group_orthogonal_loss = moe_group_orthogonal_loss + self.enable_delay_scale_loss = enable_delay_scale_loss + self.num_acc_steps = num_acc_steps + self.moe_layer_start_index = moe_layer_start_index + self.moe_layer_end_index = self.num_hidden_layers - 1 if moe_layer_end_index == -1 else moe_layer_end_index + self.moe_gate_act = moe_gate_act + self.moe_norm_gate_logits = moe_norm_gate_logits + self.moe_use_aux_free = moe_use_aux_free + self.fuse_gate_detach_matmul = fuse_gate_detach_matmul + self.dpo_config = dpo_config + self.moe_multimodal_dispatch_use_allgather = moe_multimodal_dispatch_use_allgather + self.moe_use_hard_gate = moe_use_hard_gate + self.moe_dense_experts_token_type_id = moe_dense_experts_token_type_id + + @property + def multimodel_experts(self) -> bool: + """multimodel experts.""" + return isinstance(self.moe_num_experts, (tuple, list)) and len(self.moe_num_experts) > 1 + + @property + def use_moe(self) -> bool: + """ + Check if model is using MoE architecture. + + Returns: + bool: True if moe_num_experts > 0, False otherwise + """ + return self.moe_num_experts > 0 + + +class Ernie4_5_VLMoEConfig(Ernie4_5_MoEConfig): + """ + This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Ernie model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ErnieModel`] or [`~TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + """ + + model_type = "ernie4_5_moe_vl" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + base_model_tp_plan = { + "model.layers.*.self_attn.q_proj": "colwise_rep", + "model.layers.*.self_attn.k_proj": "colwise_rep", + "model.layers.*.self_attn.v_proj": "colwise_rep", + "model.layers.*.self_attn.o_proj": "rowwise_rep", + "model.layers.*.mlp.experts.*.gate_proj": "colwise", + "model.layers.*.mlp.experts.*.up_proj": "colwise", + "model.layers.*.mlp.experts.*.down_proj": "rowwise", + "model.layers.*.mlp_text.experts.*.gate_proj": "colwise", + "model.layers.*.mlp_text.experts.*.up_proj": "colwise", + "model.layers.*.mlp_text.experts.*.down_proj": "rowwise", + "model.layers.*.mlp.gate_proj": "colwise", + "model.layers.*.mlp.up_proj": "colwise", + "model.layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + vision_config=None, + im_patch_id=None, + pixel_hidden_size=None, + modality_detach=False, + temporal_conv_size=2, + spatial_conv_size=2, + mm_vocab_size=0, # vocab for mm specialtokens + max_text_id=None, + use_temporal_conv=True, + moe_use_size_all2all=False, + moe_num_attn_experts=False, + moe_dense_experts_token_type_id: int = 3, + moe_use_hard_gate: bool = True, + moe_fuse_experts: bool = False, + moe_use_token_type_bias: bool = False, + disable_ffn_model_parallel=False, + fuse_attn_ffn=True, + rope_3d=True, + freq_allocation=20, + using_precision_check=False, + use_recompute_resampler=False, + resampler_fuse_rms_norm=False, + moe_layer_feed_fake_token=False, + tensor_parallel_degree=1, + **kwargs, + ): + super().__init__(**kwargs) + if isinstance(vision_config, dict): + self.vision_config = DFNRopeVisionTransformerConfig(**vision_config) + else: + self.vision_config = DFNRopeVisionTransformerConfig() + self.im_patch_id = im_patch_id + self.pixel_hidden_size = pixel_hidden_size + self.modality_detach = modality_detach + self.temporal_conv_size = temporal_conv_size + self.spatial_conv_size = spatial_conv_size + self.mm_vocab_size = mm_vocab_size + self.max_text_id = max_text_id + self.use_temporal_conv = use_temporal_conv + + self.moe_use_size_all2all = moe_use_size_all2all + self.moe_num_attn_experts = moe_num_attn_experts + self.moe_dense_experts_token_type_id = moe_dense_experts_token_type_id + self.moe_use_hard_gate = moe_use_hard_gate + self.moe_fuse_experts = moe_fuse_experts + self.moe_use_token_type_bias = moe_use_token_type_bias + self.disable_ffn_model_parallel = disable_ffn_model_parallel + + self.fuse_attn_ffn = fuse_attn_ffn + self.rope_3d = rope_3d + self.freq_allocation = freq_allocation + self.using_precision_check = using_precision_check + self.use_recompute_resampler = use_recompute_resampler + self.resampler_fuse_rms_norm = resampler_fuse_rms_norm + self.moe_layer_feed_fake_token = moe_layer_feed_fake_token + + self.tensor_parallel_degree = tensor_parallel_degree + + @property + def multimodel_experts(self) -> bool: + """Check if model is using more than 1 multimodel experts.""" + return isinstance(self.moe_num_experts, (tuple, list)) and len(self.moe_num_experts) > 1 + + @property + def use_moe(self) -> bool: + """ + Check if model is using MoE architecture. + + Returns: + bool: True if moe_num_experts > 0, False otherwise + """ + return sum(self.moe_num_experts) > 0 if self.multimodel_experts else self.moe_num_experts > 0 + + def to_dict(self, saving_file=False): + """to_dict""" + output = copy.deepcopy(self.__dict__) + if self.vision_config: + output["vision_config"] = ( + self.vision_config.to_dict() + if isinstance(self.vision_config, (DFNRopeVisionTransformerConfig)) + else self.vision_config + ) + + output["model_type"] = self.__class__.model_type + return output diff --git a/paddleformers/transformers/ernie4_5vl/tokenizer.py b/paddleformers/transformers/ernie4_5vl/tokenizer.py new file mode 100644 index 0000000000..679f76c801 --- /dev/null +++ b/paddleformers/transformers/ernie4_5vl/tokenizer.py @@ -0,0 +1,484 @@ +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for Ernie4_5_VL.""" + +import os + +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +# Fix relative import issues +from ..tokenizer_utils import PaddleTokenizerMixin + +logger = logging.get_logger(__name__) + +__all__ = [ + "Ernie4_5_VLTokenizer", +] + + +class Ernie4_5_VLTokenizer(PaddleTokenizerMixin, PreTrainedTokenizer): + """ + ERNIE 4.5 VL Tokenizer based on SentencePiece with smart tensor support. + + This tokenizer is designed for multimodal inputs including text, images, and videos. + It inherits from PaddleTokenizerMixin for smart tensor conversion and PreTrainedTokenizer + for standard HuggingFace functionality. + + Features: + - SentencePiece-based tokenization + - Multimodal token support (image/video placeholders) + - Smart tensor conversion (Paddle/PyTorch/NumPy) + - Chat template support + - Enhanced encoding methods including encode_chat_input + """ + + vocab_files_names = { + "vocab_file": "tokenizer.model", + } + + model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] + padding_side = "right" + + # ERNIE 4.5 VL specific tokens + SPECIAL_TOKENS = { + "space": "", + "gender": "", + "image_start": "<|im_start|>", + "image_end": "<|im_end|>", + "image_placeholder": "<|IMAGE_PLACEHOLDER|>", + "video_start": "<|VIDEO_START|>", + "video_end": "<|VIDEO_END|>", + "video_placeholder": "<|VIDEO_PLACEHOLDER|>", + } + + def __init__( + self, + vocab_file=None, + bos_token="", + cls_token="", + eos_token="", + mask_token="", + pad_token="", + sep_token="", + unk_token="", + additional_special_tokens=None, + **kwargs, + ): + """ + Initialize the ERNIE 4.5 VL Tokenizer. + + Args: + vocab_file: Path to the SentencePiece vocabulary file + bos_token: Beginning of sequence token + cls_token: Classification token + eos_token: End of sequence token + mask_token: Masking token + pad_token: Padding token + sep_token: Separator token + unk_token: Unknown token + additional_special_tokens: Additional special tokens + **kwargs: Additional keyword arguments + """ + # Handle possible parameter renaming + if vocab_file is None: + for key in ["tokenizer_file", "model_file", "spm_file"]: + if key in kwargs: + vocab_file = kwargs.pop(key) + break + + if vocab_file is None: + raise ValueError( + "vocab_file is required. Please provide the path to the tokenizer.model file " + "or ensure it's available in the model directory." + ) + + # Initialize SentencePiece model first, as parent __init__ might call get_vocab() + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + + # Set default additional special tokens + if additional_special_tokens is None: + additional_special_tokens = [self.SPECIAL_TOKENS["space"], self.SPECIAL_TOKENS["gender"]] + + # Call PreTrainedTokenizer's __init__ + super().__init__( + bos_token=bos_token, + cls_token=cls_token, + eos_token=eos_token, + mask_token=mask_token, + pad_token=pad_token, + sep_token=sep_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + # Save initialization parameters for save_pretrained + self.init_kwargs = { + "vocab_file": vocab_file, + "bos_token": bos_token, + "cls_token": cls_token, + "eos_token": eos_token, + "mask_token": mask_token, + "pad_token": pad_token, + "sep_token": sep_token, + "unk_token": unk_token, + "additional_special_tokens": additional_special_tokens, + } + self.init_kwargs.update(kwargs) + + # Set default attributes + self.split_special_tokens = False + + # Set internal attributes + self._bos_token = bos_token + self._eos_token = eos_token + self._pad_token = pad_token + self._unk_token = unk_token + self._cls_token = cls_token + self._sep_token = sep_token + self._mask_token = mask_token + self._additional_special_tokens = additional_special_tokens + + # Set context manager attributes + self._in_target_context_manager = False + + # Set chat template + self.chat_template = None + + # Set initialization inputs + self.init_inputs = [] + + # Set added tokens decoder + self._added_tokens_decoder = {} + + # ==================== Pickle Support ==================== + + def __getstate__(self): + """Support for pickle serialization.""" + state = self.__dict__.copy() + del state["sp_model"] + return state + + def __setstate__(self, state): + """Support for pickle deserialization.""" + self.__dict__.update(state) + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + # ==================== Special Token Properties ==================== + + @property + def space_token(self): + """Return the space token.""" + return self.SPECIAL_TOKENS["space"] + + @property + def space_token_id(self): + """Return the ID of the space token.""" + return self.sp_model.piece_to_id(self.SPECIAL_TOKENS["space"]) + + @property + def gend_token(self): + """Return the gender token.""" + return self.SPECIAL_TOKENS["gender"] + + @property + def gend_token_id(self): + """Return the ID of the gender token.""" + return self.sp_model.piece_to_id(self.SPECIAL_TOKENS["gender"]) + + @property + def im_start_id(self): + """Return the ID of the image start token.""" + return self.sp_model.piece_to_id(self.SPECIAL_TOKENS["image_start"]) + + @property + def im_end_id(self): + """Return the ID of the image end token.""" + return self.sp_model.piece_to_id(self.SPECIAL_TOKENS["image_end"]) + + @property + def image_placeholder_id(self): + """Return the ID of the image placeholder token.""" + return self.sp_model.piece_to_id(self.SPECIAL_TOKENS["image_placeholder"]) + + @property + def video_placeholder_id(self): + """Return the ID of the video placeholder token.""" + return self.sp_model.piece_to_id(self.SPECIAL_TOKENS["video_placeholder"]) + + # ==================== Core Tokenization Methods ==================== + + @property + def vocab_size(self): + """Return the size of the vocabulary.""" + return self.sp_model.vocab_size() + + def get_vocab(self): + """Return the vocabulary as a dictionary mapping tokens to IDs.""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Tokenize the input text into pieces.""" + return self.sp_model.encode_as_pieces(text) + + def _convert_token_to_id(self, token): + """Convert a token to its corresponding ID.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, id): + """Convert an ID to its corresponding token.""" + return self.sp_model.id_to_piece(id) + + def convert_tokens_to_string(self, tokens): + """Convert a sequence of tokens back to a string.""" + current_sub_tokens = [] + out_string = "" + + for token in tokens: + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + # ==================== Core Tokenization Methods ==================== + + def convert_tokens_to_ids(self, tokens): + """Convert a sequence of tokens to a sequence of IDs.""" + if isinstance(tokens, str): + return self._convert_token_to_id(tokens) + elif isinstance(tokens, list): + return [self._convert_token_to_id(token) for token in tokens] + else: + raise TypeError(f"tokens should be a string or a list of strings, got {type(tokens)}") + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Convert a sequence of IDs to a sequence of tokens.""" + if isinstance(ids, int): + return self._convert_id_to_token(ids) + elif isinstance(ids, list): + return [self._convert_id_to_token(id) for id in ids] + else: + raise TypeError(f"ids should be an int or a list of ints, got {type(ids)}") + + def _add_tokens(self, new_tokens, special_tokens=False): + """Add new tokens to the tokenizer.""" + if not special_tokens and new_tokens: + raise ValueError("Adding regular tokens is not supported for SentencePiece tokenizers") + return 0 + + def tokenize(self, text, **kwargs): + """Tokenize the input text.""" + return self._tokenize(text) + + def _decode(self, token_ids, skip_special_tokens=False, **kwargs): + """Decode a sequence of token IDs to a string.""" + if isinstance(token_ids, int): + token_ids = [token_ids] + + # Filter out special tokens if requested + if skip_special_tokens: + token_ids = [i for i in token_ids if i not in self.all_special_ids] + + # Convert IDs to tokens + tokens = [self._convert_id_to_token(id) for id in token_ids] + + # Convert tokens to string + return self.convert_tokens_to_string(tokens) + + def decode(self, token_ids, skip_special_tokens=False, **kwargs): + """Decode a sequence of token IDs to a string.""" + return self._decode(token_ids, skip_special_tokens=skip_special_tokens, **kwargs) + + def encode(self, text, **kwargs): + """Encode text to token IDs.""" + tokens = self._tokenize(text) + return self.convert_tokens_to_ids(tokens) + + def encode_plus(self, text, **kwargs): + """Encode text to token IDs with additional information.""" + # Get basic encoding + input_ids = self.encode(text, **kwargs) + + # Create attention mask + attention_mask = [1] * len(input_ids) + + # Handle padding if requested + padding = kwargs.get("padding", False) + if padding: + max_length = kwargs.get("max_length", None) + if max_length is not None: + # Pad to max_length + while len(input_ids) < max_length: + input_ids.append(self.pad_token_id) + attention_mask.append(0) + + result = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Add token_type_ids if requested + if kwargs.get("return_token_type_ids", False): + result["token_type_ids"] = [0] * len(input_ids) + + return result + + # ==================== Enhanced Encoding Methods ==================== + + def encode_chat_input(self, messages, add_generation_prompt=False, **kwargs): + """ + Encode chat messages into token IDs. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys + add_generation_prompt: Whether to add a generation prompt + **kwargs: Additional arguments for encoding + + Returns: + Encoded token IDs or encoding result + """ + # Apply chat template + chat_text = self.apply_chat_template(messages, add_generation_prompt=add_generation_prompt) + + # Encode text + if kwargs.get("return_tensors") is not None: + return self(chat_text, **kwargs) + else: + return self.encode(chat_text, **kwargs) + + def encode_multimodal(self, text, images=None, videos=None, **kwargs): + """ + Encode multimodal input including text, images, and videos. + + Args: + text: Input text + images: List of image paths or URLs + videos: List of video paths or URLs + **kwargs: Additional encoding arguments + + Returns: + Encoding result with multimodal tokens + """ + # Build multimodal text + multimodal_text = text + + if images: + for i, image in enumerate(images): + multimodal_text += f" {self.SPECIAL_TOKENS['image_start']} {self.SPECIAL_TOKENS['image_placeholder']} {self.SPECIAL_TOKENS['image_end']}" + + if videos: + for i, video in enumerate(videos): + multimodal_text += f" {self.SPECIAL_TOKENS['video_start']} {self.SPECIAL_TOKENS['video_placeholder']} {self.SPECIAL_TOKENS['video_end']}" + + # Encode multimodal text + return self(multimodal_text, **kwargs) + + # ==================== Additional Utility Methods ==================== + + def get_special_tokens_mask(self, token_ids, already_has_special_tokens=False): + """Get the special tokens mask.""" + if already_has_special_tokens: + return [0] * len(token_ids) + + special_tokens_mask = [0] * len(token_ids) + for i, token_id in enumerate(token_ids): + if token_id in self.all_special_ids: + special_tokens_mask[i] = 1 + + return special_tokens_mask + + def num_special_tokens_to_add(self, pair=False): + """Return the number of special tokens that will be added.""" + return 0 + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """Build model inputs from a sequence or a pair of sequences.""" + if token_ids_1 is None: + return token_ids_0 + else: + return token_ids_0 + [self.sep_token_id] + token_ids_1 + + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + """Create token type IDs from a sequence or a pair of sequences.""" + if token_ids_1 is None: + return [0] * len(token_ids_0) + else: + return [0] * len(token_ids_0) + [1] * (len(token_ids_1) + 1) + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + """Prepare text for tokenization.""" + return text, kwargs + + # ==================== Utility Methods ==================== + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory: The directory to save the vocabulary to + filename_prefix: Prefix to add to the filename + + Returns: + Paths to the saved files + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + # Construct output vocabulary file path + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"], + ) + + # Copy or create vocabulary file + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + import shutil + + shutil.copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + """ + Load tokenizer from pretrained model + + Args: + pretrained_model_name_or_path: Model name or path + *args: Other positional arguments + **kwargs: Other keyword arguments + + Returns: + Loaded tokenizer instance + """ + # Call parent's from_pretrained method to handle all file downloads and path resolution + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) diff --git a/paddleformers/transformers/llama/tokenizer.py b/paddleformers/transformers/llama/tokenizer.py index 11490d3b26..42bd963eb7 100644 --- a/paddleformers/transformers/llama/tokenizer.py +++ b/paddleformers/transformers/llama/tokenizer.py @@ -25,10 +25,12 @@ import unicodedata from typing import Collection, Dict, List, Set, Tuple, Union +from transformers.tokenization_utils import PreTrainedTokenizer + from ...utils.import_utils import is_tiktoken_available from ...utils.log import logger from ..legacy.tokenizer_utils_base import AddedToken -from ..tokenizer_utils import PretrainedTokenizer +from ..tokenizer_utils import PaddleTokenizerMixin VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} @@ -55,7 +57,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: } -class Llama3Tokenizer(PretrainedTokenizer): +class Llama3Tokenizer(PaddleTokenizerMixin, PreTrainedTokenizer): """QWen tokenizer.""" model_input_names = ["input_ids", "attention_mask", "position_ids"] diff --git a/paddleformers/transformers/tokenizer_utils.py b/paddleformers/transformers/tokenizer_utils.py index f01aa85915..25066c81a4 100644 --- a/paddleformers/transformers/tokenizer_utils.py +++ b/paddleformers/transformers/tokenizer_utils.py @@ -19,9 +19,17 @@ import os import re from functools import wraps -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import Any, Dict, List, Union from transformers import BatchEncoding + +try: + from .legacy.tokenizer_utils import PretrainedTokenizer + + PreTrainedTokenizer = PretrainedTokenizer +except: + from transformers.tokenization_utils import PreTrainedTokenizer + from transformers.tokenization_utils_base import ( ADDED_TOKENS_FILE, CHAT_TEMPLATE_FILE, @@ -34,17 +42,6 @@ from ..utils.download import DownloadSource, resolve_file_path from ..utils.log import logger -if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer - -# legacy PretrainedTokenizer, which is different from huggingface PreTrainedTokenizer -try: - from .legacy.tokenizer_utils import PretrainedTokenizer - - PretrainedTokenizer = PretrainedTokenizer -except: - pass - class TensorType(ExplicitEnum): """ @@ -213,13 +210,9 @@ def from_pretrained( pass except Exception as e: raise e - # 获得cache_dir的目录 - for file_id, file_path in resolved_vocab_files.items(): - if resolved_vocab_files[file_id] is not None: - cache_dir = os.path.dirname(resolved_vocab_files[file_id]) - break if not any(key in resolved_vocab_files for key in cls.vocab_files_names.keys()): + hf_link = f"https://huggingface.co/{pretrained_model_name_or_path}" modelscope_link = f"https://modelscope.cn/models/{pretrained_model_name_or_path}" encoded_model_name = pretrained_model_name_or_path.replace("/", "%2F") @@ -347,9 +340,9 @@ def _encode_chat_inputs( ans.append(ans_roundi) non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans) - assert len(non_learnable_parts) == len( - ans - ), f"Get non_learnable_parts len: {len(non_learnable_parts)}, but ans len: {len(ans)}." + # assert len(non_learnable_parts) == len( + # ans + # ), f"Get non_learnable_parts len: {len(non_learnable_parts)}, but ans len: {len(ans)}." conversation_ids = [] for i in range(len(non_learnable_parts)): diff --git a/paddleformers/utils/download/download.py b/paddleformers/utils/download/download.py index e17a1ab806..fc393dad8a 100644 --- a/paddleformers/utils/download/download.py +++ b/paddleformers/utils/download/download.py @@ -28,7 +28,11 @@ RepositoryNotFoundError, RevisionNotFoundError, ) -from paddle import __version__ + +try: + from paddle import __version__ +except: + __version__ = None from requests import HTTPError from ..log import logger diff --git a/paddleformers/utils/downloader.py b/paddleformers/utils/downloader.py index 838c20fbe1..1e0162a894 100644 --- a/paddleformers/utils/downloader.py +++ b/paddleformers/utils/downloader.py @@ -24,7 +24,6 @@ from collections import OrderedDict from typing import Optional, Union -import paddle.distributed as dist import requests from filelock import FileLock from huggingface_hub import get_hf_file_metadata, hf_hub_url @@ -530,6 +529,7 @@ def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_devic Returns: str: path to load static model """ + # TODO: This function will be removed in a future release. try: base_dir, target_dir = os.path.split(os.path.normpath(local_path)) if not os.path.exists(base_dir) and base_dir != "": @@ -543,6 +543,9 @@ def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_devic persistent_path = local_path device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + + import paddle.distributed as dist + if device_id != 0: logger.info("Waiting local process 0...") dist.barrier() diff --git a/tests/transformers/ernie4_5vl/__init__.py b/tests/transformers/ernie4_5vl/__init__.py new file mode 100644 index 0000000000..595add0aed --- /dev/null +++ b/tests/transformers/ernie4_5vl/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/ernie4_5vl/test_tokenizer.py b/tests/transformers/ernie4_5vl/test_tokenizer.py new file mode 100644 index 0000000000..8c479a655d --- /dev/null +++ b/tests/transformers/ernie4_5vl/test_tokenizer.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import unittest + +from paddleformers.transformers import Ernie4_5_VLTokenizer + +HUB_FLAG = "huggingface" + + +class Ernie4_5_VL_TokenizationTest(unittest.TestCase): + from_pretrained_id = "baidu/ERNIE-4.5-VL-28B-A3B-Base-PT" + tokenizer_class = Ernie4_5_VLTokenizer + test_slow_tokenizer = True + space_between_special_tokens = False + from_pretrained_kwargs = None + test_seq2seq = False + + def setUp(self): + self.test_dirs = ["./slow_tokenizer"] + for test_dir in self.test_dirs: + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + + def tearDown(self): + for test_dir in self.test_dirs: + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + + def test_slow_tokenizer_from_pretrained(self): + tokenizer = Ernie4_5_VLTokenizer.from_pretrained( + self.from_pretrained_id, download_hub=HUB_FLAG, trust_remote_code=True + ) + self.assertTrue(tokenizer is not None) + + def test_slow_tokenizer_save_pretrained(self): + tokenizer = Ernie4_5_VLTokenizer.from_pretrained( + self.from_pretrained_id, download_hub=HUB_FLAG, trust_remote_code=True + ) + tokenizer.model_max_length = 512 + tokenizer.save_pretrained("./slow_tokenizer") + self.assertTrue(os.path.exists("./slow_tokenizer/tokenizer_config.json")) + + def test_tokenize(self): + tokenizer = Ernie4_5_VLTokenizer.from_pretrained( + self.from_pretrained_id, download_hub=HUB_FLAG, trust_remote_code=True + ) + text = "hello world, this is a tokenizer test" + output_dict = tokenizer(text) + decode_text = tokenizer.decode(output_dict["input_ids"], skip_special_tokens=True) + self.assertEqual(text, decode_text) + + +# Ernie4_5_VL_TokenizationTest().test_slow_tokenizer_from_pretrained() diff --git a/tests/transformers/qwen/test_tokenizer.py b/tests/transformers/qwen/test_tokenizer.py index 127dbd4fd2..4dfe2c0719 100644 --- a/tests/transformers/qwen/test_tokenizer.py +++ b/tests/transformers/qwen/test_tokenizer.py @@ -19,7 +19,7 @@ from paddleformers.transformers import QWenTokenizer -class Qwen2TokenizationTest(unittest.TestCase): +class QwenTokenizationTest(unittest.TestCase): from_pretrained_id = "PaddleNLP/qwen-7b" tokenizer_class = QWenTokenizer test_slow_tokenizer = True @@ -56,4 +56,4 @@ def test_tokenize(self): self.assertEqual(text, decode_text) -Qwen2TokenizationTest().test_slow_tokenizer_from_pretrained() +QwenTokenizationTest().test_slow_tokenizer_from_pretrained() diff --git a/tests/transformers/test_hf_tokenizer.py b/tests/transformers/test_hf_tokenizer.py index 794c618515..10ccc418d8 100644 --- a/tests/transformers/test_hf_tokenizer.py +++ b/tests/transformers/test_hf_tokenizer.py @@ -19,7 +19,7 @@ from paddleformers.transformers import AutoTokenizer, Qwen2Tokenizer -@unittest.skip("don't support multisource download") +@unittest.skip("multi source download CI not support") class TestHFMultiSourceTokenizer(unittest.TestCase): def encode(self, tokenizer): input_text = "hello world, 你好" @@ -68,7 +68,7 @@ def test_auto_tokenizer(self): class TestHFTokenizer(unittest.TestCase): def setUp(self): - self.tokenizer = AutoTokenizer.from_pretrained("PaddleNLP/Qwen2.5-7B") + self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B", from_hf_hub=True) def test_encode(self): input_text = "hello world, this is paddle format checker"