-
Couldn't load subscription status.
- Fork 368
Cpu memory optimization #3845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Cpu memory optimization #3845
Changes from all commits
754743b
2140c49
a016bc0
6ea89ae
c286767
2540824
c7f8b12
711446c
35d5861
503f320
6b1950c
1e2e669
33ca588
d99f183
66b40bd
880b639
fddc075
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,7 @@ | |
| ) | ||
| from torch_tensorrt.dynamo.utils import ( | ||
| deallocate_module, | ||
| get_cpu_memory_usage, | ||
| get_flat_args_with_check, | ||
| get_output_metadata, | ||
| parse_graph_io, | ||
|
|
@@ -675,7 +676,7 @@ def compile( | |
| "l2_limit_for_tiling": l2_limit_for_tiling, | ||
| "offload_module_to_cpu": offload_module_to_cpu, | ||
| } | ||
|
|
||
| logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") | ||
| settings = CompilationSettings(**compilation_options) | ||
| logger.info("Compilation Settings: %s\n", settings) | ||
| exported_program = pre_export_lowering(exported_program, settings) | ||
|
|
@@ -689,14 +690,17 @@ def compile( | |
|
|
||
| # Apply lowering on the graph module | ||
| gm = post_lowering(gm, settings) | ||
| logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") | ||
| logger.debug("Lowered Input graph: " + str(gm.graph)) | ||
|
|
||
| # Move the weights in the state_dict to CPU | ||
| if offload_module_to_cpu: | ||
| deallocate_module(gm, delete_module=False) | ||
| deallocate_module(exported_program.module(), delete_module=False) | ||
| logger.info( | ||
| "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" | ||
| ) | ||
| logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") | ||
| else: | ||
| remaining_memory, total_memory = torch.cuda.mem_get_info() | ||
| if remaining_memory < total_memory // 2: | ||
|
|
@@ -858,6 +862,11 @@ def preserve_module_specs( | |
| # Iterate over all components that can be accelerated | ||
| # Generate the corresponding TRT Module for those | ||
|
|
||
| # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. | ||
| # This is done to release CPU memory. | ||
| for attr in dir(gm): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets make this opt in similar to malloc trim There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be cleared no matter what? |
||
| if attr.startswith("_frozen_param"): | ||
| delattr(gm, attr) | ||
| for name, _ in partitioned_module.named_children(): | ||
| submodule = getattr(partitioned_module, name) | ||
| # filter on the GraphModule | ||
|
|
@@ -1231,7 +1240,7 @@ def convert_exported_program_to_serialized_trt_engine( | |
|
|
||
| # Prepare torch_trt inputs | ||
| trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) | ||
| trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) | ||
| trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) | ||
| device = to_torch_tensorrt_device(device) | ||
| enabled_precisions = {dtype._from(p) for p in enabled_precisions} | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,12 @@ | |
| from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig | ||
| from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger | ||
| from torch_tensorrt.dynamo.observer import Observer | ||
| from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device | ||
| from torch_tensorrt.dynamo.utils import ( | ||
| DYNAMIC_DIM, | ||
| deallocate_module, | ||
| get_cpu_memory_usage, | ||
| to_torch_device, | ||
| ) | ||
| from torch_tensorrt.logging import TRT_LOGGER | ||
|
|
||
| _LOGGER: logging.Logger = logging.getLogger(__name__) | ||
|
|
@@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError): | |
|
|
||
|
|
||
| class TRTInterpreterResult(NamedTuple): | ||
| serialized_engine: bytes | ||
| engine: trt.ICudaEngine | ||
| input_names: Sequence[str] | ||
| output_names: Sequence[str] | ||
| weight_name_map: Optional[dict[Any, Any]] | ||
|
|
@@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None: | |
| _LOGGER.info("Building weight name mapping...") | ||
| # Stage 1: Name mapping | ||
| torch_device = to_torch_device(self.compilation_settings.device) | ||
| self.module.to(torch_device) | ||
| sd = self.module.state_dict() | ||
| sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} | ||
| weight_name_map: dict[str, Any] = {} | ||
| weight_refit_map = self.ctx.weight_refit_map | ||
| constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} | ||
|
|
@@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None: | |
| torch.cuda.empty_cache() | ||
|
|
||
| @needs_refit # type: ignore[misc] | ||
| def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: | ||
| def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None: | ||
| serialized_engine = engine.serialize() | ||
| # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine | ||
| # if not self.compilation_settings.strip_engine_weights: | ||
| # # set EXCLUDE_WEIGHTS flag to strip weights | ||
| # runtime = trt.Runtime(TRT_LOGGER) | ||
| # engine = runtime.deserialize_cuda_engine(serialized_engine) | ||
|
|
||
| # serialization_config = engine.create_serialization_config() | ||
| # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) | ||
| # serialized_engine = engine.serialize_with_config( | ||
|
|
@@ -733,6 +735,9 @@ def run( | |
| return interpreter_result # type: ignore[no-any-return] | ||
|
|
||
| self._construct_trt_network_def() | ||
| _LOGGER.debug( | ||
| f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" | ||
| ) | ||
|
|
||
| if not self.compilation_settings.immutable_weights: | ||
| self._save_weight_mapping() | ||
|
|
@@ -750,16 +755,19 @@ def run( | |
| self._create_timing_cache( | ||
| builder_config, self.compilation_settings.timing_cache_path | ||
| ) | ||
| serialized_engine = self.builder.build_serialized_network( | ||
|
|
||
| cuda_engine = self.builder.build_engine_with_config( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious what is the benefit in CPU memory, if we return cuda_engine instead of serialized_engine? |
||
| self.ctx.net, builder_config | ||
| ) | ||
| assert serialized_engine | ||
| assert cuda_engine | ||
|
|
||
| _LOGGER.debug( | ||
| f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" | ||
| ) | ||
|
|
||
| _LOGGER.info( | ||
| f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" | ||
| ) | ||
| _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") | ||
|
|
||
| self.ctx.clear_cpu_weights_reference_holder() | ||
|
|
||
| self._save_timing_cache( | ||
|
|
@@ -772,14 +780,10 @@ def run( | |
| and self.compilation_settings.cache_built_engines | ||
| and self.engine_cache is not None | ||
| ): | ||
| self._insert_engine_to_cache(hash_val, serialized_engine) | ||
|
|
||
| with io.BytesIO() as engine_bytes: | ||
| engine_bytes.write(serialized_engine) | ||
| engine_str = engine_bytes.getvalue() | ||
| self._insert_engine_to_cache(hash_val, cuda_engine) | ||
|
|
||
| return TRTInterpreterResult( | ||
| engine_str, | ||
| cuda_engine, | ||
| self._input_names, | ||
| self._output_names, | ||
| self.weight_name_map, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,34 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import io | ||
| import logging | ||
| from typing import Any, List, Optional, Sequence | ||
| from typing import Any, List, NamedTuple, Optional, Sequence | ||
|
|
||
| import torch | ||
| from torch_tensorrt._enums import dtype | ||
| from torch_tensorrt._features import ENABLED_FEATURES | ||
| from torch_tensorrt._Input import Input | ||
| from torch_tensorrt.dynamo._engine_cache import BaseEngineCache | ||
| from torch_tensorrt.dynamo._settings import CompilationSettings | ||
| from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( | ||
| TRTInterpreter, | ||
| TRTInterpreterResult, | ||
| ) | ||
| from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter | ||
| from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule | ||
| from torch_tensorrt.dynamo.utils import get_output_dtypes | ||
| from torch_tensorrt.dynamo.utils import ( | ||
| get_cpu_memory_usage, | ||
| get_output_dtypes, | ||
| release_memory, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class SerializedInterpreterResult(NamedTuple): | ||
| serialized_engine: bytes | ||
| input_names: Sequence[str] | ||
| output_names: Sequence[str] | ||
| weight_name_map: Optional[dict[Any, Any]] | ||
| requires_output_allocator: bool | ||
|
|
||
|
|
||
| def infer_module_output_dtypes( | ||
| module: torch.fx.GraphModule, | ||
| truncate_double: bool = False, | ||
|
|
@@ -29,7 +39,7 @@ def infer_module_output_dtypes( | |
| """ | ||
| outputs = [node for node in module.graph.nodes if node.op == "output"] | ||
| outputs = outputs[0].args | ||
| return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] | ||
| return get_output_dtypes(outputs, truncate_double) # type: ignore | ||
|
|
||
|
|
||
| def interpret_module_to_result( | ||
|
|
@@ -39,7 +49,7 @@ def interpret_module_to_result( | |
| arg_inputs: Optional[Sequence[Input]] = None, | ||
| kwarg_inputs: Optional[dict[str, Any]] = None, | ||
| engine_cache: Optional[BaseEngineCache] = None, | ||
| ) -> TRTInterpreterResult: | ||
| ) -> SerializedInterpreterResult: | ||
| """Interpret an FX module to a TRTInterpreterResult | ||
| Args: | ||
| module: FX GraphModule to interpret | ||
|
|
@@ -65,7 +75,32 @@ def interpret_module_to_result( | |
| ) | ||
|
|
||
| interpreter_result = interpreter.run() | ||
| return interpreter_result | ||
| # Delete the frozen parameters from the module to release CPU memory | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we gate this by the same env variable as the malloc_trim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say something like |
||
| del interpreter | ||
| for attr in dir(module): | ||
| if attr.startswith("_frozen_param"): | ||
| delattr(module, attr) | ||
| release_memory() | ||
| logger.debug( | ||
| f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" | ||
| ) | ||
|
|
||
| serialized_engine = interpreter_result.engine.serialize() | ||
| with io.BytesIO() as engine_bytes: | ||
| engine_bytes.write(serialized_engine) | ||
| serialized_engine = engine_bytes.getvalue() | ||
| logger.debug( | ||
| f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" | ||
| ) | ||
| serialized_interpreter_result = SerializedInterpreterResult( | ||
| serialized_engine=serialized_engine, | ||
| input_names=interpreter_result.input_names, | ||
| output_names=interpreter_result.output_names, | ||
| weight_name_map=interpreter_result.weight_name_map, | ||
| requires_output_allocator=interpreter_result.requires_output_allocator, | ||
| ) | ||
|
|
||
| return serialized_interpreter_result | ||
|
|
||
|
|
||
| def convert_module( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arent these the same?