Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/snake/mrd_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CartesianFrameDataLoader,
MRDLoader,
NonCartesianFrameDataLoader,
SliceDataloader,
parse_sim_conf,
parse_waveform_information,
read_mrd_header,
Expand All @@ -16,6 +17,7 @@
"MRDLoader",
"CartesianFrameDataLoader",
"NonCartesianFrameDataLoader",
"SliceDataloader",
"parse_sim_conf",
"parse_waveform_information",
"make_base_mrd",
Expand Down
41 changes: 41 additions & 0 deletions src/snake/mrd_utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,48 @@ def get_coil_cov(self) -> NDArray | None:
"""Load the coil covariance from the dataset."""
return self._get_image_data("coil_cov")

class SliceDataloader(MRDLoader):
"""Load slice MRD files k-space frames iteratively."""

def __init__(self,
frame_dl: MRDLoader,
):
super().__init__(
filename=frame_dl._filename,
dataset_name=frame_dl._dataset_name,
writeable=frame_dl._writeable,
swmr=frame_dl._swmr,
squeeze_dims=frame_dl._squeeze_dims
)
self.frame = frame_dl
self.is_cartesian = isinstance(self.frame, CartesianFrameDataLoader)
self.is_non_cartesian = isinstance(self.frame, NonCartesianFrameDataLoader)
self.get_kspace_frame = self._get_kspace_frame

def __getattr__(self, name: str) -> Any:
return getattr(self.frame, name)

@property
def n_frames(self) -> int:
"""Number of frames."""
return self.frame.n_acquisition

@property
def shape(self) -> tuple[int, int]:
"""Shape of the volume."""
return self.frame.shape[:2]

def _get_kspace_frame(
self, idx: int, shot_dim: bool = False
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
"""Get the k-space frame."""
n_acq_per_frame = self.frame.n_acquisition // self.frame.n_frames
traj, data = self.frame.get_kspace_frame(idx // self.frame.n_shots)
traj = traj.reshape(n_acq_per_frame, -1, 3)
data = data.reshape(self.frame.n_coils, n_acq_per_frame, -1)

return traj[idx%self.frame.n_shots,:,:2], data[:,idx%self.frame.n_shots,:]

class CartesianFrameDataLoader(MRDLoader):
"""Load cartesian MRD files k-space frames iteratively.

Expand Down
2 changes: 1 addition & 1 deletion src/snake/toolkit/cli/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def reconstruction(cfg: DictConfig) -> None:
raise ValueError("No dynamic data found matching waveform name")

bold_signal = good_d.data[0]
bold_sample_time = np.arange(len(bold_signal)) * local_sim_conf.seq.TR / 1000
bold_sample_time = np.arange(len(bold_signal)) * sim_conf.seq.TR / 1000
del phantom
del dyn_datas
gc.collect()
Expand Down
5 changes: 1 addition & 4 deletions src/snake/toolkit/reconstructors/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,11 @@ def init_nufft(
shape = data_loader.shape
traj, _ = data_loader.get_kspace_frame(0)

if data_loader.slice_2d:
shape = data_loader.shape[:2]
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])[0, :, :2]

kwargs = dict(
shape=shape,
n_coils=data_loader.n_coils,
smaps=smaps,
smaps=smaps[...,0].copy() if data_loader.slice_2d else smaps,
)
if density_compensation is False:
kwargs["density"] = None
Expand Down
40 changes: 25 additions & 15 deletions src/snake/toolkit/reconstructors/pysap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CartesianFrameDataLoader,
MRDLoader,
NonCartesianFrameDataLoader,
SliceDataloader,
)
from snake.core.parallel import (
ArrayProps,
Expand Down Expand Up @@ -83,12 +84,13 @@ def reconstruct(
) -> NDArray:
"""Reconstruct data with zero-filled method."""
with data_loader:
if isinstance(data_loader, CartesianFrameDataLoader):
return self._reconstruct_cartesian(data_loader)
elif isinstance(data_loader, NonCartesianFrameDataLoader):
return self._reconstruct_nufft(data_loader)
if isinstance(data_loader, SliceDataloader):
reconstruct_method = self._reconstruct_nufft if data_loader.is_non_cartesian else self._reconstruct_cartesian
elif isinstance(data_loader, CartesianFrameDataLoader | NonCartesianFrameDataLoader):
reconstruct_method = self._reconstruct_cartesian if isinstance(data_loader, CartesianFrameDataLoader) else self._reconstruct_nufft
else:
raise ValueError("Unknown dataloader")
return reconstruct_method(data_loader)

def _reconstruct_cartesian(
self,
Expand Down Expand Up @@ -144,19 +146,19 @@ def _reconstruct_nufft(
final_images = np.empty(
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
)
smaps = data_loader.get_smaps()

for i in tqdm(range(data_loader.n_frames)):
traj, data = data_loader.get_kspace_frame(i)
if data_loader.slice_2d:
nufft_operator.samples = traj.reshape(
data_loader.n_shots, -1, traj.shape[-1]
)[0, :, :2]
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
for j in range(data.shape[1]):
final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))
else:
nufft_operator.samples = traj
final_images[i] = abs(nufft_operator.adj_op(data))
#fix: density compensation should update every frame when rotating(writer: dyn_traj?)
if smaps is not None and data_loader.slice_2d:
nufft_operator.smaps = smaps[...,i % data_loader.frame.n_shots]
nufft_operator.samples = traj
final_images[i] = abs(nufft_operator.adj_op(data))
if isinstance(data_loader, SliceDataloader):
final_images = np.moveaxis(final_images.reshape(
(data_loader.frame.n_frames, -1, *final_images.shape[-2:])
), 1, -1)
return final_images


Expand Down Expand Up @@ -271,7 +273,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
samples=traj,
shape=data_loader.shape,
n_coils=data_loader.n_coils,
smaps=smaps,
smaps=smaps[..., 0].copy() if data_loader.slice_2d else smaps,
# smaps=xp.array(smaps) if smaps is not None else None,
density=density_compensation,
squeeze_dims=True,
Expand Down Expand Up @@ -331,6 +333,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:

pbar_frames.update(1)
if self.restart_strategy != RestartStrategy.REFINE:
if isinstance(data_loader, SliceDataloader):
final_estimate = np.moveaxis(final_estimate.reshape(
(data_loader.frame.n_frames, -1, *final_estimate.shape[-2:])
), 1, -1)
return final_estimate
# else, we do a second pass on the data using the last iteration as a slotion.
pbar_frames.reset()
Expand All @@ -353,6 +359,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
else:
final_estimate[i, ...] = abs(x_iter)
pbar_frames.update(1)
if isinstance(data_loader, SliceDataloader):
final_estimate = np.moveaxis(final_estimate.reshape(
(data_loader.frame.n_frames, -1, *final_estimate.shape[-2:])
), 1, -1)
return final_estimate

def _reconstruct_frame(
Expand Down
Loading