Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion examples/simple_trainer_2dgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper
from gsplat.strategy import DefaultStrategy

from gsplat.utils import save_ply

@dataclass
class Config:
Expand Down Expand Up @@ -67,6 +67,10 @@ class Config:
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = False
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

# Initialization strategy
init_type: str = "sfm"
Expand Down Expand Up @@ -170,6 +174,7 @@ class Config:
def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.ply_steps = [int(i * factor) for i in self.ply_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)
self.refine_start_iter = int(self.refine_start_iter * factor)
Expand Down Expand Up @@ -265,6 +270,8 @@ def __init__(self, cfg: Config) -> None:
os.makedirs(self.stats_dir, exist_ok=True)
self.render_dir = f"{cfg.result_dir}/renders"
os.makedirs(self.render_dir, exist_ok=True)
self.ply_dir = f"{cfg.result_dir}/ply"
os.makedirs(self.ply_dir, exist_ok=True)

# Tensorboard
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
Expand Down Expand Up @@ -724,6 +731,25 @@ def train(self):
f"{self.ckpt_dir}/ckpt_{step}.pt",
)

if (
step in [i - 1 for i in cfg.ply_steps]
or step == max_steps - 1
and cfg.save_ply
):
rgb = None
if self.cfg.app_opt:
# eval at origin to bake the appeareance into the colors
rgb = self.app_module(
features=self.splats["features"],
embed_ids=None,
dirs=torch.zeros_like(self.splats["means"][None, :, :]),
sh_degree=sh_degree_to_use,
)
rgb = rgb + self.splats["colors"]
rgb = torch.sigmoid(rgb).squeeze(0)

save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply", rgb, is_2dgs=True)

# eval the full set
if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1:
self.eval(step)
Expand Down
6 changes: 4 additions & 2 deletions examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from gsplat._helper import load_test_data
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.rendering import rasterization, rasterization_2dgs


def main(local_rank: int, world_rank, world_size: int, args):
Expand Down Expand Up @@ -169,14 +169,16 @@ def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]

if args.backend == "gsplat":
rasterization_fn = rasterization
elif args.backend == "gsplat-2dgs":
rasterization_fn = rasterization_2dgs
elif args.backend == "inria":
from gsplat import rasterization_inria_wrapper

rasterization_fn = rasterization_inria_wrapper
else:
raise ValueError

render_colors, render_alphas, meta = rasterization_fn(
render_colors, render_alphas, _, _, _, _, meta = rasterization_fn(
means, # [N, 3]
quats, # [N, 4]
scales, # [N, 3]
Expand Down
4 changes: 3 additions & 1 deletion gsplat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import numpy as np


def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None):
def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None, is_2dgs: bool = False):
# Convert all tensors to numpy arrays in one go
print(f"Saving ply to {dir}")
numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()}

means = numpy_data["means"]
scales = numpy_data["scales"]
if is_2dgs:
scales[:, 2] = np.log(1e-6)
quats = numpy_data["quats"]
opacities = numpy_data["opacities"]

Expand Down