diff --git a/examples/simple_trainer_2dgs.py b/examples/simple_trainer_2dgs.py index 109008587..19c8011fb 100644 --- a/examples/simple_trainer_2dgs.py +++ b/examples/simple_trainer_2dgs.py @@ -31,7 +31,7 @@ from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper from gsplat.strategy import DefaultStrategy - +from gsplat.utils import save_ply @dataclass class Config: @@ -67,6 +67,10 @@ class Config: eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Whether to save ply file (storage size can be large) + save_ply: bool = False + # Steps to save the model as ply + ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -170,6 +174,7 @@ class Config: def adjust_steps(self, factor: float): self.eval_steps = [int(i * factor) for i in self.eval_steps] self.save_steps = [int(i * factor) for i in self.save_steps] + self.ply_steps = [int(i * factor) for i in self.ply_steps] self.max_steps = int(self.max_steps * factor) self.sh_degree_interval = int(self.sh_degree_interval * factor) self.refine_start_iter = int(self.refine_start_iter * factor) @@ -265,6 +270,8 @@ def __init__(self, cfg: Config) -> None: os.makedirs(self.stats_dir, exist_ok=True) self.render_dir = f"{cfg.result_dir}/renders" os.makedirs(self.render_dir, exist_ok=True) + self.ply_dir = f"{cfg.result_dir}/ply" + os.makedirs(self.ply_dir, exist_ok=True) # Tensorboard self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") @@ -724,6 +731,25 @@ def train(self): f"{self.ckpt_dir}/ckpt_{step}.pt", ) + if ( + step in [i - 1 for i in cfg.ply_steps] + or step == max_steps - 1 + and cfg.save_ply + ): + rgb = None + if self.cfg.app_opt: + # eval at origin to bake the appeareance into the colors + rgb = self.app_module( + features=self.splats["features"], + embed_ids=None, + dirs=torch.zeros_like(self.splats["means"][None, :, :]), + sh_degree=sh_degree_to_use, + ) + rgb = rgb + self.splats["colors"] + rgb = torch.sigmoid(rgb).squeeze(0) + + save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply", rgb, is_2dgs=True) + # eval the full set if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: self.eval(step) diff --git a/examples/simple_viewer.py b/examples/simple_viewer.py index 5d513dc18..9e913fffe 100644 --- a/examples/simple_viewer.py +++ b/examples/simple_viewer.py @@ -21,7 +21,7 @@ from gsplat._helper import load_test_data from gsplat.distributed import cli -from gsplat.rendering import rasterization +from gsplat.rendering import rasterization, rasterization_2dgs def main(local_rank: int, world_rank, world_size: int, args): @@ -169,6 +169,8 @@ def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int] if args.backend == "gsplat": rasterization_fn = rasterization + elif args.backend == "gsplat-2dgs": + rasterization_fn = rasterization_2dgs elif args.backend == "inria": from gsplat import rasterization_inria_wrapper @@ -176,7 +178,7 @@ def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int] else: raise ValueError - render_colors, render_alphas, meta = rasterization_fn( + render_colors, render_alphas, _, _, _, _, meta = rasterization_fn( means, # [N, 3] quats, # [N, 4] scales, # [N, 3] diff --git a/gsplat/utils.py b/gsplat/utils.py index f103e07bb..3df3a13f6 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -7,13 +7,15 @@ import numpy as np -def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): +def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None, is_2dgs: bool = False): # Convert all tensors to numpy arrays in one go print(f"Saving ply to {dir}") numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} means = numpy_data["means"] scales = numpy_data["scales"] + if is_2dgs: + scales[:, 2] = np.log(1e-6) quats = numpy_data["quats"] opacities = numpy_data["opacities"]