Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions gen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True

#----------------------------------------------------------------------------

def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs):
def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), feature=None, **video_kwargs):
grid_w = grid_dims[0]
grid_h = grid_dims[1]

Expand All @@ -56,6 +56,12 @@ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind=
for idx in range(num_keyframes*grid_h*grid_w):
all_seeds[idx] = seeds[idx % len(seeds)]

if len(all_seeds) > 1 and feature is not None:
raise ValueError('Cannot explore a feature for more than a single image')

if len(all_seeds) == 1 and feature is None:
raise ValueError('Must specify a feature if exploring an image')

if shuffle_seed is not None:
rng = np.random.RandomState(seed=shuffle_seed)
rng.shuffle(all_seeds)
Expand All @@ -78,12 +84,16 @@ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind=

# Render video.
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
modifier = 1
for frame_idx in tqdm(range(num_keyframes * w_frames)):
imgs = []
for yi in range(grid_h):
for xi in range(grid_w):
interp = grid[yi][xi]
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
if feature is not None:
w[feature] = w[feature] * modifier
modifier + .01
img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
imgs.append(img)
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
Expand Down Expand Up @@ -133,6 +143,7 @@ def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
@click.option('--feature', type=int, help='Feature to explore', default=None)
def generate_images(
network_pkl: str,
seeds: List[int],
Expand All @@ -141,7 +152,8 @@ def generate_images(
grid: Tuple[int,int],
num_keyframes: Optional[int],
w_frames: int,
output: str
output: str,
feature: int
):
"""Render a latent vector interpolation video.

Expand Down Expand Up @@ -170,7 +182,7 @@ def generate_images(
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi)
gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, feature=feature)

#----------------------------------------------------------------------------

Expand Down