From f6198ca68433927cf5d0a96e050168499cd9e179 Mon Sep 17 00:00:00 2001 From: alineyyy Date: Fri, 14 Mar 2025 14:33:37 +0100 Subject: [PATCH 1/2] slicedataloader for zerofilled --- src/snake/mrd_utils/__init__.py | 2 + src/snake/mrd_utils/loader.py | 41 ++++++++++++++++++++ src/snake/toolkit/cli/reconstruction.py | 2 +- src/snake/toolkit/reconstructors/fourier.py | 5 +-- src/snake/toolkit/reconstructors/pysap.py | 42 +++++++++++++-------- 5 files changed, 71 insertions(+), 21 deletions(-) diff --git a/src/snake/mrd_utils/__init__.py b/src/snake/mrd_utils/__init__.py index e5f273c6..bdcce696 100644 --- a/src/snake/mrd_utils/__init__.py +++ b/src/snake/mrd_utils/__init__.py @@ -4,6 +4,7 @@ CartesianFrameDataLoader, MRDLoader, NonCartesianFrameDataLoader, + SliceDataloader, parse_sim_conf, parse_waveform_information, read_mrd_header, @@ -16,6 +17,7 @@ "MRDLoader", "CartesianFrameDataLoader", "NonCartesianFrameDataLoader", + "SliceDataloader", "parse_sim_conf", "parse_waveform_information", "make_base_mrd", diff --git a/src/snake/mrd_utils/loader.py b/src/snake/mrd_utils/loader.py index 751daeec..fdc19e20 100644 --- a/src/snake/mrd_utils/loader.py +++ b/src/snake/mrd_utils/loader.py @@ -378,7 +378,48 @@ def get_coil_cov(self) -> NDArray | None: """Load the coil covariance from the dataset.""" return self._get_image_data("coil_cov") +class SliceDataloader(MRDLoader): + """Load slice MRD files k-space frames iteratively.""" + + def __init__(self, + frame_dl: MRDLoader, + ): + super().__init__( + filename=frame_dl._filename, + dataset_name=frame_dl._dataset_name, + writeable=frame_dl._writeable, + swmr=frame_dl._swmr, + squeeze_dims=frame_dl._squeeze_dims + ) + self.frame = frame_dl + self.is_cartesian = isinstance(self.frame, CartesianFrameDataLoader) + self.is_non_cartesian = isinstance(self.frame, NonCartesianFrameDataLoader) + self.get_kspace_frame = self._get_kspace_frame + + def __getattr__(self, name: str) -> Any: + return getattr(self.frame, name) + + @property + def n_frames(self) -> int: + """Number of frames.""" + return self.frame.n_acquisition + + @property + def shape(self) -> tuple[int, int]: + """Shape of the volume.""" + return self.frame.shape[:2] + + def _get_kspace_frame( + self, idx: int, shot_dim: bool = False + ) -> tuple[NDArray[np.float32], NDArray[np.complex64]]: + """Get the k-space frame.""" + n_acq_per_frame = self.frame.n_acquisition // self.frame.n_frames + traj, data = self.frame.get_kspace_frame(idx // self.frame.n_shots) + traj = traj.reshape(n_acq_per_frame, -1, 3) + data = data.reshape(self.frame.n_coils, n_acq_per_frame, -1) + return traj[idx%self.frame.n_shots,:,:2], data[:,idx%self.frame.n_shots,:] + class CartesianFrameDataLoader(MRDLoader): """Load cartesian MRD files k-space frames iteratively. diff --git a/src/snake/toolkit/cli/reconstruction.py b/src/snake/toolkit/cli/reconstruction.py index 6e51b672..2822a760 100644 --- a/src/snake/toolkit/cli/reconstruction.py +++ b/src/snake/toolkit/cli/reconstruction.py @@ -81,7 +81,7 @@ def reconstruction(cfg: DictConfig) -> None: raise ValueError("No dynamic data found matching waveform name") bold_signal = good_d.data[0] - bold_sample_time = np.arange(len(bold_signal)) * local_sim_conf.seq.TR / 1000 + bold_sample_time = np.arange(len(bold_signal)) * sim_conf.seq.TR / 1000 del phantom del dyn_datas gc.collect() diff --git a/src/snake/toolkit/reconstructors/fourier.py b/src/snake/toolkit/reconstructors/fourier.py index 8ebb56c5..ba559727 100644 --- a/src/snake/toolkit/reconstructors/fourier.py +++ b/src/snake/toolkit/reconstructors/fourier.py @@ -60,14 +60,11 @@ def init_nufft( shape = data_loader.shape traj, _ = data_loader.get_kspace_frame(0) - if data_loader.slice_2d: - shape = data_loader.shape[:2] - traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])[0, :, :2] kwargs = dict( shape=shape, n_coils=data_loader.n_coils, - smaps=smaps, + smaps=smaps.squeeze() if data_loader.slice_2d else smaps, ) if density_compensation is False: kwargs["density"] = None diff --git a/src/snake/toolkit/reconstructors/pysap.py b/src/snake/toolkit/reconstructors/pysap.py index 66d68a39..4cc1a654 100644 --- a/src/snake/toolkit/reconstructors/pysap.py +++ b/src/snake/toolkit/reconstructors/pysap.py @@ -14,6 +14,7 @@ CartesianFrameDataLoader, MRDLoader, NonCartesianFrameDataLoader, + SliceDataloader, ) from snake.core.parallel import ( ArrayProps, @@ -83,12 +84,13 @@ def reconstruct( ) -> NDArray: """Reconstruct data with zero-filled method.""" with data_loader: - if isinstance(data_loader, CartesianFrameDataLoader): - return self._reconstruct_cartesian(data_loader) - elif isinstance(data_loader, NonCartesianFrameDataLoader): - return self._reconstruct_nufft(data_loader) + if isinstance(data_loader, SliceDataloader): + reconstruct_method = self._reconstruct_nufft if data_loader.is_non_cartesian else self._reconstruct_cartesian + elif isinstance(data_loader, CartesianFrameDataLoader | NonCartesianFrameDataLoader): + reconstruct_method = self._reconstruct_cartesian if isinstance(data_loader, CartesianFrameDataLoader) else self._reconstruct_nufft else: raise ValueError("Unknown dataloader") + return reconstruct_method(data_loader) def _reconstruct_cartesian( self, @@ -144,19 +146,19 @@ def _reconstruct_nufft( final_images = np.empty( (data_loader.n_frames, *data_loader.shape), dtype=np.float32 ) + smaps = data_loader.get_smaps() for i in tqdm(range(data_loader.n_frames)): traj, data = data_loader.get_kspace_frame(i) - if data_loader.slice_2d: - nufft_operator.samples = traj.reshape( - data_loader.n_shots, -1, traj.shape[-1] - )[0, :, :2] - data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1)) - for j in range(data.shape[1]): - final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j])) - else: - nufft_operator.samples = traj - final_images[i] = abs(nufft_operator.adj_op(data)) + #fix: density compensation should update every frame when rotating(writer: dyn_traj?) + if smaps is not None and data_loader.slice_2d: + nufft_operator.smaps = smaps[...,i % data_loader.frame.n_shots] + nufft_operator.samples = traj + final_images[i] = abs(nufft_operator.adj_op(data)) + if data_loader.slice_2d: + final_images = np.moveaxis(final_images.reshape( + (data_loader.frame.n_frames, -1, *final_images.shape[-2:]) + ), 1, -1) return final_images @@ -271,7 +273,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: samples=traj, shape=data_loader.shape, n_coils=data_loader.n_coils, - smaps=smaps, + smaps=smaps.squeeze() if data_loader.slice_2d else smaps, # smaps=xp.array(smaps) if smaps is not None else None, density=density_compensation, squeeze_dims=True, @@ -331,6 +333,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: pbar_frames.update(1) if self.restart_strategy != RestartStrategy.REFINE: + if data_loader.slice_2d: + final_estimate = np.moveaxis(final_estimate.reshape( + (data_loader.frame.n_frames, -1, *final_estimate.shape[-2:]) + ), 1, -1) return final_estimate # else, we do a second pass on the data using the last iteration as a slotion. pbar_frames.reset() @@ -353,6 +359,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: else: final_estimate[i, ...] = abs(x_iter) pbar_frames.update(1) + if data_loader.slice_2d: + final_estimate = np.moveaxis(final_estimate.reshape( + (data_loader.frame.n_frames, -1, *final_estimate.shape[-2:]) + ), 1, -1) return final_estimate def _reconstruct_frame( @@ -373,7 +383,7 @@ def _reconstruct_frame( grad_op=grad_op, linear_op=copy.deepcopy(self.space_linear_op), prox_op=copy.deepcopy(self.space_prox_op), - x_init=x_init, + x_init=x_init.get(), synthesis_init=False, metric_kwargs={}, compute_backend=self.compute_backend, From 42a6d075824c0be5938d4ac4e5c87c0c78c0d01a Mon Sep 17 00:00:00 2001 From: alineyyy Date: Fri, 14 Mar 2025 17:59:20 +0100 Subject: [PATCH 2/2] wip: bugs in single slice (cs) --- src/snake/toolkit/reconstructors/fourier.py | 2 +- src/snake/toolkit/reconstructors/pysap.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/snake/toolkit/reconstructors/fourier.py b/src/snake/toolkit/reconstructors/fourier.py index ba559727..e4163292 100644 --- a/src/snake/toolkit/reconstructors/fourier.py +++ b/src/snake/toolkit/reconstructors/fourier.py @@ -64,7 +64,7 @@ def init_nufft( kwargs = dict( shape=shape, n_coils=data_loader.n_coils, - smaps=smaps.squeeze() if data_loader.slice_2d else smaps, + smaps=smaps[...,0].copy() if data_loader.slice_2d else smaps, ) if density_compensation is False: kwargs["density"] = None diff --git a/src/snake/toolkit/reconstructors/pysap.py b/src/snake/toolkit/reconstructors/pysap.py index 4cc1a654..871e990f 100644 --- a/src/snake/toolkit/reconstructors/pysap.py +++ b/src/snake/toolkit/reconstructors/pysap.py @@ -155,7 +155,7 @@ def _reconstruct_nufft( nufft_operator.smaps = smaps[...,i % data_loader.frame.n_shots] nufft_operator.samples = traj final_images[i] = abs(nufft_operator.adj_op(data)) - if data_loader.slice_2d: + if isinstance(data_loader, SliceDataloader): final_images = np.moveaxis(final_images.reshape( (data_loader.frame.n_frames, -1, *final_images.shape[-2:]) ), 1, -1) @@ -273,7 +273,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: samples=traj, shape=data_loader.shape, n_coils=data_loader.n_coils, - smaps=smaps.squeeze() if data_loader.slice_2d else smaps, + smaps=smaps[..., 0].copy() if data_loader.slice_2d else smaps, # smaps=xp.array(smaps) if smaps is not None else None, density=density_compensation, squeeze_dims=True, @@ -333,7 +333,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: pbar_frames.update(1) if self.restart_strategy != RestartStrategy.REFINE: - if data_loader.slice_2d: + if isinstance(data_loader, SliceDataloader): final_estimate = np.moveaxis(final_estimate.reshape( (data_loader.frame.n_frames, -1, *final_estimate.shape[-2:]) ), 1, -1) @@ -359,7 +359,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: else: final_estimate[i, ...] = abs(x_iter) pbar_frames.update(1) - if data_loader.slice_2d: + if isinstance(data_loader, SliceDataloader): final_estimate = np.moveaxis(final_estimate.reshape( (data_loader.frame.n_frames, -1, *final_estimate.shape[-2:]) ), 1, -1) @@ -383,7 +383,7 @@ def _reconstruct_frame( grad_op=grad_op, linear_op=copy.deepcopy(self.space_linear_op), prox_op=copy.deepcopy(self.space_prox_op), - x_init=x_init.get(), + x_init=x_init, synthesis_init=False, metric_kwargs={}, compute_backend=self.compute_backend,