Skip to content

Commit a85aef8

Browse files
committed
rf
1 parent 548a52b commit a85aef8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3374
-52
lines changed

alper_env.yaml

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
name: robodiff
2+
channels:
3+
- pytorch
4+
- nvidia
5+
- conda-forge
6+
dependencies:
7+
- _libgcc_mutex=0.1=conda_forge
8+
- _openmp_mutex=4.5=2_kmp_llvm
9+
- aom=3.6.1=h59595ed_0
10+
- blas=2.116=mkl
11+
- blas-devel=3.9.0=16_linux64_mkl
12+
- brotli-python=1.1.0=py310hf71b8c6_2
13+
- bzip2=1.0.8=h4bc722e_7
14+
- ca-certificates=2024.8.30=hbcca054_0
15+
- certifi=2024.8.30=pyhd8ed1ab_0
16+
- cffi=1.17.1=py310h8deb56e_0
17+
- charset-normalizer=3.4.0=pyhd8ed1ab_0
18+
- cpython=3.10.15=py310hd8ed1ab_2
19+
- cuda-cudart=12.4.127=0
20+
- cuda-cupti=12.4.127=0
21+
- cuda-libraries=12.4.1=0
22+
- cuda-nvrtc=12.4.127=0
23+
- cuda-nvtx=12.4.127=0
24+
- cuda-opencl=12.6.77=0
25+
- cuda-runtime=12.4.1=0
26+
- cuda-version=12.6=3
27+
- expat=2.6.3=h5888daf_0
28+
- ffmpeg=4.4.2=gpl_hdf48244_113
29+
- filelock=3.16.1=pyhd8ed1ab_0
30+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
31+
- font-ttf-inconsolata=3.000=h77eed37_0
32+
- font-ttf-source-code-pro=2.038=h77eed37_0
33+
- font-ttf-ubuntu=0.83=h77eed37_3
34+
- fontconfig=2.14.2=h14ed4e7_0
35+
- fonts-conda-ecosystem=1=0
36+
- fonts-conda-forge=1=0
37+
- freetype=2.12.1=h267a509_2
38+
- gettext=0.22.5=he02047a_3
39+
- gettext-tools=0.22.5=he02047a_3
40+
- giflib=5.2.2=hd590300_0
41+
- gmp=6.3.0=hac33072_2
42+
- gmpy2=2.1.5=py310he8512ff_2
43+
- gnutls=3.7.9=hb077bed_0
44+
- h2=4.1.0=pyhd8ed1ab_0
45+
- hpack=4.0.0=pyh9f0ad1d_0
46+
- hyperframe=6.0.1=pyhd8ed1ab_0
47+
- icu=75.1=he02047a_0
48+
- idna=3.10=pyhd8ed1ab_0
49+
- jinja2=3.1.4=pyhd8ed1ab_0
50+
- lame=3.100=h166bdaf_1003
51+
- lcms2=2.16=hb7c19ff_0
52+
- ld_impl_linux-64=2.43=h712a8e2_1
53+
- lerc=4.0.0=h27087fc_0
54+
- libasprintf=0.22.5=he8f35ee_3
55+
- libasprintf-devel=0.22.5=he8f35ee_3
56+
- libblas=3.9.0=16_linux64_mkl
57+
- libcblas=3.9.0=16_linux64_mkl
58+
- libcublas=12.4.5.8=0
59+
- libcufft=11.2.1.3=0
60+
- libcufile=1.11.1.6=0
61+
- libcurand=10.3.7.77=0
62+
- libcusolver=11.6.1.9=0
63+
- libcusparse=12.3.1.170=0
64+
- libdeflate=1.22=hb9d3cd8_0
65+
- libdrm=2.4.123=hb9d3cd8_0
66+
- libegl=1.7.0=ha4b6fd6_1
67+
- libexpat=2.6.3=h5888daf_0
68+
- libffi=3.4.2=h7f98852_5
69+
- libgcc=14.2.0=h77fa898_1
70+
- libgcc-ng=14.2.0=h69a702a_1
71+
- libgettextpo=0.22.5=he02047a_3
72+
- libgettextpo-devel=0.22.5=he02047a_3
73+
- libgfortran=14.2.0=h69a702a_1
74+
- libgfortran-ng=14.2.0=h69a702a_1
75+
- libgfortran5=14.2.0=hd5240d6_1
76+
- libgl=1.7.0=ha4b6fd6_1
77+
- libglvnd=1.7.0=ha4b6fd6_1
78+
- libglx=1.7.0=ha4b6fd6_1
79+
- libgomp=14.2.0=h77fa898_1
80+
- libhwloc=2.11.1=default_hecaa2ac_1000
81+
- libiconv=1.17=hd590300_2
82+
- libidn2=2.3.7=hd590300_0
83+
- libjpeg-turbo=3.0.0=hd590300_1
84+
- liblapack=3.9.0=16_linux64_mkl
85+
- liblapacke=3.9.0=16_linux64_mkl
86+
- libnpp=12.2.5.30=0
87+
- libnsl=2.0.1=hd590300_0
88+
- libnvfatbin=12.6.77=0
89+
- libnvjitlink=12.4.127=0
90+
- libnvjpeg=12.3.1.117=0
91+
- libpciaccess=0.18=hd590300_0
92+
- libpng=1.6.44=hadc24fc_0
93+
- libsqlite=3.46.1=hadc24fc_0
94+
- libstdcxx=14.2.0=hc0a3c3a_1
95+
- libstdcxx-ng=14.2.0=h4852527_1
96+
- libtasn1=4.19.0=h166bdaf_0
97+
- libtiff=4.7.0=he137b08_1
98+
- libunistring=0.9.10=h7f98852_0
99+
- libuuid=2.38.1=h0b41bf4_0
100+
- libva=2.22.0=h8a09558_1
101+
- libvpx=1.13.1=h59595ed_0
102+
- libwebp=1.4.0=h2c329e2_0
103+
- libwebp-base=1.4.0=hd590300_0
104+
- libxcb=1.17.0=h8a09558_0
105+
- libxcrypt=4.4.36=hd590300_1
106+
- libxml2=2.12.7=he7c6b58_4
107+
- libzlib=1.3.1=hb9d3cd8_2
108+
- llvm-openmp=15.0.7=h0cdce71_0
109+
- markupsafe=3.0.1=py310h89163eb_1
110+
- mkl=2022.1.0=h84fe81f_915
111+
- mkl-devel=2022.1.0=ha770c72_916
112+
- mkl-include=2022.1.0=h84fe81f_915
113+
- mpc=1.3.1=h24ddda3_1
114+
- mpfr=4.2.1=h90cbb55_3
115+
- mpmath=1.3.0=pyhd8ed1ab_0
116+
- ncurses=6.5=he02047a_1
117+
- nettle=3.9.1=h7ab15ed_0
118+
- networkx=3.4.1=pyhd8ed1ab_0
119+
- numpy=2.1.2=py310hd6e36ab_0
120+
- openh264=2.3.1=hcb278e6_2
121+
- openjpeg=2.5.2=h488ebb8_0
122+
- openssl=3.3.2=hb9d3cd8_0
123+
- p11-kit=0.24.1=hc5aa10d_0
124+
- pillow=11.0.0=py310hfeaa1f3_0
125+
- pip=24.2=pyh8b19718_1
126+
- pthread-stubs=0.4=hb9d3cd8_1002
127+
- pycparser=2.22=pyhd8ed1ab_0
128+
- pysocks=1.7.1=pyha2e5f31_6
129+
- python=3.10.15=h4a871b0_2_cpython
130+
- python_abi=3.10=5_cp310
131+
- pytorch=2.5.0=py3.10_cuda12.4_cudnn9.1.0_0
132+
- pytorch-cuda=12.4=hc786d27_7
133+
- pytorch-mutex=1.0=cuda
134+
- pyyaml=6.0.2=py310ha75aee5_1
135+
- readline=8.2=h8228510_1
136+
- requests=2.32.3=pyhd8ed1ab_0
137+
- svt-av1=1.4.1=hcb278e6_0
138+
- tbb=2021.13.0=h84d6215_0
139+
- tk=8.6.13=noxft_h4845f30_101
140+
- torchaudio=2.5.0=py310_cu124
141+
- torchtriton=3.1.0=py310
142+
- torchvision=0.20.0=py310_cu124
143+
- typing_extensions=4.12.2=pyha770c72_0
144+
- tzdata=2024b=hc8b5060_0
145+
- urllib3=2.2.3=pyhd8ed1ab_0
146+
- wayland=1.23.1=h3e06ad9_0
147+
- wayland-protocols=1.37=hd8ed1ab_0
148+
- wheel=0.44.0=pyhd8ed1ab_0
149+
- x264=1!164.3095=h166bdaf_2
150+
- x265=3.5=h924138e_3
151+
- xorg-libx11=1.8.10=h4f16b4b_0
152+
- xorg-libxau=1.0.11=hb9d3cd8_1
153+
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
154+
- xorg-libxext=1.3.6=hb9d3cd8_0
155+
- xorg-libxfixes=6.0.1=hb9d3cd8_0
156+
- xorg-xorgproto=2024.1=hb9d3cd8_1
157+
- xz=5.2.6=h166bdaf_0
158+
- yaml=0.2.5=h7f98852_2
159+
- zstandard=0.23.0=py310ha39cb0e_1
160+
- zstd=1.5.6=ha6fb4c9_0
161+
- pip:
162+
- aiosignal==1.3.1
163+
- antlr4-python3-runtime==4.9.3
164+
- asciitree==0.3.3
165+
- attrs==24.2.0
166+
- av==13.1.0
167+
- blessed==1.20.0
168+
- click==8.1.7
169+
- cloudpickle==3.1.0
170+
- contourpy==1.3.0
171+
- cycler==0.12.1
172+
- diffusers==0.31.0
173+
- dill==0.3.9
174+
- docker-pycreds==0.4.0
175+
- egl-probe==1.0.2
176+
- einops==0.8.0
177+
- etils==1.10.0
178+
- evdev==1.7.1
179+
- fasteners==0.19
180+
- fonttools==4.54.1
181+
- frozenlist==1.5.0
182+
- fsspec==2024.9.0
183+
- gitdb==4.0.11
184+
- gitpython==3.1.43
185+
- glfw==2.7.0
186+
- gpustat==1.1.1
187+
- gym==0.26.2
188+
- gym-notices==0.0.8
189+
- h5py==3.12.1
190+
- huggingface-hub==0.26.0
191+
- hydra-core==1.3.2
192+
- imagecodecs==2024.9.22
193+
- imageio==2.36.0
194+
- imageio-ffmpeg==0.5.1
195+
- importlib-metadata==8.5.0
196+
- importlib-resources==6.4.5
197+
- jsonschema==4.23.0
198+
- jsonschema-specifications==2024.10.1
199+
- kiwisolver==1.4.7
200+
- lazy-loader==0.4
201+
- llvmlite==0.43.0
202+
- markdown==3.7
203+
- matplotlib==3.9.2
204+
- msgpack==1.1.0
205+
- mujoco==3.2.4
206+
- numba==0.60.0
207+
- numcodecs==0.13.1
208+
- nvidia-ml-py==12.560.30
209+
- omegaconf==2.3.0
210+
- opencv-python==4.10.0.84
211+
- packaging==24.1
212+
- platformdirs==4.3.6
213+
- protobuf==5.28.2
214+
- psutil==6.1.0
215+
- pygame==2.6.1
216+
- pymunk==6.9.0
217+
- pynput==1.7.7
218+
- pyopengl==3.1.7
219+
- pyparsing==3.2.0
220+
- python-xlib==0.33
221+
- pytz==2024.2
222+
- ray==2.39.0
223+
- referencing==0.35.1
224+
- regex==2024.9.11
225+
- robosuite==1.4.1
226+
- rpds-py==0.21.0
227+
- safetensors==0.4.5
228+
- scikit-image==0.24.0
229+
- scipy==1.14.1
230+
- sentry-sdk==2.17.0
231+
- setproctitle==1.3.3
232+
- setuptools==75.2.0
233+
- shapely==2.0.6
234+
- six==1.16.0
235+
- smmap==5.0.1
236+
- sympy==1.13.1
237+
- tensorboardx==2.6.2.2
238+
- termcolor==2.5.0
239+
- threadpoolctl==3.5.0
240+
- tifffile==2024.9.20
241+
- tqdm==4.66.5
242+
- wandb==0.18.5
243+
- wcwidth==0.2.13
244+
- zarr==2.18.3
245+
- zipp==3.20.2
246+
prefix: /local/vondrick/alper/miniforge3/envs/robodiff

