Skip to content

Commit 79f05e4

Browse files
[Multimodal] Always enable hashing mm data (vllm-project#23308)
Signed-off-by: Roger Wang <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent f8daddc commit 79f05e4

File tree

15 files changed

+95
-149
lines changed

15 files changed

+95
-149
lines changed

vllm/config/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,15 +1685,6 @@ def uses_mrope(self) -> bool:
16851685
def is_multimodal_model(self) -> bool:
16861686
return self.multimodal_config is not None
16871687

1688-
@property
1689-
def processor_return_mm_hashes(self) -> bool:
1690-
"""Whether the multi-modal processor should output hashes."""
1691-
mm_config = self.multimodal_config
1692-
if mm_config is None:
1693-
return False
1694-
1695-
return mm_config.mm_processor_cache_gb > 0
1696-
16971688
@property
16981689
def enable_mm_processor_cache(self) -> bool:
16991690
"""Whether the multi-modal processor cache should be enabled."""

vllm/inputs/preprocess.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def _process_multimodal(
254254
mm_processor_kwargs: Optional[Mapping[str, object]],
255255
tokenization_kwargs: Optional[dict[str, Any]] = None,
256256
lora_request: Optional[LoRARequest] = None,
257-
return_mm_hashes: bool = False,
258257
) -> MultiModalInputs:
259258
"""
260259
Apply the model's multi-modal processor to a multi-modal prompt,
@@ -271,8 +270,7 @@ def _process_multimodal(
271270
return mm_processor.apply(prompt,
272271
mm_data,
273272
hf_processor_mm_kwargs=mm_processor_kwargs,
274-
tokenization_kwargs=tokenization_kwargs,
275-
return_mm_hashes=return_mm_hashes)
273+
tokenization_kwargs=tokenization_kwargs)
276274

277275
async def _process_multimodal_async(
278276
self,
@@ -281,7 +279,6 @@ async def _process_multimodal_async(
281279
mm_processor_kwargs: Optional[Mapping[str, object]],
282280
tokenization_kwargs: Optional[dict[str, Any]] = None,
283281
lora_request: Optional[LoRARequest] = None,
284-
return_mm_hashes: bool = False,
285282
) -> MultiModalInputs:
286283
"""
287284
Async version of
@@ -297,8 +294,7 @@ async def _process_multimodal_async(
297294
return mm_processor.apply(prompt,
298295
mm_data,
299296
hf_processor_mm_kwargs=mm_processor_kwargs,
300-
tokenization_kwargs=tokenization_kwargs,
301-
return_mm_hashes=return_mm_hashes)
297+
tokenization_kwargs=tokenization_kwargs)
302298

303299
def _process_embeds(
304300
self,
@@ -335,7 +331,6 @@ def _process_tokens(
335331
parsed_content: TokensPrompt,
336332
tokenization_kwargs: Optional[dict[str, Any]] = None,
337333
lora_request: Optional[LoRARequest] = None,
338-
return_mm_hashes: bool = False,
339334
) -> Union[TokenInputs, MultiModalInputs]:
340335
prompt_token_ids = parsed_content["prompt_token_ids"]
341336
token_type_ids = parsed_content.get("token_type_ids")
@@ -348,7 +343,6 @@ def _process_tokens(
348343
parsed_content.get("mm_processor_kwargs"),
349344
tokenization_kwargs=tokenization_kwargs,
350345
lora_request=lora_request,
351-
return_mm_hashes=return_mm_hashes,
352346
)
353347
else:
354348
inputs = token_inputs(
@@ -366,7 +360,6 @@ async def _process_tokens_async(
366360
parsed_content: TokensPrompt,
367361
tokenization_kwargs: Optional[dict[str, Any]] = None,
368362
lora_request: Optional[LoRARequest] = None,
369-
return_mm_hashes: bool = False,
370363
) -> Union[TokenInputs, MultiModalInputs]:
371364
prompt_token_ids = parsed_content["prompt_token_ids"]
372365
token_type_ids = parsed_content.get("token_type_ids")
@@ -379,7 +372,6 @@ async def _process_tokens_async(
379372
parsed_content.get("mm_processor_kwargs"),
380373
tokenization_kwargs=tokenization_kwargs,
381374
lora_request=lora_request,
382-
return_mm_hashes=return_mm_hashes,
383375
)
384376
else:
385377
inputs = token_inputs(
@@ -397,7 +389,6 @@ def _process_text(
397389
parsed_content: TextPrompt,
398390
tokenization_kwargs: Optional[dict[str, Any]] = None,
399391
lora_request: Optional[LoRARequest] = None,
400-
return_mm_hashes: bool = False,
401392
) -> Union[TokenInputs, MultiModalInputs]:
402393
prompt_text = parsed_content["prompt"]
403394

@@ -409,7 +400,6 @@ def _process_text(
409400
parsed_content.get("mm_processor_kwargs"),
410401
tokenization_kwargs=tokenization_kwargs,
411402
lora_request=lora_request,
412-
return_mm_hashes=return_mm_hashes,
413403
)
414404
else:
415405
prompt_token_ids = self._tokenize_prompt(
@@ -432,7 +422,6 @@ async def _process_text_async(
432422
parsed_content: TextPrompt,
433423
tokenization_kwargs: Optional[dict[str, Any]] = None,
434424
lora_request: Optional[LoRARequest] = None,
435-
return_mm_hashes: bool = False,
436425
) -> Union[TokenInputs, MultiModalInputs]:
437426
prompt_text = parsed_content["prompt"]
438427

@@ -444,7 +433,6 @@ async def _process_text_async(
444433
parsed_content.get("mm_processor_kwargs"),
445434
tokenization_kwargs=tokenization_kwargs,
446435
lora_request=lora_request,
447-
return_mm_hashes=return_mm_hashes,
448436
)
449437
else:
450438
prompt_token_ids = await self._tokenize_prompt_async(
@@ -467,7 +455,6 @@ def _prompt_to_llm_inputs(
467455
prompt: SingletonPrompt,
468456
tokenization_kwargs: Optional[dict[str, Any]] = None,
469457
lora_request: Optional[LoRARequest] = None,
470-
return_mm_hashes: bool = False,
471458
) -> SingletonInputs:
472459
"""
473460
Extract the singleton inputs from a prompt.
@@ -476,7 +463,6 @@ def _prompt_to_llm_inputs(
476463
477464
* prompt: single encoder or decoder input prompt
478465
* lora_request: this is only valid for decoder prompts
479-
* return_mm_hashes: whether to return multimodal hashes
480466
481467
Returns:
482468
@@ -490,21 +476,18 @@ def _prompt_to_llm_inputs(
490476
return self._process_tokens(
491477
parsed["content"],
492478
lora_request=lora_request,
493-
return_mm_hashes=return_mm_hashes,
494479
)
495480
if parsed["type"] == "text":
496481
return self._process_text(
497482
parsed["content"],
498483
tokenization_kwargs=tokenization_kwargs,
499484
lora_request=lora_request,
500-
return_mm_hashes=return_mm_hashes,
501485
)
502486
if parsed["type"] == "str":
503487
return self._process_text(
504488
TextPrompt(prompt=parsed["content"]),
505489
tokenization_kwargs=tokenization_kwargs,
506490
lora_request=lora_request,
507-
return_mm_hashes=return_mm_hashes,
508491
)
509492

510493
assert_never(parsed)
@@ -514,7 +497,6 @@ async def _prompt_to_llm_inputs_async(
514497
prompt: SingletonPrompt,
515498
tokenization_kwargs: Optional[dict[str, Any]] = None,
516499
lora_request: Optional[LoRARequest] = None,
517-
return_mm_hashes: bool = False,
518500
) -> SingletonInputs:
519501
"""
520502
Async version of
@@ -528,21 +510,18 @@ async def _prompt_to_llm_inputs_async(
528510
return await self._process_tokens_async(
529511
parsed["content"],
530512
lora_request=lora_request,
531-
return_mm_hashes=return_mm_hashes,
532513
)
533514
if parsed["type"] == "text":
534515
return await self._process_text_async(
535516
parsed["content"],
536517
tokenization_kwargs=tokenization_kwargs,
537518
lora_request=lora_request,
538-
return_mm_hashes=return_mm_hashes,
539519
)
540520
if parsed["type"] == "str":
541521
return await self._process_text_async(
542522
TextPrompt(prompt=parsed["content"]),
543523
tokenization_kwargs=tokenization_kwargs,
544524
lora_request=lora_request,
545-
return_mm_hashes=return_mm_hashes,
546525
)
547526

548527
assert_never(parsed)
@@ -785,7 +764,6 @@ def _process_decoder_only_prompt(
785764
prompt: SingletonPrompt,
786765
tokenization_kwargs: Optional[dict[str, Any]] = None,
787766
lora_request: Optional[LoRARequest] = None,
788-
return_mm_hashes: bool = False,
789767
) -> DecoderOnlyInputs:
790768
"""
791769
For decoder-only models:
@@ -796,7 +774,6 @@ def _process_decoder_only_prompt(
796774
797775
* prompt: input prompt
798776
* lora_request
799-
* return_mm_hashes
800777
801778
Returns:
802779
@@ -807,7 +784,6 @@ def _process_decoder_only_prompt(
807784
prompt,
808785
tokenization_kwargs=tokenization_kwargs,
809786
lora_request=lora_request,
810-
return_mm_hashes=return_mm_hashes,
811787
)
812788

813789
return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -817,7 +793,6 @@ async def _process_decoder_only_prompt_async(
817793
prompt: SingletonPrompt,
818794
tokenization_kwargs: Optional[dict[str, Any]] = None,
819795
lora_request: Optional[LoRARequest] = None,
820-
return_mm_hashes: bool = False,
821796
) -> DecoderOnlyInputs:
822797
"""
823798
Async version of
@@ -827,7 +802,6 @@ async def _process_decoder_only_prompt_async(
827802
prompt,
828803
tokenization_kwargs=tokenization_kwargs,
829804
lora_request=lora_request,
830-
return_mm_hashes=return_mm_hashes,
831805
)
832806

833807
return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -837,17 +811,15 @@ def preprocess(
837811
prompt: PromptType,
838812
tokenization_kwargs: Optional[dict[str, Any]] = None,
839813
lora_request: Optional[LoRARequest] = None,
840-
return_mm_hashes: bool = False,
841814
) -> ProcessorInputs:
842815
"""Preprocess the input prompt."""
843816
if self.model_config.is_encoder_decoder:
844-
assert not return_mm_hashes, (
845-
"Multimodal hashes for encoder-decoder models should not be ",
846-
"returned until they are supported on vLLM V1.")
847817
# Encoder-decoder model requires special mapping of
848-
# input prompts to encoder & decoder
818+
# input prompts to encoder & decoder.
849819
return self._process_encoder_decoder_prompt(
850-
prompt, tokenization_kwargs)
820+
prompt,
821+
tokenization_kwargs,
822+
)
851823

852824
if is_explicit_encoder_decoder_prompt(prompt):
853825
raise ValueError("Cannot pass encoder-decoder prompt "
@@ -858,27 +830,25 @@ def preprocess(
858830
prompt,
859831
tokenization_kwargs=tokenization_kwargs,
860832
lora_request=lora_request,
861-
return_mm_hashes=return_mm_hashes,
862833
)
863834

864835
async def preprocess_async(
865836
self,
866837
prompt: PromptType,
867838
tokenization_kwargs: Optional[dict[str, Any]] = None,
868839
lora_request: Optional[LoRARequest] = None,
869-
return_mm_hashes: bool = False,
870840
) -> ProcessorInputs:
871841
"""
872842
Async version of
873843
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
874844
"""
875845
if self.model_config.is_encoder_decoder:
876-
assert not return_mm_hashes, (
877-
"Multimodal hashes for encoder-decoder models should not be ",
878-
"returned until they are supported on vLLM V1.")
879846
# Encoder-decoder model requires special mapping of
880-
# input prompts to encoder & decoder
881-
return await self._process_encoder_decoder_prompt_async(prompt)
847+
# input prompts to encoder & decoder.
848+
return await self._process_encoder_decoder_prompt_async(
849+
prompt,
850+
tokenization_kwargs,
851+
)
882852

883853
if is_explicit_encoder_decoder_prompt(prompt):
884854
raise ValueError("Cannot pass encoder-decoder prompt "
@@ -889,5 +859,4 @@ async def preprocess_async(
889859
prompt,
890860
tokenization_kwargs=tokenization_kwargs,
891861
lora_request=lora_request,
892-
return_mm_hashes=return_mm_hashes,
893862
)

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,6 @@ def _cached_apply_hf_processor(
290290
mm_data_items: MultiModalDataItems,
291291
hf_processor_mm_kwargs: Mapping[str, object],
292292
tokenization_kwargs: Mapping[str, object],
293-
*,
294-
return_mm_hashes: bool,
295293
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
296294
# The processor logic is different for len(images) <= 2 vs > 2
297295
# Since the processing cache assumes that the processor output is
@@ -303,15 +301,13 @@ def _cached_apply_hf_processor(
303301
mm_data_items=mm_data_items,
304302
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
305303
tokenization_kwargs=tokenization_kwargs,
306-
return_mm_hashes=return_mm_hashes,
307304
)
308305

309306
return super()._cached_apply_hf_processor(
310307
prompt=prompt,
311308
mm_data_items=mm_data_items,
312309
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
313310
tokenization_kwargs=tokenization_kwargs,
314-
return_mm_hashes=return_mm_hashes,
315311
)
316312

317313

vllm/model_executor/models/h2ovl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,6 @@ def _cached_apply_hf_processor(
479479
mm_data_items: MultiModalDataItems,
480480
hf_processor_mm_kwargs: Mapping[str, object],
481481
tokenization_kwargs: Mapping[str, object],
482-
*,
483-
return_mm_hashes: bool,
484482
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
485483
# The processor logic is different for len(images) <= 1 vs > 1
486484
# Since the processing cache assumes that the processor output is
@@ -492,15 +490,13 @@ def _cached_apply_hf_processor(
492490
mm_data_items=mm_data_items,
493491
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
494492
tokenization_kwargs=tokenization_kwargs,
495-
return_mm_hashes=return_mm_hashes,
496493
)
497494

498495
return super()._cached_apply_hf_processor(
499496
prompt=prompt,
500497
mm_data_items=mm_data_items,
501498
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
502499
tokenization_kwargs=tokenization_kwargs,
503-
return_mm_hashes=return_mm_hashes,
504500
)
505501

506502

vllm/model_executor/models/llava.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,6 @@ def apply(
795795
mm_data: MultiModalDataDict,
796796
hf_processor_mm_kwargs: Mapping[str, object],
797797
tokenization_kwargs: Optional[Mapping[str, object]] = None,
798-
return_mm_hashes: bool = False,
799798
) -> MultiModalInputs:
800799
hf_config = self.info.get_hf_config()
801800
image_token_id = hf_config.image_token_index
@@ -807,7 +806,7 @@ def apply(
807806
)
808807

809808
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
810-
tokenization_kwargs, return_mm_hashes)
809+
tokenization_kwargs)
811810

812811
mm_items = self._to_mm_items(mm_data)
813812
mm_item_counts = mm_items.get_all_counts()

vllm/model_executor/models/mllama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,9 @@ def apply(
168168
mm_data: MultiModalDataDict,
169169
hf_processor_mm_kwargs: Mapping[str, object],
170170
tokenization_kwargs: Optional[Mapping[str, object]] = None,
171-
return_mm_hashes: bool = False,
172171
) -> MultiModalEncDecInputs:
173172
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
174-
tokenization_kwargs, return_mm_hashes)
173+
tokenization_kwargs)
175174

176175
image_token_id = self.info.get_hf_config().image_token_index
177176
# Check that the number of image tokens in the decoder prompt matches

vllm/model_executor/models/paligemma.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ def apply(
194194
mm_data: MultiModalDataDict,
195195
hf_processor_mm_kwargs: Mapping[str, object],
196196
tokenization_kwargs: Optional[Mapping[str, object]] = None,
197-
return_mm_hashes: bool = False,
198197
) -> MultiModalInputs:
199198
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
200-
tokenization_kwargs, return_mm_hashes)
199+
tokenization_kwargs)
201200
prompt_token_ids = mm_inputs["prompt_token_ids"]
202201

203202
tokenizer = self.info.get_tokenizer()

vllm/model_executor/models/pixtral.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,15 +308,12 @@ def _cached_apply_hf_processor(
308308
mm_data_items: MultiModalDataItems,
309309
hf_processor_mm_kwargs: Mapping[str, object],
310310
tokenization_kwargs: Mapping[str, object],
311-
*,
312-
return_mm_hashes: bool,
313311
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
314312
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
315313
prompt=prompt,
316314
mm_data_items=mm_data_items,
317315
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
318316
tokenization_kwargs=tokenization_kwargs,
319-
return_mm_hashes=return_mm_hashes,
320317
)
321318

322319
# NOTE: The tokens are already inserted by the chat template

0 commit comments

Comments
 (0)