diffusion_policy/common/schedulers.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
2+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import math
17+
from dataclasses import dataclass
18+
from typing import List, Optional, Tuple, Union
19+
20+
import numpy as np
21+
import torch
22+
23+
from diffusers.configuration_utils import ConfigMixin, register_to_config
24+
from diffusers.utils import BaseOutput, logging
25+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
26+
27+
28+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29+
30+
def logit(x):
31+
return torch.log(x) - torch.log(1-x)
32+
33+
def logit_normal_pdf(x, m, s):
34+
x = torch.tensor(x).clamp(1e-7, 1-1e-7)
35+
return (1/(s * math.sqrt(2*math.pi))) * (1/x * (1-x)) * torch.exp(-(logit(x)-m)**2/(2*s**2))
36+
37+
@dataclass
38+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
39+
"""
40+
Output class for the scheduler's `step` function output.
41+
42+
Args:
43+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
45+
denoising loop.
46+
"""
47+
48+
prev_sample: torch.FloatTensor
49+
50+
51+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
52+
"""
53+
Euler scheduler.
54+
55+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
56+
methods the library implements for all schedulers such as loading and saving.
57+
58+
Args:
59+
num_train_timesteps (`int`, defaults to 1000):
60+
The number of diffusion steps to train the model.
61+
timestep_spacing (`str`, defaults to `"linspace"`):
62+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
63+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
64+
shift (`float`, defaults to 1.0):
65+
The shift value for the timestep schedule.
66+
"""
67+
68+
_compatibles = []
69+
order = 1
70+
71+
@register_to_config
72+
def __init__(
73+
self,
74+
num_train_timesteps: int = 1024,
75+
prediction_type: str = 'flow',
76+
sampling_weight: str = 'logit_normal',
77+
):
78+
self.prediction_type = prediction_type
79+
self.num_train_timesteps = num_train_timesteps
80+
self.sampling_weight = sampling_weight
81+
82+
self.timesteps = None
83+
84+
def add_noise(self, original_samples, noise, timesteps):
85+
86+
timesteps = timesteps.to(original_samples.device).float()/self.num_train_timesteps
87+
88+
while len(timesteps.shape) < len(original_samples.shape):
89+
timesteps = timesteps.unsqueeze(-1)
90+
91+
return original_samples * timesteps + noise * (1 - timesteps)
92+
93+
def sample_timesteps(self, bsz, device):
94+
if self.sampling_weight == 'logit_normal':
95+
x = torch.linspace(0, 1, self.num_train_timesteps, device=device)
96+
prob = logit_normal_pdf(x, m=0.0, s=1.0) + 1e-3
97+
prob = prob / prob.sum()
98+
99+
sample = torch.multinomial(prob, bsz, replacement=True).long()
100+
return sample
101+
else:
102+
return torch.randint(0, self.num_train_timesteps, (bsz,), device=device).long()
103+
104+
def set_timesteps(self, num_inference_steps):
105+
"""
106+
Don't judge me, I just tried matching the Diffusion Policy inference API
107+
"""
108+
self.timesteps = np.linspace(0, self.num_train_timesteps, num_inference_steps+1)[:-1]
109+
110+
def step(self, model_output, timestep, sample, generator=None, **kwargs):
111+
112+
dt = 1.0 / len(self.timesteps)
113+
sample = model_output * dt + sample
114+
115+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample)
116+
117+
def __len__(self):
118+
return self.config.num_train_timesteps
119+
120+
if __name__ == "__main__":
121+
scheduler = FlowMatchEulerDiscreteScheduler(1024)
122+
print(scheduler.add_noise(torch.randn(1, 1024), torch.tensor([1.,2.,3.,4.,5.]), noise=torch.randn(1, 1024)))
123+
124+
scheduler.set_timesteps(8)
125+
print(scheduler.add_noise(torch.randn(1, 1024), torch.tensor([1.,2.,3.,4.,5.]), noise=torch.randn(1, 1024)))
126+
127+
scheduler.set_timesteps(16)
128+
print(scheduler.add_noise(torch.randn(1, 1024), torch.tensor([1.,2.,3.,4.,5.]), noise=torch.randn(1, 1024)))
129+
130+
#do a step
131+
print(scheduler.step(torch.randn(1, 1024), torch.tensor([1.,2.,3.,4.,5.]), torch.randn(1, 1024)))

0 commit comments

Comments
 (0)