diff --git a/.github/actions/load-data/action.yml b/.github/actions/load-data/action.yml index b78fe20..0a77efc 100644 --- a/.github/actions/load-data/action.yml +++ b/.github/actions/load-data/action.yml @@ -31,4 +31,4 @@ runs: name: Download datasets from Google Drive shell: bash run: | - rclone copy remote:"SampleData" ./testing_data --drive-shared-with-me \ No newline at end of file + rclone copy remote:"SampleData" ./testing_data --drive-shared-with-me diff --git a/.github/workflows/all_os_versions.txt b/.github/workflows/all_os_versions.txt index 1bcf5cd..d9eca36 100644 --- a/.github/workflows/all_os_versions.txt +++ b/.github/workflows/all_os_versions.txt @@ -1 +1 @@ -["ubuntu-latest", "macos-latest", "windows-2022"] \ No newline at end of file +["ubuntu-latest", "macos-latest", "windows-2022"] diff --git a/.github/workflows/all_python_versions.txt b/.github/workflows/all_python_versions.txt index 350c415..7a7daf6 100644 --- a/.github/workflows/all_python_versions.txt +++ b/.github/workflows/all_python_versions.txt @@ -1 +1 @@ -["3.10", "3.11", "3.12", "3.13"] \ No newline at end of file +["3.10", "3.11", "3.12", "3.13"] diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index ae8fe85..370b88f 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -50,7 +50,7 @@ jobs: run: | python -m pip install "." python -m pip install --group test - + - name: Prepare data for tests uses: ./.github/actions/load-data with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2c4259b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + exclude: ^docs/ + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.0 + hooks: + - id: ruff + args: [ --fix ] + +- repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + additional_dependencies: + - tomli diff --git a/CHANGELOG.md b/CHANGELOG.md index 82cdfe9..0f15132 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,4 +63,4 @@ # GuPPy-v1.1.1 (July 6th, 2021) -It is the GuPPy's first release for people to use and give us feedbacks on it \ No newline at end of file +It is the GuPPy's first release for people to use and give us feedbacks on it diff --git a/MANIFEST.in b/MANIFEST.in index 603ac05..e24a868 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ include *.md -recursive-include src *.ipynb \ No newline at end of file +recursive-include src *.ipynb diff --git a/README.md b/README.md index b564104..851839a 100644 --- a/README.md +++ b/README.md @@ -14,15 +14,15 @@ To install the latest stable release of GuPPy through PyPI, simply run the follo pip install guppy ``` -We recommend that you install the package inside a [virtual environment](https://docs.python.org/3/tutorial/venv.html). -A simple way of doing this is to use a [conda environment](https://docs.conda.io/projects/conda/en/latest/user-guide/concepts/environments.html) from the `conda` package manager ([installation instructions](https://docs.conda.io/en/latest/miniconda.html)). +We recommend that you install the package inside a [virtual environment](https://docs.python.org/3/tutorial/venv.html). +A simple way of doing this is to use a [conda environment](https://docs.conda.io/projects/conda/en/latest/user-guide/concepts/environments.html) from the `conda` package manager ([installation instructions](https://docs.conda.io/en/latest/miniconda.html)). Detailed instructions on how to use conda environments can be found in their [documentation](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html). ### Installation from GitHub -To install the latest development version of GuPPy from GitHub, you can clone the repository and install the package manually. -This has the advantage of allowing you to access the latest features and bug fixes that may not yet be available in the stable release. -To install the conversion from GitHub you will need to use `git` ([installation instructions](https://github.com/git-guides/install-git)). +To install the latest development version of GuPPy from GitHub, you can clone the repository and install the package manually. +This has the advantage of allowing you to access the latest features and bug fixes that may not yet be available in the stable release. +To install the conversion from GitHub you will need to use `git` ([installation instructions](https://github.com/git-guides/install-git)). From a terminal or command prompt, execute the following commands: 1. Clone the repository: @@ -88,5 +88,3 @@ This will launch the GuPPy user interface, where you can begin analyzing your fi - [Gabriela Lopez](https://github.com/glopez924) - [Talia Lerner](https://github.com/talialerner) - [Paul Adkisson](https://github.com/pauladkisson) - - diff --git a/pyproject.toml b/pyproject.toml index 0d6d6ca..4527fba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,12 +21,12 @@ authors = [ license = { file = "LICENSE" } keywords = [ - "neuroscience", - "fiber-photometry", - "calcium-imaging", - "data-analysis", - "gui", - "visualization", + "neuroscience", + "fiber-photometry", + "calcium-imaging", + "data-analysis", + "gui", + "visualization", "signal-processing", ] classifiers = [ @@ -64,6 +64,10 @@ test = [ "pytest-xdist" # Runs tests on parallel ] +dev = [ + "pre-commit", +] + [project.scripts] guppy = "guppy.main:main" @@ -73,3 +77,52 @@ guppy = "guppy.main:main" [tool.setuptools.packages.find] where = ["src"] + +[tool.black] +line-length = 120 +target-version = ['py310', 'py311', 'py312', 'py313'] +include = '\.pyi?$' +extend-exclude = ''' +/( + \.toml + |\.yml + |\.txt + |\.sh + |\.git + |\.ini + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + + +[tool.ruff] + +[tool.ruff.lint] +select = [ + "F401", # Unused import + "I", # All isort rules + # TODO: add docstring rules after adding comprehensive docstrings to all public functions/classes + # "D101", # Missing docstring in public class + # "D103", # Missing docstring in public function + "UP006", # non-pep585 annotation (tuple, list -> tuple, list) + "UP007" # non-pep604 annotation (Union[x, y] -> X | Y ) +] +fixable = ["ALL"] + +[tool.ruff.lint.per-file-ignores] +"**__init__.py" = ["F401", "I"] # We are not enforcing import rules in __init__'s + +[tool.ruff.lint.isort] +relative-imports-order = "closest-to-furthest" +known-first-party = ["guppy"] + + +[tool.codespell] +skip = '.git*,*.pdf,*.css,*.svg' +check-hidden = true +ignore-words-list = 'assertin,sortings' diff --git a/src/guppy/combineDataFn.py b/src/guppy/combineDataFn.py index a420a79..d832dba 100755 --- a/src/guppy/combineDataFn.py +++ b/src/guppy/combineDataFn.py @@ -1,169 +1,178 @@ -import os -import glob -import json -import numpy as np -import h5py -import re import fnmatch import logging +import os +import re + +import h5py +import numpy as np logger = logging.getLogger(__name__) -def find_files(path, glob_path, ignore_case = False): - rule = re.compile(fnmatch.translate(glob_path), re.IGNORECASE) if ignore_case \ - else re.compile(fnmatch.translate(glob_path)) - no_bytes_path = os.listdir(os.path.expanduser(path)) - str_path = [] - - # converting byte object to string - for x in no_bytes_path: - try: - str_path.append(x.decode('utf-8')) - except: - str_path.append(x) - - return [os.path.join(path,n) for n in str_path if rule.match(n)] + +def find_files(path, glob_path, ignore_case=False): + rule = ( + re.compile(fnmatch.translate(glob_path), re.IGNORECASE) + if ignore_case + else re.compile(fnmatch.translate(glob_path)) + ) + no_bytes_path = os.listdir(os.path.expanduser(path)) + str_path = [] + + # converting byte object to string + for x in no_bytes_path: + try: + str_path.append(x.decode("utf-8")) + except: + str_path.append(x) + + return [os.path.join(path, n) for n in str_path if rule.match(n)] + def read_hdf5(event, filepath, key): - if event: - op = os.path.join(filepath, event+'.hdf5') - else: - op = filepath + if event: + op = os.path.join(filepath, event + ".hdf5") + else: + op = filepath - if os.path.exists(op): - with h5py.File(op, 'r') as f: - arr = np.asarray(f[key]) - else: - raise Exception('{}.hdf5 file does not exist'.format(event)) + if os.path.exists(op): + with h5py.File(op, "r") as f: + arr = np.asarray(f[key]) + else: + raise Exception("{}.hdf5 file does not exist".format(event)) + + return arr - return arr def write_hdf5(data, event, filepath, key): - op = os.path.join(filepath, event+'.hdf5') - - if not os.path.exists(op): - with h5py.File(op, 'w') as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - else: - with h5py.File(op, 'r+') as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + op = os.path.join(filepath, event + ".hdf5") + + if not os.path.exists(op): + with h5py.File(op, "w") as f: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) + else: + with h5py.File(op, "r+") as f: + if key in list(f.keys()): + if type(data) is np.ndarray: + f[key].resize(data.shape) + arr = f[key] + arr[:] = data + else: + arr = f[key] + arr = data + else: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) def decide_naming_convention(filepath): - path_1 = find_files(filepath, 'control*', ignore_case=True) #glob.glob(os.path.join(filepath, 'control*')) - - path_2 = find_files(filepath, 'signal*', ignore_case=True) #glob.glob(os.path.join(filepath, 'signal*')) - - path = sorted(path_1 + path_2, key=str.casefold) + path_1 = find_files(filepath, "control*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) + + path_2 = find_files(filepath, "signal*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) + + path = sorted(path_1 + path_2, key=str.casefold) - if len(path)%2 != 0: - raise Exception('There are not equal number of Control and Signal data') - - path = np.asarray(path).reshape(2,-1) + if len(path) % 2 != 0: + raise Exception("There are not equal number of Control and Signal data") - return path + path = np.asarray(path).reshape(2, -1) + + return path def eliminateData(filepath, timeForLightsTurnOn, event, sampling_rate, naming): - - arr = np.array([]) - ts_arr = np.array([]) - for i in range(len(filepath)): - ts = read_hdf5('timeCorrection_'+naming, filepath[i], 'timestampNew') - data = read_hdf5(event, filepath[i], 'data').reshape(-1) - - #index = np.where((ts>coords[i,0]) & (tscoords[i,0]) & (ts 1: - mean = np.nanmean(psth[:,single_trials_index], axis=1).reshape(-1,1) - err = np.nanstd(psth[:,single_trials_index], axis=1)/math.sqrt(psth[:,single_trials_index].shape[1]) - err = err.reshape(-1,1) - psth = np.hstack((psth,mean)) - psth = np.hstack((psth, err)) - #timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) - #psth = np.hstack((psth, timestamps)) - try: - ts = read_hdf5(event, filepath, 'ts') - ts = np.append(ts, ['mean', 'err']) - except: - ts = None - - if len(columns)==0: - df = pd.DataFrame(psth, index=None, columns=ts, dtype='float32') - else: - columns = np.asarray(columns) - columns = np.append(columns, ['mean', 'err']) - df = pd.DataFrame(psth, index=None, columns=columns, dtype='float32') - - df.to_hdf(op, key='df', mode='w') + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + + # check if file already exists + # if os.path.exists(op): + # return 0 + + # removing psth binned trials + columns = list(np.array(columns, dtype="str")) + regex = re.compile("bin_*") + single_trials_index = [i for i in range(len(columns)) if not regex.match(columns[i])] + single_trials_index = [i for i in range(len(columns)) if columns[i] != "timestamps"] + + psth = psth.T + if psth.ndim > 1: + mean = np.nanmean(psth[:, single_trials_index], axis=1).reshape(-1, 1) + err = np.nanstd(psth[:, single_trials_index], axis=1) / math.sqrt(psth[:, single_trials_index].shape[1]) + err = err.reshape(-1, 1) + psth = np.hstack((psth, mean)) + psth = np.hstack((psth, err)) + # timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) + # psth = np.hstack((psth, timestamps)) + try: + ts = read_hdf5(event, filepath, "ts") + ts = np.append(ts, ["mean", "err"]) + except: + ts = None + + if len(columns) == 0: + df = pd.DataFrame(psth, index=None, columns=ts, dtype="float32") + else: + columns = np.asarray(columns) + columns = np.append(columns, ["mean", "err"]) + df = pd.DataFrame(psth, index=None, columns=columns, dtype="float32") + + df.to_hdf(op, key="df", mode="w") def getCorrCombinations(filepath, inputParameters): - selectForComputePsth = inputParameters['selectForComputePsth'] - if selectForComputePsth=='z_score': - path = glob.glob(os.path.join(filepath, 'z_score_*')) - elif selectForComputePsth=='dff': - path = glob.glob(os.path.join(filepath, 'dff_*')) + selectForComputePsth = inputParameters["selectForComputePsth"] + if selectForComputePsth == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForComputePsth == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) else: - path = glob.glob(os.path.join(filepath, 'z_score_*')) + glob.glob(os.path.join(filepath, 'dff_*')) - + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + names = list() type = list() for i in range(len(path)): - basename = (os.path.basename(path[i])).split('.')[0] - names.append(basename.split('_')[-1]) - type.append((os.path.basename(path[i])).split('.')[0].split('_'+names[-1], 1)[0]) - + basename = (os.path.basename(path[i])).split(".")[0] + names.append(basename.split("_")[-1]) + type.append((os.path.basename(path[i])).split(".")[0].split("_" + names[-1], 1)[0]) + names = list(np.unique(np.array(names))) type = list(np.unique(np.array(type))) corr_info = list() - if len(names)<=1: + if len(names) <= 1: logger.info("Cross-correlation cannot be computed because only one signal is present.") return corr_info, type - elif len(names)==2: + elif len(names) == 2: corr_info = names else: corr_info = names corr_info.append(names[0]) - + return corr_info, type - def helperCrossCorrelation(arr_A, arr_B, sample_rate): cross_corr = list() - for (a, b) in zip(arr_A, arr_B): + for a, b in zip(arr_A, arr_B): if np.isnan(a).any() or np.isnan(b).any(): - corr = signal.correlate(a, b, method='direct') + corr = signal.correlate(a, b, method="direct") else: corr = signal.correlate(a, b) - corr_norm = corr/ np.max(np.abs(corr)) + corr_norm = corr / np.max(np.abs(corr)) cross_corr.append(corr_norm) lag = signal.correlation_lags(len(a), len(b)) - lag_msec = np.array(lag / sample_rate, dtype='float32') - - cross_corr_arr = np.array(cross_corr, dtype='float32') - lag_msec = lag_msec.reshape(1,-1) + lag_msec = np.array(lag / sample_rate, dtype="float32") + + cross_corr_arr = np.array(cross_corr, dtype="float32") + lag_msec = lag_msec.reshape(1, -1) cross_corr_arr = np.concatenate((cross_corr_arr, lag_msec), axis=0) return cross_corr_arr def computeCrossCorrelation(filepath, event, inputParameters): - isCompute = inputParameters['computeCorr'] - removeArtifacts = inputParameters['removeArtifacts'] - artifactsRemovalMethod = inputParameters['artifactsRemovalMethod'] - if isCompute==True: - if removeArtifacts==True and artifactsRemovalMethod=='concatenate': - raise Exception("For cross-correlation, when removeArtifacts is True, artifacts removal method\ - should be replace with NaNs and not concatenate") + isCompute = inputParameters["computeCorr"] + removeArtifacts = inputParameters["removeArtifacts"] + artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] + if isCompute == True: + if removeArtifacts == True and artifactsRemovalMethod == "concatenate": + raise Exception( + "For cross-correlation, when removeArtifacts is True, artifacts removal method\ + should be replace with NaNs and not concatenate" + ) corr_info, type = getCorrCombinations(filepath, inputParameters) - if 'control' in event.lower() or 'signal' in event.lower(): + if "control" in event.lower() or "signal" in event.lower(): return else: for i in range(1, len(corr_info)): logger.debug(f"Computing cross-correlation for event {event}...") for j in range(len(type)): - psth_a = read_Df(filepath, event+'_'+corr_info[i-1], type[j]+'_'+corr_info[i-1]) - psth_b = read_Df(filepath, event+'_'+corr_info[i], type[j]+'_'+corr_info[i]) - sample_rate = 1/(psth_a['timestamps'][1]-psth_a['timestamps'][0]) - psth_a = psth_a.drop(columns=['timestamps', 'err', 'mean']) - psth_b = psth_b.drop(columns=['timestamps', 'err', 'mean']) + psth_a = read_Df(filepath, event + "_" + corr_info[i - 1], type[j] + "_" + corr_info[i - 1]) + psth_b = read_Df(filepath, event + "_" + corr_info[i], type[j] + "_" + corr_info[i]) + sample_rate = 1 / (psth_a["timestamps"][1] - psth_a["timestamps"][0]) + psth_a = psth_a.drop(columns=["timestamps", "err", "mean"]) + psth_b = psth_b.drop(columns=["timestamps", "err", "mean"]) cols_a, cols_b = np.array(psth_a.columns), np.array(psth_b.columns) - if np.intersect1d(cols_a, cols_b).size>0: + if np.intersect1d(cols_a, cols_b).size > 0: cols = list(np.intersect1d(cols_a, cols_b)) else: cols = list(cols_a) arr_A, arr_B = np.array(psth_a).T, np.array(psth_b).T cross_corr = helperCrossCorrelation(arr_A, arr_B, sample_rate) - cols.append('timestamps') - create_Df(make_dir(filepath), 'corr_'+event, type[j]+'_'+corr_info[i-1]+'_'+corr_info[i], cross_corr, cols) + cols.append("timestamps") + create_Df( + make_dir(filepath), + "corr_" + event, + type[j] + "_" + corr_info[i - 1] + "_" + corr_info[i], + cross_corr, + cols, + ) logger.info(f"Cross-correlation for event {event} computed.") diff --git a/src/guppy/computePsth.py b/src/guppy/computePsth.py index b08764c..671d1d3 100755 --- a/src/guppy/computePsth.py +++ b/src/guppy/computePsth.py @@ -1,773 +1,857 @@ # coding: utf-8 -import os -import sys -import json import glob -import re -import h5py +import json +import logging import math +import multiprocessing as mp +import os +import re import subprocess -import numpy as np -import pandas as pd +import sys +from collections import OrderedDict from itertools import repeat -import multiprocessing as mp + +import h5py +import numpy as np +import pandas as pd from scipy import signal as ss -from collections import OrderedDict -from pathlib import Path + +from .computeCorr import computeCrossCorrelation, getCorrCombinations, make_dir from .preprocess import get_all_stores_for_combining_data -from .computeCorr import computeCrossCorrelation -from .computeCorr import getCorrCombinations -from .computeCorr import make_dir -import logging logger = logging.getLogger(__name__) + def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths)-set(removePaths)) + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + def writeToFile(value: str): - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(value) + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(value) + # function to read hdf5 file def read_hdf5(event, filepath, key): - if event: - event = event.replace("\\","_") - event = event.replace("/","_") - op = os.path.join(filepath, event+'.hdf5') - else: - op = filepath + if event: + event = event.replace("\\", "_") + event = event.replace("/", "_") + op = os.path.join(filepath, event + ".hdf5") + else: + op = filepath - if os.path.exists(op): - with h5py.File(op, 'r') as f: - arr = np.asarray(f[key]) - else: - raise Exception('{}.hdf5 file does not exist'.format(event)) + if os.path.exists(op): + with h5py.File(op, "r") as f: + arr = np.asarray(f[key]) + else: + raise Exception("{}.hdf5 file does not exist".format(event)) + + return arr - return arr # function to write hdf5 file def write_hdf5(data, event, filepath, key): - event = event.replace("\\","_") - event = event.replace("/","_") - op = os.path.join(filepath, event+'.hdf5') - - # if file does not exist create a new file - if not os.path.exists(op): - with h5py.File(op, 'w') as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None, ), chunks=True) - else: - f.create_dataset(key, data=data) - # if file already exists, append data to it or add a new key to it - else: - with h5py.File(op, 'r+') as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - f.create_dataset(key, data=data, maxshape=(None, ), chunks=True) + event = event.replace("\\", "_") + event = event.replace("/", "_") + op = os.path.join(filepath, event + ".hdf5") + + # if file does not exist create a new file + if not os.path.exists(op): + with h5py.File(op, "w") as f: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) + # if file already exists, append data to it or add a new key to it + else: + with h5py.File(op, "r+") as f: + if key in list(f.keys()): + if type(data) is np.ndarray: + f[key].resize(data.shape) + arr = f[key] + arr[:] = data + else: + arr = f[key] + arr = data + else: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) def create_Df_area_peak(filepath, arr, name, index=[]): - op = os.path.join(filepath, 'peak_AUC_'+name+'.h5') - dirname = os.path.dirname(filepath) + op = os.path.join(filepath, "peak_AUC_" + name + ".h5") + dirname = os.path.dirname(filepath) + + df = pd.DataFrame(arr, index=index) - df = pd.DataFrame(arr, index=index) + df.to_hdf(op, key="df", mode="w") - df.to_hdf(op, key='df', mode='w') def read_Df_area_peak(filepath, name): - op = os.path.join(filepath, 'peak_AUC_'+name+'.h5') - df = pd.read_hdf(op, key='df', mode='r') + op = os.path.join(filepath, "peak_AUC_" + name + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df - return df def create_csv_area_peak(filepath, arr, name, index=[]): - op = os.path.join(filepath, 'peak_AUC_'+name+'.csv') - df = pd.DataFrame(arr, index=index) - - df.to_csv(op) + op = os.path.join(filepath, "peak_AUC_" + name + ".csv") + df = pd.DataFrame(arr, index=index) + + df.to_csv(op) # function to create dataframe for each event PSTH and save it to h5 file def create_Df(filepath, event, name, psth, columns=[]): - event = event.replace("\\","_") - event = event.replace("/","_") - if name: - op = os.path.join(filepath, event+'_{}.h5'.format(name)) - else: - op = os.path.join(filepath, event+'.h5') - - # check if file already exists - #if os.path.exists(op): - # return 0 - - # removing psth binned trials - columns = np.array(columns, dtype='str') - regex = re.compile('bin_*') - single_trials = columns[[i for i in range(len(columns)) if not regex.match(columns[i])]] - single_trials_index = [i for i in range(len(single_trials)) if single_trials[i]!='timestamps'] - - psth = psth.T - if psth.ndim > 1: - mean = np.nanmean(psth[:,single_trials_index], axis=1).reshape(-1,1) - err = np.nanstd(psth[:,single_trials_index], axis=1)/math.sqrt(psth[:,single_trials_index].shape[1]) - err = err.reshape(-1,1) - psth = np.hstack((psth,mean)) - psth = np.hstack((psth, err)) - #timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) - #psth = np.hstack((psth, timestamps)) - try: - ts = read_hdf5(event, filepath, 'ts') - ts = np.append(ts, ['mean', 'err']) - except: - ts = None - - if len(columns)==0: - df = pd.DataFrame(psth, index=None, columns=ts, dtype='float32') - else: - columns = np.asarray(columns) - columns = np.append(columns, ['mean', 'err']) - df = pd.DataFrame(psth, index=None, columns=list(columns), dtype='float32') - - df.to_hdf(op, key='df', mode='w') + event = event.replace("\\", "_") + event = event.replace("/", "_") + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + + # check if file already exists + # if os.path.exists(op): + # return 0 + + # removing psth binned trials + columns = np.array(columns, dtype="str") + regex = re.compile("bin_*") + single_trials = columns[[i for i in range(len(columns)) if not regex.match(columns[i])]] + single_trials_index = [i for i in range(len(single_trials)) if single_trials[i] != "timestamps"] + + psth = psth.T + if psth.ndim > 1: + mean = np.nanmean(psth[:, single_trials_index], axis=1).reshape(-1, 1) + err = np.nanstd(psth[:, single_trials_index], axis=1) / math.sqrt(psth[:, single_trials_index].shape[1]) + err = err.reshape(-1, 1) + psth = np.hstack((psth, mean)) + psth = np.hstack((psth, err)) + # timestamps = np.asarray(read_Df(filepath, 'ts_psth', '')) + # psth = np.hstack((psth, timestamps)) + try: + ts = read_hdf5(event, filepath, "ts") + ts = np.append(ts, ["mean", "err"]) + except: + ts = None + + if len(columns) == 0: + df = pd.DataFrame(psth, index=None, columns=ts, dtype="float32") + else: + columns = np.asarray(columns) + columns = np.append(columns, ["mean", "err"]) + df = pd.DataFrame(psth, index=None, columns=list(columns), dtype="float32") + + df.to_hdf(op, key="df", mode="w") # function to read h5 file and make a dataframe from it def read_Df(filepath, event, name): - event = event.replace("\\","_") - event = event.replace("/","_") - if name: - op = os.path.join(filepath, event+'_{}.h5'.format(name)) - else: - op = os.path.join(filepath, event+'.h5') - df = pd.read_hdf(op, key='df', mode='r') + event = event.replace("\\", "_") + event = event.replace("/", "_") + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + df = pd.read_hdf(op, key="df", mode="r") - return df + return df # function to create PSTH trials corresponding to each event timestamp def rowFormation(z_score, thisIndex, nTsPrev, nTsPost): - - if nTsPrev(thisIndex+nTsPost): - res = z_score[thisIndex-nTsPrev-1:thisIndex+nTsPost] - elif nTsPrev>=thisIndex and z_score.shape[0]>(thisIndex+nTsPost): - mismatch = nTsPrev-thisIndex+1 - res = np.zeros(nTsPrev+nTsPost+1) - res[:mismatch] = np.nan - res[mismatch:] = z_score[:thisIndex+nTsPost] - elif nTsPrev>=thisIndex and z_score.shape[0]<(thisIndex+nTsPost): - mismatch1 = nTsPrev-thisIndex+1 - mismatch2 = (thisIndex+nTsPost)-z_score.shape[0] - res1 = np.full(mismatch1, np.nan) - res2 = z_score - res3 = np.full(mismatch2, np.nan) - res = np.concatenate((res1, np.concatenate((res2, res3)))) - else: - mismatch = (thisIndex+nTsPost)-z_score.shape[0] - res1 = np.zeros(mismatch) - res1[:] = np.nan - res2 = z_score[thisIndex-nTsPrev-1:z_score.shape[0]] - res = np.concatenate((res2, res1)) - - return res + + if nTsPrev < thisIndex and z_score.shape[0] > (thisIndex + nTsPost): + res = z_score[thisIndex - nTsPrev - 1 : thisIndex + nTsPost] + elif nTsPrev >= thisIndex and z_score.shape[0] > (thisIndex + nTsPost): + mismatch = nTsPrev - thisIndex + 1 + res = np.zeros(nTsPrev + nTsPost + 1) + res[:mismatch] = np.nan + res[mismatch:] = z_score[: thisIndex + nTsPost] + elif nTsPrev >= thisIndex and z_score.shape[0] < (thisIndex + nTsPost): + mismatch1 = nTsPrev - thisIndex + 1 + mismatch2 = (thisIndex + nTsPost) - z_score.shape[0] + res1 = np.full(mismatch1, np.nan) + res2 = z_score + res3 = np.full(mismatch2, np.nan) + res = np.concatenate((res1, np.concatenate((res2, res3)))) + else: + mismatch = (thisIndex + nTsPost) - z_score.shape[0] + res1 = np.zeros(mismatch) + res1[:] = np.nan + res2 = z_score[thisIndex - nTsPrev - 1 : z_score.shape[0]] + res = np.concatenate((res2, res1)) + + return res # function to calculate baseline for each PSTH trial and do baseline correction def baselineCorrection(filepath, arr, timeAxis, baselineStart, baselineEnd): - #timeAxis = read_Df(filepath, 'ts_psth', '') - #timeAxis = np.asarray(timeAxis).reshape(-1) - baselineStrtPt = np.where(timeAxis>=baselineStart)[0] - baselineEndPt = np.where(timeAxis>=baselineEnd)[0] + # timeAxis = read_Df(filepath, 'ts_psth', '') + # timeAxis = np.asarray(timeAxis).reshape(-1) + baselineStrtPt = np.where(timeAxis >= baselineStart)[0] + baselineEndPt = np.where(timeAxis >= baselineEnd)[0] + + # logger.info(baselineStrtPt[0], baselineEndPt[0]) + if baselineStart == 0 and baselineEnd == 0: + return arr - #logger.info(baselineStrtPt[0], baselineEndPt[0]) - if baselineStart==0 and baselineEnd==0: - return arr - - baseline = np.nanmean(arr[baselineStrtPt[0]:baselineEndPt[0]]) - baselineSub = np.subtract(arr, baseline) + baseline = np.nanmean(arr[baselineStrtPt[0] : baselineEndPt[0]]) + baselineSub = np.subtract(arr, baseline) - return baselineSub + return baselineSub # helper function to make PSTH for each event -def helper_psth(z_score, event, filepath, - nSecPrev, nSecPost, timeInterval, - bin_psth_trials, use_time_or_trials, - baselineStart, baselineEnd, - naming, just_use_signal): - - event = event.replace("\\","_") - event = event.replace("/","_") - - sampling_rate = read_hdf5('timeCorrection_'+naming, filepath, 'sampling_rate')[0] - - # calculate time before event timestamp and time after event timestamp - nTsPrev = int(round(nSecPrev*sampling_rate)) - nTsPost = int(round(nSecPost*sampling_rate)) - - totalTs = (-1*nTsPrev) + nTsPost - increment = ((-1*nSecPrev)+nSecPost)/totalTs - timeAxis = np.linspace(nSecPrev, nSecPost+increment, totalTs+1) - timeAxisNew = np.concatenate((timeAxis, timeAxis[::-1])) - - # avoid writing same data to same file in multi-processing - #if not os.path.exists(os.path.join(filepath, 'ts_psth.h5')): - # logger.info('file not exists') - # create_Df(filepath, 'ts_psth', '', timeAxis) - # time.sleep(2) - - ts = read_hdf5(event+'_'+naming, filepath, 'ts') - - # reject timestamps for which baseline cannot be calculated because of nan values - new_ts = [] - for i in range(ts.shape[0]): - thisTime = ts[i] # -1 not needed anymore - if thisTime0: - timestamps = read_hdf5('timeCorrection_'+naming, filepath, 'timestampNew') - timestamps = np.divide(timestamps, 60) - ts_min = np.divide(ts, 60) - bin_steps = np.arange(timestamps[0], timestamps[-1]+bin_psth_trials, bin_psth_trials) - indices_each_step = dict() - for i in range(1, bin_steps.shape[0]): - indices_each_step[f"{np.around(bin_steps[i-1],0)}-{np.around(bin_steps[i],0)}"] = np.where((ts_min>=bin_steps[i-1]) & (ts_min<=bin_steps[i]))[0] - elif use_time_or_trials=='# of trials' and bin_psth_trials>0: - bin_steps = np.arange(0, ts.shape[0], bin_psth_trials) - if bin_steps[-1] 0: + timestamps = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") + timestamps = np.divide(timestamps, 60) + ts_min = np.divide(ts, 60) + bin_steps = np.arange(timestamps[0], timestamps[-1] + bin_psth_trials, bin_psth_trials) + indices_each_step = dict() + for i in range(1, bin_steps.shape[0]): + indices_each_step[f"{np.around(bin_steps[i-1],0)}-{np.around(bin_steps[i],0)}"] = np.where( + (ts_min >= bin_steps[i - 1]) & (ts_min <= bin_steps[i]) + )[0] + elif use_time_or_trials == "# of trials" and bin_psth_trials > 0: + bin_steps = np.arange(0, ts.shape[0], bin_psth_trials) + if bin_steps[-1] < ts.shape[0]: + bin_steps = np.concatenate((bin_steps, [ts.shape[0]]), axis=0) + indices_each_step = dict() + for i in range(1, bin_steps.shape[0]): + indices_each_step[f"{bin_steps[i-1]}-{bin_steps[i]}"] = np.arange(bin_steps[i - 1], bin_steps[i]) + else: + indices_each_step = dict() + + psth_bin, psth_bin_baselineUncorrected = [], [] + if indices_each_step: + keys = list(indices_each_step.keys()) + for k in keys: + # no trials in a given bin window, just put all the nan values + if indices_each_step[k].shape[0] == 0: + psth_bin.append(np.full(psth.shape[1], np.nan)) + psth_bin_baselineUncorrected.append(np.full(psth_baselineUncorrected.shape[1], np.nan)) + psth_bin.append(np.full(psth.shape[1], np.nan)) + psth_bin_baselineUncorrected.append(np.full(psth_baselineUncorrected.shape[1], np.nan)) + else: + index = indices_each_step[k] + arr = psth[index, :] + # mean of bins + psth_bin.append(np.nanmean(psth[index, :], axis=0)) + psth_bin_baselineUncorrected.append(np.nanmean(psth_baselineUncorrected[index, :], axis=0)) + psth_bin.append(np.nanstd(psth[index, :], axis=0) / math.sqrt(psth[index, :].shape[0])) + # error of bins + psth_bin_baselineUncorrected.append( + np.nanstd(psth_baselineUncorrected[index, :], axis=0) + / math.sqrt(psth_baselineUncorrected[index, :].shape[0]) + ) + + # adding column names + columns.append(f"bin_({k})") + columns.append(f"bin_err_({k})") + + psth = np.concatenate((psth, psth_bin), axis=0) + psth_baselineUncorrected = np.concatenate((psth_baselineUncorrected, psth_bin_baselineUncorrected), axis=0) + + timeAxis = timeAxis.reshape(1, -1) + psth = np.concatenate((psth, timeAxis), axis=0) + psth_baselineUncorrected = np.concatenate((psth_baselineUncorrected, timeAxis), axis=0) + columns.append("timestamps") + + return psth, psth_baselineUncorrected, columns # function to create PSTH for each event using function helper_psth and save the PSTH to h5 file def storenamePsth(filepath, event, inputParameters): - event = event.replace("\\","_") - event = event.replace("/","_") - - selectForComputePsth = inputParameters['selectForComputePsth'] - bin_psth_trials = inputParameters['bin_psth_trials'] - use_time_or_trials = inputParameters['use_time_or_trials'] - - if selectForComputePsth=='z_score': - path = glob.glob(os.path.join(filepath, 'z_score_*')) - elif selectForComputePsth=='dff': - path = glob.glob(os.path.join(filepath, 'dff_*')) - else: - path = glob.glob(os.path.join(filepath, 'z_score_*')) + glob.glob(os.path.join(filepath, 'dff_*')) - - b = np.divide(np.ones((100,)), 100) - a = 1 - - #storesList = storesList - #sampling_rate = read_hdf5(storesList[0,0], filepath, 'sampling_rate') - nSecPrev, nSecPost = inputParameters['nSecPrev'], inputParameters['nSecPost'] - baselineStart, baselineEnd = inputParameters['baselineCorrectionStart'], inputParameters['baselineCorrectionEnd'] - timeInterval = inputParameters['timeInterval'] - - if 'control' in event.lower() or 'signal' in event.lower(): - return 0 - else: - for i in range(len(path)): - logger.info(f"Computing PSTH for event {event}...") - basename = (os.path.basename(path[i])).split('.')[0] - name_1 = basename.split('_')[-1] - control = read_hdf5('control_'+name_1, os.path.dirname(path[i]), 'data') - if (control==0).all()==True: - signal = read_hdf5('signal_'+name_1, os.path.dirname(path[i]), 'data') - z_score = ss.filtfilt(b, a, signal) - just_use_signal = True - else: - z_score = read_hdf5('', path[i], 'data') - just_use_signal = False - psth, psth_baselineUncorrected, cols = helper_psth(z_score, event, filepath, - nSecPrev, nSecPost, timeInterval, - bin_psth_trials, use_time_or_trials, - baselineStart, baselineEnd, - name_1, just_use_signal) - - create_Df(filepath, event+'_'+name_1+'_baselineUncorrected', basename, psth_baselineUncorrected, columns=cols) # extra - create_Df(filepath, event+'_'+name_1, basename, psth, columns=cols) - logger.info(f"PSTH for event {event} computed.") + event = event.replace("\\", "_") + event = event.replace("/", "_") + + selectForComputePsth = inputParameters["selectForComputePsth"] + bin_psth_trials = inputParameters["bin_psth_trials"] + use_time_or_trials = inputParameters["use_time_or_trials"] + + if selectForComputePsth == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForComputePsth == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + else: + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + + b = np.divide(np.ones((100,)), 100) + a = 1 + + # storesList = storesList + # sampling_rate = read_hdf5(storesList[0,0], filepath, 'sampling_rate') + nSecPrev, nSecPost = inputParameters["nSecPrev"], inputParameters["nSecPost"] + baselineStart, baselineEnd = inputParameters["baselineCorrectionStart"], inputParameters["baselineCorrectionEnd"] + timeInterval = inputParameters["timeInterval"] + + if "control" in event.lower() or "signal" in event.lower(): + return 0 + else: + for i in range(len(path)): + logger.info(f"Computing PSTH for event {event}...") + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + control = read_hdf5("control_" + name_1, os.path.dirname(path[i]), "data") + if (control == 0).all() == True: + signal = read_hdf5("signal_" + name_1, os.path.dirname(path[i]), "data") + z_score = ss.filtfilt(b, a, signal) + just_use_signal = True + else: + z_score = read_hdf5("", path[i], "data") + just_use_signal = False + psth, psth_baselineUncorrected, cols = helper_psth( + z_score, + event, + filepath, + nSecPrev, + nSecPost, + timeInterval, + bin_psth_trials, + use_time_or_trials, + baselineStart, + baselineEnd, + name_1, + just_use_signal, + ) + + create_Df( + filepath, + event + "_" + name_1 + "_baselineUncorrected", + basename, + psth_baselineUncorrected, + columns=cols, + ) # extra + create_Df(filepath, event + "_" + name_1, basename, psth, columns=cols) + logger.info(f"PSTH for event {event} computed.") def helperPSTHPeakAndArea(psth_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint): - peak_startPoint = np.asarray(peak_startPoint) - peak_endPoint = np.asarray(peak_endPoint) - - peak_startPoint = peak_startPoint[~np.isnan(peak_startPoint)] - peak_endPoint = peak_endPoint[~np.isnan(peak_endPoint)] + peak_startPoint = np.asarray(peak_startPoint) + peak_endPoint = np.asarray(peak_endPoint) - if peak_startPoint.shape[0]!=peak_endPoint.shape[0]: - logger.error('Number of Peak Start Time and Peak End Time are unequal.') - raise Exception('Number of Peak Start Time and Peak End Time are unequal.') + peak_startPoint = peak_startPoint[~np.isnan(peak_startPoint)] + peak_endPoint = peak_endPoint[~np.isnan(peak_endPoint)] - if np.less_equal(peak_endPoint, peak_startPoint).any()==True: - logger.error('Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window.') - raise Exception('Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window.') + if peak_startPoint.shape[0] != peak_endPoint.shape[0]: + logger.error("Number of Peak Start Time and Peak End Time are unequal.") + raise Exception("Number of Peak Start Time and Peak End Time are unequal.") + if np.less_equal(peak_endPoint, peak_startPoint).any() == True: + logger.error( + "Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window." + ) + raise Exception( + "Peak End Time is lesser than or equal to Peak Start Time. Please check the Peak parameters window." + ) - peak_area = OrderedDict() + peak_area = OrderedDict() - if peak_startPoint.shape[0]==0 or peak_endPoint.shape[0]==0: - peak_area['peak'] = np.nan - peak_area['area'] = np.nan + if peak_startPoint.shape[0] == 0 or peak_endPoint.shape[0] == 0: + peak_area["peak"] = np.nan + peak_area["area"] = np.nan - for i in range(peak_startPoint.shape[0]): - startPtForPeak = np.where(timestamps>=peak_startPoint[i])[0] - endPtForPeak = np.where(timestamps>=peak_endPoint[i])[0] - if len(startPtForPeak)>=1 and len(endPtForPeak)>=1: - peakPoint_pos = startPtForPeak[0] + np.argmax(psth_mean[startPtForPeak[0]:endPtForPeak[0],:], axis=0) - peakPoint_neg = startPtForPeak[0] + np.argmin(psth_mean[startPtForPeak[0]:endPtForPeak[0],:], axis=0) - peak_area['peak_pos_'+str(i+1)] = np.amax(psth_mean[peakPoint_pos],axis=0) - peak_area['peak_neg_'+str(i+1)] = np.amin(psth_mean[peakPoint_neg],axis=0) - peak_area['area_'+str(i+1)] = np.trapz(psth_mean[startPtForPeak[0]:endPtForPeak[0],:], axis=0) - else: - peak_area['peak_'+str(i+1)] = np.nan - peak_area['area_'+str(i+1)] = np.nan + for i in range(peak_startPoint.shape[0]): + startPtForPeak = np.where(timestamps >= peak_startPoint[i])[0] + endPtForPeak = np.where(timestamps >= peak_endPoint[i])[0] + if len(startPtForPeak) >= 1 and len(endPtForPeak) >= 1: + peakPoint_pos = startPtForPeak[0] + np.argmax(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + peakPoint_neg = startPtForPeak[0] + np.argmin(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + peak_area["peak_pos_" + str(i + 1)] = np.amax(psth_mean[peakPoint_pos], axis=0) + peak_area["peak_neg_" + str(i + 1)] = np.amin(psth_mean[peakPoint_neg], axis=0) + peak_area["area_" + str(i + 1)] = np.trapz(psth_mean[startPtForPeak[0] : endPtForPeak[0], :], axis=0) + else: + peak_area["peak_" + str(i + 1)] = np.nan + peak_area["area_" + str(i + 1)] = np.nan - return peak_area + return peak_area # function to compute PSTH peak and area using the function helperPSTHPeakAndArea save the values to h5 and csv files. def findPSTHPeakAndArea(filepath, event, inputParameters): - - event = event.replace("\\","_") - event = event.replace("/","_") - - #sampling_rate = read_hdf5(storesList[0,0], filepath, 'sampling_rate') - peak_startPoint = inputParameters['peak_startPoint'] - peak_endPoint = inputParameters['peak_endPoint'] - selectForComputePsth = inputParameters['selectForComputePsth'] - - - if selectForComputePsth=='z_score': - path = glob.glob(os.path.join(filepath, 'z_score_*')) - elif selectForComputePsth=='dff': - path = glob.glob(os.path.join(filepath, 'dff_*')) - else: - path = glob.glob(os.path.join(filepath, 'z_score_*')) + glob.glob(os.path.join(filepath, 'dff_*')) - - - if 'control' in event.lower() or 'signal' in event.lower(): - return 0 - else: - for i in range(len(path)): - logger.info(f"Computing peak and area for PSTH mean signal for event {event}...") - basename = (os.path.basename(path[i])).split('.')[0] - name_1 = basename.split('_')[-1] - sampling_rate = read_hdf5('timeCorrection_'+name_1, filepath, 'sampling_rate')[0] - psth = read_Df(filepath, event+'_'+name_1, basename) - cols = list(psth.columns) - regex = re.compile('bin_[(]') - bin_names = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - regex_trials = re.compile('[+-]?([0-9]*[.])?[0-9]+') - trials_names = [cols[i] for i in range(len(cols)) if regex_trials.match(cols[i])] - psth_mean_bin_names = trials_names + bin_names + ['mean'] - psth_mean_bin_mean = np.asarray(psth[psth_mean_bin_names]) - timestamps = np.asarray(psth['timestamps']).ravel() #np.asarray(read_Df(filepath, 'ts_psth', '')).ravel() - peak_area = helperPSTHPeakAndArea(psth_mean_bin_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint) # peak, area = - #arr = np.array([[peak, area]]) - fileName = [os.path.basename(os.path.dirname(filepath))] - index = [fileName[0]+'_'+s for s in psth_mean_bin_names] - create_Df_area_peak(filepath, peak_area, event+'_'+name_1+'_'+basename, index=index) # columns=['peak', 'area'] - create_csv_area_peak(filepath, peak_area, event+'_'+name_1+'_'+basename, index=index) - logger.info(f"Peak and Area for PSTH mean signal for event {event} computed.") + + event = event.replace("\\", "_") + event = event.replace("/", "_") + + # sampling_rate = read_hdf5(storesList[0,0], filepath, 'sampling_rate') + peak_startPoint = inputParameters["peak_startPoint"] + peak_endPoint = inputParameters["peak_endPoint"] + selectForComputePsth = inputParameters["selectForComputePsth"] + + if selectForComputePsth == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForComputePsth == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + else: + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + + if "control" in event.lower() or "signal" in event.lower(): + return 0 + else: + for i in range(len(path)): + logger.info(f"Computing peak and area for PSTH mean signal for event {event}...") + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + sampling_rate = read_hdf5("timeCorrection_" + name_1, filepath, "sampling_rate")[0] + psth = read_Df(filepath, event + "_" + name_1, basename) + cols = list(psth.columns) + regex = re.compile("bin_[(]") + bin_names = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + regex_trials = re.compile("[+-]?([0-9]*[.])?[0-9]+") + trials_names = [cols[i] for i in range(len(cols)) if regex_trials.match(cols[i])] + psth_mean_bin_names = trials_names + bin_names + ["mean"] + psth_mean_bin_mean = np.asarray(psth[psth_mean_bin_names]) + timestamps = np.asarray(psth["timestamps"]).ravel() # np.asarray(read_Df(filepath, 'ts_psth', '')).ravel() + peak_area = helperPSTHPeakAndArea( + psth_mean_bin_mean, timestamps, sampling_rate, peak_startPoint, peak_endPoint + ) # peak, area = + # arr = np.array([[peak, area]]) + fileName = [os.path.basename(os.path.dirname(filepath))] + index = [fileName[0] + "_" + s for s in psth_mean_bin_names] + create_Df_area_peak( + filepath, peak_area, event + "_" + name_1 + "_" + basename, index=index + ) # columns=['peak', 'area'] + create_csv_area_peak(filepath, peak_area, event + "_" + name_1 + "_" + basename, index=index) + logger.info(f"Peak and Area for PSTH mean signal for event {event} computed.") + def makeAverageDir(filepath): - op = os.path.join(filepath, 'average') - if not os.path.exists(op): - os.mkdir(op) + op = os.path.join(filepath, "average") + if not os.path.exists(op): + os.mkdir(op) + + return op - return op def psth_shape_check(psth): - each_ln = [] - for i in range(len(psth)): - each_ln.append(psth[i].shape[0]) + each_ln = [] + for i in range(len(psth)): + each_ln.append(psth[i].shape[0]) + + each_ln = np.asarray(each_ln) + keep_ln = each_ln[-1] - each_ln = np.asarray(each_ln) - keep_ln = each_ln[-1] + for i in range(len(psth)): + if psth[i].shape[0] > keep_ln: + psth[i] = psth[i][:keep_ln] + elif psth[i].shape[0] < keep_ln: + psth[i] = np.append(psth[i], np.full(keep_ln - len(psth[i]), np.nan)) + else: + psth[i] = psth[i] - for i in range(len(psth)): - if psth[i].shape[0]>keep_ln: - psth[i] = psth[i][:keep_ln] - elif psth[i].shape[0]0: - psth_bins.append(df[bins_cols]) - - if len(psth)==0: - logger.warning('Somthing is wrong with the file search pattern.') - continue - - if len(bins_cols)>0: - df_bins = pd.concat(psth_bins, axis=1) - df_bins_mean = df_bins.groupby(by=df_bins.columns, axis=1).mean() - df_bins_err = df_bins.groupby(by=df_bins.columns, axis=1).std()/math.sqrt(df_bins.shape[1]) - cols_err = list(df_bins_err.columns) - dict_err = {} - for i in cols_err: - split = i.split('_') - dict_err[i] = '{}_err_{}'.format(split[0], split[1]) - df_bins_err = df_bins_err.rename(columns=dict_err) - columns = columns + list(df_bins_mean.columns) + list(df_bins_err.columns) - df_bins_mean_err = pd.concat([df_bins_mean, df_bins_err], axis=1).T - psth, df_bins_mean_err = np.asarray(psth), np.asarray(df_bins_mean_err) - psth = np.concatenate((psth, df_bins_mean_err), axis=0) - else: - psth = psth_shape_check(psth) - psth = np.asarray(psth) - - timestamps = np.asarray(df['timestamps']).reshape(1,-1) - psth = np.concatenate((psth, timestamps), axis=0) - columns = columns + ['timestamps'] - create_Df(op, temp_path[j][1], temp_path[j][2], psth, columns=columns) - - - # read PSTH peak and area for each event and combine them. Save the final output to an average folder - for i in range(len(new_path)): - arr = [] - index = [] - temp_path = new_path[i] - for j in range(len(temp_path)): - if not os.path.exists(os.path.join(temp_path[j][0], 'peak_AUC_'+temp_path[j][1]+'_'+temp_path[j][2]+'.h5')): - continue - else: - df = read_Df_area_peak(temp_path[j][0], temp_path[j][1]+'_'+temp_path[j][2]) - arr.append(df) - index.append(list(df.index)) - - if len(arr)==0: - logger.warning('Somthing is wrong with the file search pattern.') - continue - index = list(np.concatenate(index)) - new_df = pd.concat(arr, axis=0) #os.path.join(filepath, 'peak_AUC_'+name+'.csv') - new_df.to_csv(os.path.join(op, 'peak_AUC_{}_{}.csv'.format(temp_path[j][1], temp_path[j][2])), index=index) - new_df.to_hdf(os.path.join(op, 'peak_AUC_{}_{}.h5'.format(temp_path[j][1], temp_path[j][2])), key='df', mode='w', index=index) - - # read cross-correlation files and combine them. Save the final output to an average folder - type = [] - for i in range(len(folderNames)): - _, temp_type = getCorrCombinations(folderNames[i], inputParameters) - type.append(temp_type) - - type = np.unique(np.array(type)) - for i in range(len(type)): - corr = [] - columns = [] - df = None - for j in range(len(folderNames)): - corr_info, _ = getCorrCombinations(folderNames[j], inputParameters) - for k in range(1, len(corr_info)): - path = os.path.join(folderNames[j], 'cross_correlation_output', 'corr_'+event+'_'+type[i]+'_'+corr_info[k-1]+'_'+corr_info[k]) - if not os.path.exists(path+'.h5'): - continue - else: - df = read_Df(os.path.join(folderNames[j], 'cross_correlation_output'), 'corr_'+event, type[i]+'_'+corr_info[k-1]+'_'+corr_info[k]) - corr.append(df['mean']) - columns.append(os.path.basename(folderNames[j])) - - if not isinstance(df, pd.DataFrame): - break - - corr = np.array(corr) - timestamps = np.array(df['timestamps']).reshape(1,-1) - corr = np.concatenate((corr, timestamps), axis=0) - columns.append('timestamps') - create_Df(make_dir(op), 'corr_'+event, type[i]+'_'+corr_info[k-1]+'_'+corr_info[k], corr, columns=columns) - - logger.info('Group of data averaged.') + event = event.replace("\\", "_") + event = event.replace("/", "_") + + logger.debug("Averaging group of data...") + path = [] + abspath = inputParameters["abspath"] + selectForComputePsth = inputParameters["selectForComputePsth"] + path_temp_len = [] + op = makeAverageDir(abspath) + + # combining paths to all the selected folders for doing average + for i in range(len(folderNames)): + if selectForComputePsth == "z_score": + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + elif selectForComputePsth == "dff": + path_temp = glob.glob(os.path.join(folderNames[i], "dff_*")) + else: + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + glob.glob( + os.path.join(folderNames[i], "dff_*") + ) + + path_temp_len.append(len(path_temp)) + # path_temp = glob.glob(os.path.join(folderNames[i], 'z_score_*')) + for j in range(len(path_temp)): + basename = (os.path.basename(path_temp[j])).split(".")[0] + write_hdf5(np.array([]), basename, op, "data") + name_1 = basename.split("_")[-1] + temp = [folderNames[i], event + "_" + name_1, basename] + path.append(temp) + + # processing of all the paths + path_temp_len = np.asarray(path_temp_len) + max_len = np.argmax(path_temp_len) + + naming = [] + for i in range(len(path)): + naming.append(path[i][2]) + naming = np.unique(np.asarray(naming)) + + new_path = [[] for _ in range(path_temp_len[max_len])] + for i in range(len(path)): + idx = np.where(naming == path[i][2])[0][0] + new_path[idx].append(path[i]) + + # read PSTH for each event and make the average of it. Save the final output to an average folder. + for i in range(len(new_path)): + psth, psth_bins = [], [] + columns = [] + bins_cols = [] + temp_path = new_path[i] + for j in range(len(temp_path)): + # logger.info(os.path.join(temp_path[j][0], temp_path[j][1]+'_{}.h5'.format(temp_path[j][2]))) + if not os.path.exists(os.path.join(temp_path[j][0], temp_path[j][1] + "_{}.h5".format(temp_path[j][2]))): + continue + else: + df = read_Df(temp_path[j][0], temp_path[j][1], temp_path[j][2]) # filepath, event, name + cols = list(df.columns) + regex = re.compile("bin_[(]") + bins_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + psth.append(np.asarray(df["mean"])) + columns.append(os.path.basename(temp_path[j][0])) + if len(bins_cols) > 0: + psth_bins.append(df[bins_cols]) + + if len(psth) == 0: + logger.warning("Something is wrong with the file search pattern.") + continue + + if len(bins_cols) > 0: + df_bins = pd.concat(psth_bins, axis=1) + df_bins_mean = df_bins.groupby(by=df_bins.columns, axis=1).mean() + df_bins_err = df_bins.groupby(by=df_bins.columns, axis=1).std() / math.sqrt(df_bins.shape[1]) + cols_err = list(df_bins_err.columns) + dict_err = {} + for i in cols_err: + split = i.split("_") + dict_err[i] = "{}_err_{}".format(split[0], split[1]) + df_bins_err = df_bins_err.rename(columns=dict_err) + columns = columns + list(df_bins_mean.columns) + list(df_bins_err.columns) + df_bins_mean_err = pd.concat([df_bins_mean, df_bins_err], axis=1).T + psth, df_bins_mean_err = np.asarray(psth), np.asarray(df_bins_mean_err) + psth = np.concatenate((psth, df_bins_mean_err), axis=0) + else: + psth = psth_shape_check(psth) + psth = np.asarray(psth) + + timestamps = np.asarray(df["timestamps"]).reshape(1, -1) + psth = np.concatenate((psth, timestamps), axis=0) + columns = columns + ["timestamps"] + create_Df(op, temp_path[j][1], temp_path[j][2], psth, columns=columns) + + # read PSTH peak and area for each event and combine them. Save the final output to an average folder + for i in range(len(new_path)): + arr = [] + index = [] + temp_path = new_path[i] + for j in range(len(temp_path)): + if not os.path.exists( + os.path.join(temp_path[j][0], "peak_AUC_" + temp_path[j][1] + "_" + temp_path[j][2] + ".h5") + ): + continue + else: + df = read_Df_area_peak(temp_path[j][0], temp_path[j][1] + "_" + temp_path[j][2]) + arr.append(df) + index.append(list(df.index)) + + if len(arr) == 0: + logger.warning("Something is wrong with the file search pattern.") + continue + index = list(np.concatenate(index)) + new_df = pd.concat(arr, axis=0) # os.path.join(filepath, 'peak_AUC_'+name+'.csv') + new_df.to_csv(os.path.join(op, "peak_AUC_{}_{}.csv".format(temp_path[j][1], temp_path[j][2])), index=index) + new_df.to_hdf( + os.path.join(op, "peak_AUC_{}_{}.h5".format(temp_path[j][1], temp_path[j][2])), + key="df", + mode="w", + index=index, + ) + + # read cross-correlation files and combine them. Save the final output to an average folder + type = [] + for i in range(len(folderNames)): + _, temp_type = getCorrCombinations(folderNames[i], inputParameters) + type.append(temp_type) + + type = np.unique(np.array(type)) + for i in range(len(type)): + corr = [] + columns = [] + df = None + for j in range(len(folderNames)): + corr_info, _ = getCorrCombinations(folderNames[j], inputParameters) + for k in range(1, len(corr_info)): + path = os.path.join( + folderNames[j], + "cross_correlation_output", + "corr_" + event + "_" + type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], + ) + if not os.path.exists(path + ".h5"): + continue + else: + df = read_Df( + os.path.join(folderNames[j], "cross_correlation_output"), + "corr_" + event, + type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], + ) + corr.append(df["mean"]) + columns.append(os.path.basename(folderNames[j])) + + if not isinstance(df, pd.DataFrame): + break + + corr = np.array(corr) + timestamps = np.array(df["timestamps"]).reshape(1, -1) + corr = np.concatenate((corr, timestamps), axis=0) + columns.append("timestamps") + create_Df( + make_dir(op), "corr_" + event, type[i] + "_" + corr_info[k - 1] + "_" + corr_info[k], corr, columns=columns + ) + + logger.info("Group of data averaged.") def psthForEachStorename(inputParameters): - logger.info("Computing PSTH, Peak and Area for each event...") - inputParameters = inputParameters - - - #storesList = np.genfromtxt(inputParameters['storesListPath'], dtype='str', delimiter=',') - - folderNames = inputParameters['folderNames'] - folderNamesForAvg = inputParameters['folderNamesForAvg'] - average = inputParameters['averageForGroup'] - combine_data = inputParameters['combine_data'] - numProcesses = inputParameters['numberOfCores'] - inputParameters['step'] = 0 - if numProcesses==0: - numProcesses = mp.cpu_count() - elif numProcesses>mp.cpu_count(): - logger.warning('Warning : # of cores parameter set is greater than the cores available \ - available in your machine') - numProcesses = mp.cpu_count()-1 - - logger.info(f"Average for group : {average}") - - # for average following if statement will be executed - if average==True: - if len(folderNamesForAvg)>0: - storesListPath = [] - for i in range(len(folderNamesForAvg)): - filepath = folderNamesForAvg[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - storesListPath = np.concatenate(storesListPath) - storesList = np.asarray([[],[]]) - for i in range(storesListPath.shape[0]): - storesList = np.concatenate((storesList, np.genfromtxt(os.path.join(storesListPath[i], 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1)), axis=1) - storesList = np.unique(storesList, axis=1) - op = makeAverageDir(inputParameters['abspath']) - np.savetxt(os.path.join(op, 'storesList.csv'), storesList, delimiter=",", fmt='%s') - pbMaxValue = 0 - for j in range(storesList.shape[1]): - if 'control' in storesList[1,j].lower() or 'signal' in storesList[1,j].lower(): - continue - else: - pbMaxValue += 1 - writeToFile(str((1+pbMaxValue+1)*10)+'\n'+str(10)+'\n') - for k in range(storesList.shape[1]): - if 'control' in storesList[1,k].lower() or 'signal' in storesList[1,k].lower(): - continue - else: - averageForGroup(storesListPath, storesList[1,k], inputParameters) - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - - else: - logger.error('Not a single folder name is provided in folderNamesForAvg in inputParamters File.') - raise Exception('Not a single folder name is provided in folderNamesForAvg in inputParamters File.') - - # for individual analysis following else statement will be executed - else: - if combine_data==True: - storesListPath = [] - for i in range(len(folderNames)): - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], '*_output_*')))) - storesListPath = list(np.concatenate(storesListPath).flatten()) - op = get_all_stores_for_combining_data(storesListPath) - writeToFile(str((len(op)+len(op)+1)*10)+'\n'+str(10)+'\n') - for i in range(len(op)): - storesList = np.asarray([[],[]]) - for j in range(len(op[i])): - storesList = np.concatenate((storesList, np.genfromtxt(os.path.join(op[i][j], 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1)), axis=1) - storesList = np.unique(storesList, axis=1) - for k in range(storesList.shape[1]): - storenamePsth(op[i][0], storesList[1,k], inputParameters) - findPSTHPeakAndArea(op[i][0], storesList[1,k], inputParameters) - computeCrossCorrelation(op[i][0], storesList[1,k], inputParameters) - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - else: - storesListPath = [] - for i in range(len(folderNames)): - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], '*_output_*')))) - storesListPath = np.concatenate(storesListPath) - writeToFile(str((storesListPath.shape[0]+storesListPath.shape[0]+1)*10)+'\n'+str(10)+'\n') - for i in range(len(folderNames)): - logger.debug(f"Computing PSTH, Peak and Area for each event in {folderNames[i]}") - storesListPath = takeOnlyDirs(glob.glob(os.path.join(folderNames[i], '*_output_*'))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - - with mp.Pool(numProcesses) as p: - p.starmap(storenamePsth, zip(repeat(filepath), storesList[1,:], repeat(inputParameters))) - - with mp.Pool(numProcesses) as pq: - pq.starmap(findPSTHPeakAndArea, zip(repeat(filepath), storesList[1,:], repeat(inputParameters))) - - with mp.Pool(numProcesses) as cr: - cr.starmap(computeCrossCorrelation, zip(repeat(filepath), storesList[1,:], repeat(inputParameters))) - - #for k in range(storesList.shape[1]): - # storenamePsth(filepath, storesList[1,k], inputParameters) - # findPSTHPeakAndArea(filepath, storesList[1,k], inputParameters) - - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - logger.info(f"PSTH, Area and Peak are computed for all events in {folderNames[i]}.") - logger.info("PSTH, Area and Peak are computed for all events.") - return inputParameters + logger.info("Computing PSTH, Peak and Area for each event...") + inputParameters = inputParameters + + # storesList = np.genfromtxt(inputParameters['storesListPath'], dtype='str', delimiter=',') + + folderNames = inputParameters["folderNames"] + folderNamesForAvg = inputParameters["folderNamesForAvg"] + average = inputParameters["averageForGroup"] + combine_data = inputParameters["combine_data"] + numProcesses = inputParameters["numberOfCores"] + inputParameters["step"] = 0 + if numProcesses == 0: + numProcesses = mp.cpu_count() + elif numProcesses > mp.cpu_count(): + logger.warning( + "Warning : # of cores parameter set is greater than the cores available \ + available in your machine" + ) + numProcesses = mp.cpu_count() - 1 + + logger.info(f"Average for group : {average}") + + # for average following if statement will be executed + if average == True: + if len(folderNamesForAvg) > 0: + storesListPath = [] + for i in range(len(folderNamesForAvg)): + filepath = folderNamesForAvg[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + storesList = np.asarray([[], []]) + for i in range(storesListPath.shape[0]): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt( + os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + op = makeAverageDir(inputParameters["abspath"]) + np.savetxt(os.path.join(op, "storesList.csv"), storesList, delimiter=",", fmt="%s") + pbMaxValue = 0 + for j in range(storesList.shape[1]): + if "control" in storesList[1, j].lower() or "signal" in storesList[1, j].lower(): + continue + else: + pbMaxValue += 1 + writeToFile(str((1 + pbMaxValue + 1) * 10) + "\n" + str(10) + "\n") + for k in range(storesList.shape[1]): + if "control" in storesList[1, k].lower() or "signal" in storesList[1, k].lower(): + continue + else: + averageForGroup(storesListPath, storesList[1, k], inputParameters) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + + else: + logger.error("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + raise Exception("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + + # for individual analysis following else statement will be executed + else: + if combine_data == True: + storesListPath = [] + for i in range(len(folderNames)): + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + writeToFile(str((len(op) + len(op) + 1) * 10) + "\n" + str(10) + "\n") + for i in range(len(op)): + storesList = np.asarray([[], []]) + for j in range(len(op[i])): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + for k in range(storesList.shape[1]): + storenamePsth(op[i][0], storesList[1, k], inputParameters) + findPSTHPeakAndArea(op[i][0], storesList[1, k], inputParameters) + computeCrossCorrelation(op[i][0], storesList[1, k], inputParameters) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + else: + storesListPath = [] + for i in range(len(folderNames)): + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + writeToFile(str((storesListPath.shape[0] + storesListPath.shape[0] + 1) * 10) + "\n" + str(10) + "\n") + for i in range(len(folderNames)): + logger.debug(f"Computing PSTH, Peak and Area for each event in {folderNames[i]}") + storesListPath = takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt( + os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + + with mp.Pool(numProcesses) as p: + p.starmap(storenamePsth, zip(repeat(filepath), storesList[1, :], repeat(inputParameters))) + + with mp.Pool(numProcesses) as pq: + pq.starmap( + findPSTHPeakAndArea, zip(repeat(filepath), storesList[1, :], repeat(inputParameters)) + ) + + with mp.Pool(numProcesses) as cr: + cr.starmap( + computeCrossCorrelation, zip(repeat(filepath), storesList[1, :], repeat(inputParameters)) + ) + + # for k in range(storesList.shape[1]): + # storenamePsth(filepath, storesList[1,k], inputParameters) + # findPSTHPeakAndArea(filepath, storesList[1,k], inputParameters) + + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + logger.info(f"PSTH, Area and Peak are computed for all events in {folderNames[i]}.") + logger.info("PSTH, Area and Peak are computed for all events.") + return inputParameters + def main(input_parameters): - try: - inputParameters = psthForEachStorename(input_parameters) - subprocess.call([sys.executable, "-m", "guppy.findTransientsFreqAndAmp", json.dumps(inputParameters)]) - logger.info('#'*400) - except Exception as e: - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(str(-1)+"\n") - logger.error(str(e)) - raise e + try: + inputParameters = psthForEachStorename(input_parameters) + subprocess.call([sys.executable, "-m", "guppy.findTransientsFreqAndAmp", json.dumps(inputParameters)]) + logger.info("#" * 400) + except Exception as e: + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(str(-1) + "\n") + logger.error(str(e)) + raise e + if __name__ == "__main__": - input_parameters = json.loads(sys.argv[1]) - main(input_parameters=input_parameters) + input_parameters = json.loads(sys.argv[1]) + main(input_parameters=input_parameters) diff --git a/src/guppy/findTransientsFreqAndAmp.py b/src/guppy/findTransientsFreqAndAmp.py index 0ac43ce..795ace1 100755 --- a/src/guppy/findTransientsFreqAndAmp.py +++ b/src/guppy/findTransientsFreqAndAmp.py @@ -1,70 +1,72 @@ -import os -import sys import glob -import h5py import json +import logging import math -import numpy as np -import pandas as pd import multiprocessing as mp -from scipy.signal import argrelextrema -import matplotlib.pyplot as plt +import os +import sys from itertools import repeat -from pathlib import Path + +import h5py +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from scipy.signal import argrelextrema + from .preprocess import get_all_stores_for_combining_data -import logging logger = logging.getLogger(__name__) + def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths)-set(removePaths)) + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + def writeToFile(value: str): - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(value) + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(value) + def read_hdf5(event, filepath, key): - if event: - op = os.path.join(filepath, event+'.hdf5') - else: - op = filepath + if event: + op = os.path.join(filepath, event + ".hdf5") + else: + op = filepath - if os.path.exists(op): - with h5py.File(op, 'r') as f: - arr = np.asarray(f[key]) - else: - logger.error(f"{event}.hdf5 file does not exist") - raise Exception('{}.hdf5 file does not exist'.format(event)) + if os.path.exists(op): + with h5py.File(op, "r") as f: + arr = np.asarray(f[key]) + else: + logger.error(f"{event}.hdf5 file does not exist") + raise Exception("{}.hdf5 file does not exist".format(event)) - return arr + return arr def processChunks(arrValues, arrIndexes, highAmpFilt, transientsThresh): - - arrValues = arrValues[~np.isnan(arrValues)] - median = np.median(arrValues) - - mad = np.median(np.abs(arrValues-median)) - - firstThreshold = median + (highAmpFilt*mad) - - - greaterThanMad = np.where(arrValues>firstThreshold)[0] - + + arrValues = arrValues[~np.isnan(arrValues)] + median = np.median(arrValues) + + mad = np.median(np.abs(arrValues - median)) + + firstThreshold = median + (highAmpFilt * mad) + + greaterThanMad = np.where(arrValues > firstThreshold)[0] arr = np.arange(arrValues.shape[0]) lowerThanMad = np.isin(arr, greaterThanMad, invert=True) - filteredOut = arrValues[np.where(lowerThanMad==True)[0]] - + filteredOut = arrValues[np.where(lowerThanMad == True)[0]] + filteredOutMedian = np.median(filteredOut) - filteredOutMad = np.median(np.abs(filteredOut-np.median(filteredOut))) - secondThreshold = filteredOutMedian+(transientsThresh*filteredOutMad) + filteredOutMad = np.median(np.abs(filteredOut - np.median(filteredOut))) + secondThreshold = filteredOutMedian + (transientsThresh * filteredOutMad) - greaterThanThreshIndex = np.where(arrValues>secondThreshold)[0] + greaterThanThreshIndex = np.where(arrValues > secondThreshold)[0] greaterThanThreshValues = arrValues[greaterThanThreshIndex] temp = np.zeros(arrValues.shape[0]) temp[greaterThanThreshIndex] = greaterThanThreshValues @@ -73,282 +75,302 @@ def processChunks(arrValues, arrIndexes, highAmpFilt, transientsThresh): firstThresholdY = np.full(arrValues.shape[0], firstThreshold) secondThresholdY = np.full(arrValues.shape[0], secondThreshold) - newPeaks = np.full(arrValues.shape[0], np.nan) newPeaks[peaks] = peaks + arrIndexes[0] - #madY = np.full(arrValues.shape[0], mad) + # madY = np.full(arrValues.shape[0], mad) medianY = np.full(arrValues.shape[0], median) filteredOutMedianY = np.full(arrValues.shape[0], filteredOutMedian) return peaks, mad, filteredOutMad, medianY, filteredOutMedianY, firstThresholdY, secondThresholdY - def createChunks(z_score, sampling_rate, window): - - logger.debug('Creating chunks for multiprocessing...') - windowPoints = math.ceil(sampling_rate*window) - remainderPoints = math.ceil((sampling_rate*window) - (z_score.shape[0]%windowPoints)) + logger.debug("Creating chunks for multiprocessing...") + windowPoints = math.ceil(sampling_rate * window) + remainderPoints = math.ceil((sampling_rate * window) - (z_score.shape[0] % windowPoints)) - if remainderPoints==windowPoints: - padded_z_score = z_score - z_score_index = np.arange(padded_z_score.shape[0]) - else: - padding = np.full(remainderPoints, np.nan) - padded_z_score = np.concatenate((z_score, padding)) - z_score_index = np.arange(padded_z_score.shape[0]) + if remainderPoints == windowPoints: + padded_z_score = z_score + z_score_index = np.arange(padded_z_score.shape[0]) + else: + padding = np.full(remainderPoints, np.nan) + padded_z_score = np.concatenate((z_score, padding)) + z_score_index = np.arange(padded_z_score.shape[0]) - reshape = padded_z_score.shape[0]/windowPoints + reshape = padded_z_score.shape[0] / windowPoints - if reshape.is_integer()==True: - z_score_chunks = padded_z_score.reshape(int(reshape), -1) - z_score_chunks_index = z_score_index.reshape(int(reshape), -1) - else: - logger.error('Reshaping values should be integer.') - raise Exception('Reshaping values should be integer.') - logger.info('Chunks are created for multiprocessing.') - return z_score_chunks, z_score_chunks_index + if reshape.is_integer() == True: + z_score_chunks = padded_z_score.reshape(int(reshape), -1) + z_score_chunks_index = z_score_index.reshape(int(reshape), -1) + else: + logger.error("Reshaping values should be integer.") + raise Exception("Reshaping values should be integer.") + logger.info("Chunks are created for multiprocessing.") + return z_score_chunks, z_score_chunks_index def calculate_freq_amp(arr, z_score, z_score_chunks_index, timestamps): - peaks = arr[:,0] - filteredOutMedian = arr[:,4] - count = 0 - peaksAmp = np.array([]) - peaksInd = np.array([]) - for i in range(z_score_chunks_index.shape[0]): - count += peaks[i].shape[0] - peaksIndexes = peaks[i]+z_score_chunks_index[i][0] - peaksInd = np.concatenate((peaksInd, peaksIndexes)) - amps = z_score[peaksIndexes]-filteredOutMedian[i][0] - peaksAmp = np.concatenate((peaksAmp, amps)) - - peaksInd = peaksInd.ravel() - peaksInd = peaksInd.astype(int) - #logger.info(timestamps) - freq = peaksAmp.shape[0]/((timestamps[-1]-timestamps[0])/60) - - return freq, peaksAmp, peaksInd + peaks = arr[:, 0] + filteredOutMedian = arr[:, 4] + count = 0 + peaksAmp = np.array([]) + peaksInd = np.array([]) + for i in range(z_score_chunks_index.shape[0]): + count += peaks[i].shape[0] + peaksIndexes = peaks[i] + z_score_chunks_index[i][0] + peaksInd = np.concatenate((peaksInd, peaksIndexes)) + amps = z_score[peaksIndexes] - filteredOutMedian[i][0] + peaksAmp = np.concatenate((peaksAmp, amps)) + + peaksInd = peaksInd.ravel() + peaksInd = peaksInd.astype(int) + # logger.info(timestamps) + freq = peaksAmp.shape[0] / ((timestamps[-1] - timestamps[0]) / 60) + + return freq, peaksAmp, peaksInd + def create_Df(filepath, arr, name, index=[], columns=[]): - op = os.path.join(filepath, 'freqAndAmp_'+name+'.h5') - dirname = os.path.dirname(filepath) + op = os.path.join(filepath, "freqAndAmp_" + name + ".h5") + dirname = os.path.dirname(filepath) - df = pd.DataFrame(arr, index=index, columns=columns) + df = pd.DataFrame(arr, index=index, columns=columns) + + df.to_hdf(op, key="df", mode="w") - df.to_hdf(op, key='df', mode='w') def create_csv(filepath, arr, name, index=[], columns=[]): - op = os.path.join(filepath, name) - df = pd.DataFrame(arr, index=index, columns=columns) - df.to_csv(op) + op = os.path.join(filepath, name) + df = pd.DataFrame(arr, index=index, columns=columns) + df.to_csv(op) + def read_Df(filepath, name): - op = os.path.join(filepath, 'freqAndAmp_'+name+'.h5') - df = pd.read_hdf(op, key='df', mode='r') + op = os.path.join(filepath, "freqAndAmp_" + name + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df - return df def visuzlize_peaks(filepath, z_score, timestamps, peaksIndex): - - dirname = os.path.dirname(filepath) - - basename = (os.path.basename(filepath)).split('.')[0] - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot(timestamps,z_score, '-', - timestamps[peaksIndex], z_score[peaksIndex], 'o') - ax.set_title(basename) - fig.suptitle(os.path.basename(dirname)) - #plt.show() + + dirname = os.path.dirname(filepath) + + basename = (os.path.basename(filepath)).split(".")[0] + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(timestamps, z_score, "-", timestamps[peaksIndex], z_score[peaksIndex], "o") + ax.set_title(basename) + fig.suptitle(os.path.basename(dirname)) + # plt.show() + def findFreqAndAmp(filepath, inputParameters, window=15, numProcesses=mp.cpu_count()): - logger.debug('Calculating frequency and amplitude of transients in z-score data....') - selectForTransientsComputation = inputParameters['selectForTransientsComputation'] - highAmpFilt = inputParameters['highAmpFilt'] - transientsThresh = inputParameters['transientsThresh'] - - if selectForTransientsComputation=='z_score': - path = glob.glob(os.path.join(filepath, 'z_score_*')) - elif selectForTransientsComputation=='dff': - path = glob.glob(os.path.join(filepath, 'dff_*')) - else: - path = glob.glob(os.path.join(filepath, 'z_score_*')) + glob.glob(os.path.join(filepath, 'dff_*')) - - for i in range(len(path)): - basename = (os.path.basename(path[i])).split('.')[0] - name_1 = basename.split('_')[-1] - sampling_rate = read_hdf5('timeCorrection_'+name_1, filepath, 'sampling_rate')[0] - z_score = read_hdf5('', path[i], 'data') - not_nan_indices = ~np.isnan(z_score) - z_score = z_score[not_nan_indices] - z_score_chunks, z_score_chunks_index = createChunks(z_score, sampling_rate, window) - - - with mp.Pool(numProcesses) as p: - result = p.starmap(processChunks, zip(z_score_chunks, z_score_chunks_index, repeat(highAmpFilt), repeat(transientsThresh))) - - - result = np.asarray(result, dtype=object) - ts = read_hdf5('timeCorrection_'+name_1, filepath, 'timestampNew') - ts = ts[not_nan_indices] - freq, peaksAmp, peaksInd = calculate_freq_amp(result, z_score, z_score_chunks_index, ts) - peaks_occurrences = np.array([ts[peaksInd], peaksAmp]).T - arr = np.array([[freq, np.mean(peaksAmp)]]) - fileName = [os.path.basename(os.path.dirname(filepath))] - create_Df(filepath, arr, basename, index=fileName ,columns=['freq (events/min)', 'amplitude']) - create_csv(filepath, arr, 'freqAndAmp_'+basename+'.csv', - index=fileName, columns=['freq (events/min)', 'amplitude']) - create_csv(filepath, peaks_occurrences, 'transientsOccurrences_'+basename+'.csv', - index=np.arange(peaks_occurrences.shape[0]),columns=['timestamps', 'amplitude']) - visuzlize_peaks(path[i], z_score, ts, peaksInd) - logger.info('Frequency and amplitude of transients in z_score data are calculated.') - + logger.debug("Calculating frequency and amplitude of transients in z-score data....") + selectForTransientsComputation = inputParameters["selectForTransientsComputation"] + highAmpFilt = inputParameters["highAmpFilt"] + transientsThresh = inputParameters["transientsThresh"] + + if selectForTransientsComputation == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif selectForTransientsComputation == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) + else: + path = glob.glob(os.path.join(filepath, "z_score_*")) + glob.glob(os.path.join(filepath, "dff_*")) + + for i in range(len(path)): + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + sampling_rate = read_hdf5("timeCorrection_" + name_1, filepath, "sampling_rate")[0] + z_score = read_hdf5("", path[i], "data") + not_nan_indices = ~np.isnan(z_score) + z_score = z_score[not_nan_indices] + z_score_chunks, z_score_chunks_index = createChunks(z_score, sampling_rate, window) + + with mp.Pool(numProcesses) as p: + result = p.starmap( + processChunks, zip(z_score_chunks, z_score_chunks_index, repeat(highAmpFilt), repeat(transientsThresh)) + ) + + result = np.asarray(result, dtype=object) + ts = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") + ts = ts[not_nan_indices] + freq, peaksAmp, peaksInd = calculate_freq_amp(result, z_score, z_score_chunks_index, ts) + peaks_occurrences = np.array([ts[peaksInd], peaksAmp]).T + arr = np.array([[freq, np.mean(peaksAmp)]]) + fileName = [os.path.basename(os.path.dirname(filepath))] + create_Df(filepath, arr, basename, index=fileName, columns=["freq (events/min)", "amplitude"]) + create_csv( + filepath, arr, "freqAndAmp_" + basename + ".csv", index=fileName, columns=["freq (events/min)", "amplitude"] + ) + create_csv( + filepath, + peaks_occurrences, + "transientsOccurrences_" + basename + ".csv", + index=np.arange(peaks_occurrences.shape[0]), + columns=["timestamps", "amplitude"], + ) + visuzlize_peaks(path[i], z_score, ts, peaksInd) + logger.info("Frequency and amplitude of transients in z_score data are calculated.") def makeAverageDir(filepath): - op = os.path.join(filepath, 'average') - if not os.path.exists(op): - os.mkdir(op) + op = os.path.join(filepath, "average") + if not os.path.exists(op): + os.mkdir(op) + + return op - return op def averageForGroup(folderNames, inputParameters): - logger.debug('Combining results for frequency and amplitude of transients in z-score data...') - path = [] - abspath = inputParameters['abspath'] - selectForTransientsComputation = inputParameters['selectForTransientsComputation'] - path_temp_len = [] - - for i in range(len(folderNames)): - if selectForTransientsComputation=='z_score': - path_temp = glob.glob(os.path.join(folderNames[i], 'z_score_*')) - elif selectForTransientsComputation=='dff': - path_temp = glob.glob(os.path.join(folderNames[i], 'dff_*')) - else: - path_temp = glob.glob(os.path.join(folderNames[i], 'z_score_*')) + glob.glob(os.path.join(folderNames[i], 'dff_*')) - - path_temp_len.append(len(path_temp)) - - for j in range(len(path_temp)): - basename = (os.path.basename(path_temp[j])).split('.')[0] - #name = name[0] - temp = [folderNames[i], basename] - path.append(temp) - - - path_temp_len = np.asarray(path_temp_len) - max_len = np.argmax(path_temp_len) - - naming = [] - for i in range(len(path)): - naming.append(path[i][1]) - naming = np.unique(np.asarray(naming)) - - - new_path = [[] for _ in range(path_temp_len[max_len])] - for i in range(len(path)): - idx = np.where(naming==path[i][1])[0][0] - new_path[idx].append(path[i]) - - op = makeAverageDir(abspath) - - - for i in range(len(new_path)): - arr = [] #np.zeros((len(new_path[i]), 2)) - fileName = [] - temp_path = new_path[i] - for j in range(len(temp_path)): - if not os.path.exists(os.path.join(temp_path[j][0], 'freqAndAmp_'+temp_path[j][1]+'.h5')): - continue - else: - df = read_Df(temp_path[j][0], temp_path[j][1]) - arr.append(np.array([df['freq (events/min)'][0], df['amplitude'][0]])) - fileName.append(os.path.basename(temp_path[j][0])) - - arr = np.asarray(arr) - create_Df(op, arr, temp_path[j][1], index=fileName, columns=['freq (events/min)', 'amplitude']) - create_csv(op, arr, 'freqAndAmp_'+temp_path[j][1]+'.csv', index=fileName, columns=['freq (events/min)', 'amplitude']) - logger.info('Results for frequency and amplitude of transients in z-score data are combined.') + logger.debug("Combining results for frequency and amplitude of transients in z-score data...") + path = [] + abspath = inputParameters["abspath"] + selectForTransientsComputation = inputParameters["selectForTransientsComputation"] + path_temp_len = [] + + for i in range(len(folderNames)): + if selectForTransientsComputation == "z_score": + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + elif selectForTransientsComputation == "dff": + path_temp = glob.glob(os.path.join(folderNames[i], "dff_*")) + else: + path_temp = glob.glob(os.path.join(folderNames[i], "z_score_*")) + glob.glob( + os.path.join(folderNames[i], "dff_*") + ) + + path_temp_len.append(len(path_temp)) + + for j in range(len(path_temp)): + basename = (os.path.basename(path_temp[j])).split(".")[0] + # name = name[0] + temp = [folderNames[i], basename] + path.append(temp) + + path_temp_len = np.asarray(path_temp_len) + max_len = np.argmax(path_temp_len) + + naming = [] + for i in range(len(path)): + naming.append(path[i][1]) + naming = np.unique(np.asarray(naming)) + + new_path = [[] for _ in range(path_temp_len[max_len])] + for i in range(len(path)): + idx = np.where(naming == path[i][1])[0][0] + new_path[idx].append(path[i]) + + op = makeAverageDir(abspath) + + for i in range(len(new_path)): + arr = [] # np.zeros((len(new_path[i]), 2)) + fileName = [] + temp_path = new_path[i] + for j in range(len(temp_path)): + if not os.path.exists(os.path.join(temp_path[j][0], "freqAndAmp_" + temp_path[j][1] + ".h5")): + continue + else: + df = read_Df(temp_path[j][0], temp_path[j][1]) + arr.append(np.array([df["freq (events/min)"][0], df["amplitude"][0]])) + fileName.append(os.path.basename(temp_path[j][0])) + + arr = np.asarray(arr) + create_Df(op, arr, temp_path[j][1], index=fileName, columns=["freq (events/min)", "amplitude"]) + create_csv( + op, + arr, + "freqAndAmp_" + temp_path[j][1] + ".csv", + index=fileName, + columns=["freq (events/min)", "amplitude"], + ) + logger.info("Results for frequency and amplitude of transients in z-score data are combined.") + def executeFindFreqAndAmp(inputParameters): - logger.info('Finding transients in z-score data and calculating frequency and amplitude....') - - inputParameters = inputParameters - - average = inputParameters['averageForGroup'] - folderNamesForAvg = inputParameters['folderNamesForAvg'] - folderNames = inputParameters['folderNames'] - combine_data = inputParameters['combine_data'] - moving_window = inputParameters['moving_window'] - numProcesses = inputParameters['numberOfCores'] - if numProcesses==0: - numProcesses = mp.cpu_count() - elif numProcesses>mp.cpu_count(): - logger.warning('Warning : # of cores parameter set is greater than the cores available \ - available in your machine') - numProcesses = mp.cpu_count()-1 - - if average==True: - if len(folderNamesForAvg)>0: - storesListPath = [] - for i in range(len(folderNamesForAvg)): - filepath = folderNamesForAvg[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - storesListPath = np.concatenate(storesListPath) - averageForGroup(storesListPath, inputParameters) - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - else: - logger.error('Not a single folder name is provided in folderNamesForAvg in inputParamters File.') - raise Exception('Not a single folder name is provided in folderNamesForAvg in inputParamters File.') - - - else: - if combine_data==True: - storesListPath = [] - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - storesListPath = list(np.concatenate(storesListPath).flatten()) - op = get_all_stores_for_combining_data(storesListPath) - for i in range(len(op)): - filepath = op[i][0] - storesList = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - plt.show() - else: - for i in range(len(folderNames)): - logger.debug(f"Finding transients in z-score data of {folderNames[i]} and calculating frequency and amplitude.") - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - logger.info('Transients in z-score data found and frequency and amplitude are calculated.') - plt.show() - - logger.info('Transients in z-score data found and frequency and amplitude are calculated.') + logger.info("Finding transients in z-score data and calculating frequency and amplitude....") + + inputParameters = inputParameters + + average = inputParameters["averageForGroup"] + folderNamesForAvg = inputParameters["folderNamesForAvg"] + folderNames = inputParameters["folderNames"] + combine_data = inputParameters["combine_data"] + moving_window = inputParameters["moving_window"] + numProcesses = inputParameters["numberOfCores"] + if numProcesses == 0: + numProcesses = mp.cpu_count() + elif numProcesses > mp.cpu_count(): + logger.warning( + "Warning : # of cores parameter set is greater than the cores available \ + available in your machine" + ) + numProcesses = mp.cpu_count() - 1 + + if average == True: + if len(folderNamesForAvg) > 0: + storesListPath = [] + for i in range(len(folderNamesForAvg)): + filepath = folderNamesForAvg[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + averageForGroup(storesListPath, inputParameters) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + else: + logger.error("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + raise Exception("Not a single folder name is provided in folderNamesForAvg in inputParamters File.") + + else: + if combine_data == True: + storesListPath = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + for i in range(len(op)): + filepath = op[i][0] + storesList = np.genfromtxt( + os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + plt.show() + else: + for i in range(len(folderNames)): + logger.debug( + f"Finding transients in z-score data of {folderNames[i]} and calculating frequency and amplitude." + ) + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt( + os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + findFreqAndAmp(filepath, inputParameters, window=moving_window, numProcesses=numProcesses) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + logger.info("Transients in z-score data found and frequency and amplitude are calculated.") + plt.show() + + logger.info("Transients in z-score data found and frequency and amplitude are calculated.") if __name__ == "__main__": - try: - executeFindFreqAndAmp(json.loads(sys.argv[1])) - logger.info('#'*400) - except Exception as e: - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(str(-1)+"\n") - logger.error(str(e)) - raise e + try: + executeFindFreqAndAmp(json.loads(sys.argv[1])) + logger.info("#" * 400) + except Exception as e: + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(str(-1) + "\n") + logger.error(str(e)) + raise e diff --git a/src/guppy/logging_config.py b/src/guppy/logging_config.py index 2de2e14..ccb2339 100644 --- a/src/guppy/logging_config.py +++ b/src/guppy/logging_config.py @@ -13,12 +13,13 @@ import logging import os from pathlib import Path + from platformdirs import user_log_dir def get_log_file(): """Get the platform-appropriate log file path. - + Returns ------- Path @@ -31,9 +32,9 @@ def get_log_file(): def setup_logging(*, level=None, console_output=True): """Configure centralized logging for GuPPy. - + This should be called once at application startup, before importing other modules. - + Parameters ---------- level : int, optional @@ -44,31 +45,28 @@ def setup_logging(*, level=None, console_output=True): """ # Determine log level if level is None: - env_level = os.environ.get('GUPPY_LOG_LEVEL', 'INFO').upper() + env_level = os.environ.get("GUPPY_LOG_LEVEL", "INFO").upper() level = getattr(logging, env_level, logging.INFO) - + # Get log file path log_file = get_log_file() - + # Configure root logger for guppy logger = logging.getLogger("guppy") logger.setLevel(level) - + # Prevent duplicate handlers if setup_logging is called multiple times if logger.handlers: return - + # Create formatter - formatter = logging.Formatter( - '%(asctime)s %(name)s %(levelname)s %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + # File handler file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) - + # Console handler (optional) if console_output: console_handler = logging.StreamHandler() diff --git a/src/guppy/main.py b/src/guppy/main.py index 070b324..9117385 100644 --- a/src/guppy/main.py +++ b/src/guppy/main.py @@ -1,18 +1,22 @@ """ Main entry point for GuPPy (Guided Photometry Analysis in Python) """ + from . import logging_config # Logging must be configured before importing application modules so that module-level loggers inherit the proper handlers and formatters logging_config.setup_logging() import panel as pn + from .savingInputParameters import savingInputParameters + def main(): """Main entry point for GuPPy""" template = savingInputParameters() pn.serve(template, show=True) + if __name__ == "__main__": main() diff --git a/src/guppy/preprocess.py b/src/guppy/preprocess.py index be7c0bf..a4370c0 100755 --- a/src/guppy/preprocess.py +++ b/src/guppy/preprocess.py @@ -1,1190 +1,1226 @@ -import os -import sys -import json +import fnmatch import glob -import time +import json +import logging +import os import re -import fnmatch +import shutil +import sys + +import h5py +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import h5py -import math -import shutil from scipy import signal as ss from scipy.optimize import curve_fit -import matplotlib.pyplot as plt -from matplotlib.widgets import MultiCursor -from pathlib import Path + from .combineDataFn import processTimestampsForCombiningData -import logging logger = logging.getLogger(__name__) # Only set matplotlib backend if not in CI environment -if not os.getenv('CI'): - plt.switch_backend('TKAgg') +if not os.getenv("CI"): + plt.switch_backend("TKAgg") + def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths)-set(removePaths)) + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + def writeToFile(value: str): - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(value) + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(value) + # find files by ignoring the case sensitivity -def find_files(path, glob_path, ignore_case = False): - rule = re.compile(fnmatch.translate(glob_path), re.IGNORECASE) if ignore_case \ - else re.compile(fnmatch.translate(glob_path)) +def find_files(path, glob_path, ignore_case=False): + rule = ( + re.compile(fnmatch.translate(glob_path), re.IGNORECASE) + if ignore_case + else re.compile(fnmatch.translate(glob_path)) + ) - no_bytes_path = os.listdir(os.path.expanduser(path)) - str_path = [] + no_bytes_path = os.listdir(os.path.expanduser(path)) + str_path = [] - # converting byte object to string - for x in no_bytes_path: - try: - str_path.append(x.decode('utf-8')) - except: - str_path.append(x) - return [os.path.join(path,n) for n in str_path if rule.match(n)] + # converting byte object to string + for x in no_bytes_path: + try: + str_path.append(x.decode("utf-8")) + except: + str_path.append(x) + return [os.path.join(path, n) for n in str_path if rule.match(n)] # curve fit exponential function -def curveFitFn(x,a,b,c): - return a+(b*np.exp(-(1/c)*x)) +def curveFitFn(x, a, b, c): + return a + (b * np.exp(-(1 / c) * x)) # helper function to create control channel using signal channel # by curve fitting signal channel to exponential function # when there is no isosbestic control channel is present def helper_create_control_channel(signal, timestamps, window): - # check if window is greater than signal shape - if window>signal.shape[0]: - window = ((signal.shape[0]+1)/2)+1 - if window%2 != 0: - window = window - else: - window = window + 1 + # check if window is greater than signal shape + if window > signal.shape[0]: + window = ((signal.shape[0] + 1) / 2) + 1 + if window % 2 != 0: + window = window + else: + window = window + 1 - filtered_signal = ss.savgol_filter(signal, window_length=window, polyorder=3) + filtered_signal = ss.savgol_filter(signal, window_length=window, polyorder=3) - p0 = [5,50,60] + p0 = [5, 50, 60] - try: - popt, pcov = curve_fit(curveFitFn, timestamps, filtered_signal, p0) - except Exception as e: - logger.error(str(e)) + try: + popt, pcov = curve_fit(curveFitFn, timestamps, filtered_signal, p0) + except Exception as e: + logger.error(str(e)) - #logger.info('Curve Fit Parameters : ', popt) - control = curveFitFn(timestamps,*popt) + # logger.info('Curve Fit Parameters : ', popt) + control = curveFitFn(timestamps, *popt) + + return control - return control # main function to create control channel using # signal channel and save it to a file def create_control_channel(filepath, arr, window=5001): - storenames = arr[0,:] - storesList = arr[1,:] - - - for i in range(storesList.shape[0]): - event_name, event = storesList[i], storenames[i] - if 'control' in event_name.lower() and 'cntrl' in event.lower(): - logger.debug('Creating control channel from signal channel using curve-fitting') - name = event_name.split('_')[-1] - signal = read_hdf5('signal_'+name, filepath, 'data') - timestampNew = read_hdf5('timeCorrection_'+name, filepath, 'timestampNew') - sampling_rate = np.full(timestampNew.shape, np.nan) - sampling_rate[0] = read_hdf5('timeCorrection_'+name, filepath, 'sampling_rate')[0] - - control = helper_create_control_channel(signal, timestampNew, window) - - write_hdf5(control, event_name, filepath, 'data') - d = { - 'timestamps': timestampNew, - 'data': control, - 'sampling_rate': sampling_rate - } - df = pd.DataFrame(d) - df.to_csv(os.path.join(os.path.dirname(filepath), event.lower()+'.csv'), index=False) - logger.info('Control channel from signal channel created using curve-fitting') - - -# function to add control channel when there is no + storenames = arr[0, :] + storesList = arr[1, :] + + for i in range(storesList.shape[0]): + event_name, event = storesList[i], storenames[i] + if "control" in event_name.lower() and "cntrl" in event.lower(): + logger.debug("Creating control channel from signal channel using curve-fitting") + name = event_name.split("_")[-1] + signal = read_hdf5("signal_" + name, filepath, "data") + timestampNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") + sampling_rate = np.full(timestampNew.shape, np.nan) + sampling_rate[0] = read_hdf5("timeCorrection_" + name, filepath, "sampling_rate")[0] + + control = helper_create_control_channel(signal, timestampNew, window) + + write_hdf5(control, event_name, filepath, "data") + d = {"timestamps": timestampNew, "data": control, "sampling_rate": sampling_rate} + df = pd.DataFrame(d) + df.to_csv(os.path.join(os.path.dirname(filepath), event.lower() + ".csv"), index=False) + logger.info("Control channel from signal channel created using curve-fitting") + + +# function to add control channel when there is no # isosbestic control channel and update the storeslist file def add_control_channel(filepath, arr): - storenames = arr[0,:] - storesList = np.char.lower(arr[1,:]) - - keep_control = np.array([]) - # check a case if there is isosbestic control channel present - for i in range(storesList.shape[0]): - if 'control' in storesList[i].lower(): - name = storesList[i].split('_')[-1] - new_str = 'signal_'+str(name).lower() - find_signal = [True for i in storesList if i==new_str] - if len(find_signal)>1: - logger.error('Error in naming convention of files or Error in storesList file') - raise Exception('Error in naming convention of files or Error in storesList file') - if len(find_signal)==0: - logger.error("Isosbectic control channel parameter is set to False and still \ - storeslist file shows there is control channel present") - raise Exception("Isosbectic control channel parameter is set to False and still \ - storeslist file shows there is control channel present") - else: - continue - - for i in range(storesList.shape[0]): - if 'signal' in storesList[i].lower(): - name = storesList[i].split('_')[-1] - new_str = 'control_'+str(name).lower() - find_signal = [True for i in storesList if i==new_str] - if len(find_signal)==0: - src, dst = os.path.join(filepath, arr[0,i]+'.hdf5'), os.path.join(filepath, 'cntrl'+str(i)+'.hdf5') - shutil.copyfile(src,dst) - arr = np.concatenate((arr, [['cntrl'+str(i)],['control_'+str(arr[1,i].split('_')[-1])]]), axis=1) - - np.savetxt(os.path.join(filepath, 'storesList.csv'), arr, delimiter=",", fmt='%s') - - return arr + storenames = arr[0, :] + storesList = np.char.lower(arr[1, :]) + + keep_control = np.array([]) + # check a case if there is isosbestic control channel present + for i in range(storesList.shape[0]): + if "control" in storesList[i].lower(): + name = storesList[i].split("_")[-1] + new_str = "signal_" + str(name).lower() + find_signal = [True for i in storesList if i == new_str] + if len(find_signal) > 1: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + if len(find_signal) == 0: + logger.error( + "Isosbectic control channel parameter is set to False and still \ + storeslist file shows there is control channel present" + ) + raise Exception( + "Isosbectic control channel parameter is set to False and still \ + storeslist file shows there is control channel present" + ) + else: + continue + + for i in range(storesList.shape[0]): + if "signal" in storesList[i].lower(): + name = storesList[i].split("_")[-1] + new_str = "control_" + str(name).lower() + find_signal = [True for i in storesList if i == new_str] + if len(find_signal) == 0: + src, dst = os.path.join(filepath, arr[0, i] + ".hdf5"), os.path.join( + filepath, "cntrl" + str(i) + ".hdf5" + ) + shutil.copyfile(src, dst) + arr = np.concatenate((arr, [["cntrl" + str(i)], ["control_" + str(arr[1, i].split("_")[-1])]]), axis=1) + + np.savetxt(os.path.join(filepath, "storesList.csv"), arr, delimiter=",", fmt="%s") + + return arr + # check if dealing with TDT files or csv files def check_TDT(filepath): - path = glob.glob(os.path.join(filepath, '*.tsq')) - if len(path)>0: - return True - else: - return False + path = glob.glob(os.path.join(filepath, "*.tsq")) + if len(path) > 0: + return True + else: + return False + # function to read hdf5 file def read_hdf5(event, filepath, key): - if event: - event = event.replace("\\","_") - event = event.replace("/","_") - op = os.path.join(filepath, event+'.hdf5') - else: - op = filepath - - if os.path.exists(op): - with h5py.File(op, 'r') as f: - arr = np.asarray(f[key]) - else: - logger.error(f"{event}.hdf5 file does not exist") - raise Exception('{}.hdf5 file does not exist'.format(event)) - - return arr + if event: + event = event.replace("\\", "_") + event = event.replace("/", "_") + op = os.path.join(filepath, event + ".hdf5") + else: + op = filepath + + if os.path.exists(op): + with h5py.File(op, "r") as f: + arr = np.asarray(f[key]) + else: + logger.error(f"{event}.hdf5 file does not exist") + raise Exception("{}.hdf5 file does not exist".format(event)) + + return arr + # function to write hdf5 file def write_hdf5(data, event, filepath, key): - event = event.replace("\\","_") - event = event.replace("/","_") - op = os.path.join(filepath, event+'.hdf5') - - # if file does not exist create a new file - if not os.path.exists(op): - with h5py.File(op, 'w') as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - - # if file already exists, append data to it or add a new key to it - else: - with h5py.File(op, 'r+') as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) + event = event.replace("\\", "_") + event = event.replace("/", "_") + op = os.path.join(filepath, event + ".hdf5") + + # if file does not exist create a new file + if not os.path.exists(op): + with h5py.File(op, "w") as f: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) + + # if file already exists, append data to it or add a new key to it + else: + with h5py.File(op, "r+") as f: + if key in list(f.keys()): + if type(data) is np.ndarray: + f[key].resize(data.shape) + arr = f[key] + arr[:] = data + else: + arr = f[key] + arr = data + else: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) # function to check control and signal channel has same length # if not, take a smaller length and do pre-processing def check_cntrl_sig_length(filepath, channels_arr, storenames, storesList): - indices = [] - for i in range(channels_arr.shape[1]): - idx_c = np.where(storesList==channels_arr[0,i])[0] - idx_s = np.where(storesList==channels_arr[1,i])[0] - control = read_hdf5(storenames[idx_c[0]], filepath, 'data') - signal = read_hdf5(storenames[idx_s[0]], filepath, 'data') - if control.shape[0]signal.shape[0]: - indices.append(storesList[idx_s[0]]) - else: - indices.append(storesList[idx_s[0]]) + indices = [] + for i in range(channels_arr.shape[1]): + idx_c = np.where(storesList == channels_arr[0, i])[0] + idx_s = np.where(storesList == channels_arr[1, i])[0] + control = read_hdf5(storenames[idx_c[0]], filepath, "data") + signal = read_hdf5(storenames[idx_s[0]], filepath, "data") + if control.shape[0] < signal.shape[0]: + indices.append(storesList[idx_c[0]]) + elif control.shape[0] > signal.shape[0]: + indices.append(storesList[idx_s[0]]) + else: + indices.append(storesList[idx_s[0]]) - return indices + return indices # function to correct timestamps after eliminating first few seconds of the data (for csv data) def timestampCorrection_csv(filepath, timeForLightsTurnOn, storesList): - - logger.debug(f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds") - storenames = storesList[0,:] - storesList = storesList[1,:] - arr = [] - for i in range(storesList.shape[0]): - if 'control' in storesList[i].lower() or 'signal' in storesList[i].lower(): - arr.append(storesList[i]) + logger.debug( + f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds" + ) + storenames = storesList[0, :] + storesList = storesList[1, :] - arr = sorted(arr, key=str.casefold) - try: - arr = np.asarray(arr).reshape(2,-1) - except: - logger.error('Error in saving stores list file or spelling mistake for control or signal') - raise Exception('Error in saving stores list file or spelling mistake for control or signal') + arr = [] + for i in range(storesList.shape[0]): + if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): + arr.append(storesList[i]) - indices = check_cntrl_sig_length(filepath, arr, storenames, storesList) + arr = sorted(arr, key=str.casefold) + try: + arr = np.asarray(arr).reshape(2, -1) + except: + logger.error("Error in saving stores list file or spelling mistake for control or signal") + raise Exception("Error in saving stores list file or spelling mistake for control or signal") - for i in range(arr.shape[1]): - name_1 = arr[0,i].split('_')[-1] - name_2 = arr[1,i].split('_')[-1] - #dirname = os.path.dirname(path[i]) - idx = np.where(storesList==indices[i])[0] + indices = check_cntrl_sig_length(filepath, arr, storenames, storesList) - if idx.shape[0]==0: - logger.error(f"{arr[0,i]} does not exist in the stores list file.") - raise Exception('{} does not exist in the stores list file.'.format(arr[0,i])) + for i in range(arr.shape[1]): + name_1 = arr[0, i].split("_")[-1] + name_2 = arr[1, i].split("_")[-1] + # dirname = os.path.dirname(path[i]) + idx = np.where(storesList == indices[i])[0] - timestamp = read_hdf5(storenames[idx][0], filepath, 'timestamps') - sampling_rate = read_hdf5(storenames[idx][0], filepath, 'sampling_rate') + if idx.shape[0] == 0: + logger.error(f"{arr[0,i]} does not exist in the stores list file.") + raise Exception("{} does not exist in the stores list file.".format(arr[0, i])) - if name_1==name_2: - correctionIndex = np.where(timestamp>=timeForLightsTurnOn)[0] - timestampNew = timestamp[correctionIndex] - write_hdf5(timestampNew, 'timeCorrection_'+name_1, filepath, 'timestampNew') - write_hdf5(correctionIndex, 'timeCorrection_'+name_1, filepath, 'correctionIndex') - write_hdf5(np.asarray(sampling_rate), 'timeCorrection_'+name_1, filepath, 'sampling_rate') + timestamp = read_hdf5(storenames[idx][0], filepath, "timestamps") + sampling_rate = read_hdf5(storenames[idx][0], filepath, "sampling_rate") - else: - logger.error('Error in naming convention of files or Error in storesList file') - raise Exception('Error in naming convention of files or Error in storesList file') + if name_1 == name_2: + correctionIndex = np.where(timestamp >= timeForLightsTurnOn)[0] + timestampNew = timestamp[correctionIndex] + write_hdf5(timestampNew, "timeCorrection_" + name_1, filepath, "timestampNew") + write_hdf5(correctionIndex, "timeCorrection_" + name_1, filepath, "correctionIndex") + write_hdf5(np.asarray(sampling_rate), "timeCorrection_" + name_1, filepath, "sampling_rate") - logger.info("Timestamps corrected and converted to seconds.") + else: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + logger.info("Timestamps corrected and converted to seconds.") # function to correct timestamps after eliminating first few seconds of the data (for TDT data) def timestampCorrection_tdt(filepath, timeForLightsTurnOn, storesList): - logger.debug(f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds") - storenames = storesList[0,:] - storesList = storesList[1,:] - - arr = [] - for i in range(storesList.shape[0]): - if 'control' in storesList[i].lower() or 'signal' in storesList[i].lower(): - arr.append(storesList[i]) - - arr = sorted(arr, key=str.casefold) - - try: - arr = np.asarray(arr).reshape(2,-1) - except: - logger.error('Error in saving stores list file or spelling mistake for control or signal') - raise Exception('Error in saving stores list file or spelling mistake for control or signal') - - indices = check_cntrl_sig_length(filepath, arr, storenames, storesList) - - for i in range(arr.shape[1]): - name_1 = arr[0,i].split('_')[-1] - name_2 = arr[1,i].split('_')[-1] - #dirname = os.path.dirname(path[i]) - idx = np.where(storesList==indices[i])[0] - - if idx.shape[0]==0: - logger.error(f"{arr[0,i]} does not exist in the stores list file.") - raise Exception('{} does not exist in the stores list file.'.format(arr[0,i])) - - timestamp = read_hdf5(storenames[idx][0], filepath, 'timestamps') - npoints = read_hdf5(storenames[idx][0], filepath, 'npoints') - sampling_rate = read_hdf5(storenames[idx][0], filepath, 'sampling_rate') - - if name_1==name_2: - timeRecStart = timestamp[0] - timestamps = np.subtract(timestamp, timeRecStart) - adder = np.arange(npoints)/sampling_rate - lengthAdder = adder.shape[0] - timestampNew = np.zeros((len(timestamps), lengthAdder)) - for i in range(lengthAdder): - timestampNew[:,i] = np.add(timestamps, adder[i]) - timestampNew = (timestampNew.T).reshape(-1, order='F') - correctionIndex = np.where(timestampNew>=timeForLightsTurnOn)[0] - timestampNew = timestampNew[correctionIndex] - - write_hdf5(np.asarray([timeRecStart]), 'timeCorrection_'+name_1, filepath, 'timeRecStart') - write_hdf5(timestampNew, 'timeCorrection_'+name_1, filepath, 'timestampNew') - write_hdf5(correctionIndex, 'timeCorrection_'+name_1, filepath, 'correctionIndex') - write_hdf5(np.asarray([sampling_rate]), 'timeCorrection_'+name_1, filepath, 'sampling_rate') - else: - logger.error('Error in naming convention of files or Error in storesList file') - raise Exception('Error in naming convention of files or Error in storesList file') - - logger.info("Timestamps corrected and converted to seconds.") - #return timeRecStart, correctionIndex, timestampNew - - -# function to apply correction to control, signal and event timestamps + logger.debug( + f"Correcting timestamps by getting rid of the first {timeForLightsTurnOn} seconds and convert timestamps to seconds" + ) + storenames = storesList[0, :] + storesList = storesList[1, :] + + arr = [] + for i in range(storesList.shape[0]): + if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): + arr.append(storesList[i]) + + arr = sorted(arr, key=str.casefold) + + try: + arr = np.asarray(arr).reshape(2, -1) + except: + logger.error("Error in saving stores list file or spelling mistake for control or signal") + raise Exception("Error in saving stores list file or spelling mistake for control or signal") + + indices = check_cntrl_sig_length(filepath, arr, storenames, storesList) + + for i in range(arr.shape[1]): + name_1 = arr[0, i].split("_")[-1] + name_2 = arr[1, i].split("_")[-1] + # dirname = os.path.dirname(path[i]) + idx = np.where(storesList == indices[i])[0] + + if idx.shape[0] == 0: + logger.error(f"{arr[0,i]} does not exist in the stores list file.") + raise Exception("{} does not exist in the stores list file.".format(arr[0, i])) + + timestamp = read_hdf5(storenames[idx][0], filepath, "timestamps") + npoints = read_hdf5(storenames[idx][0], filepath, "npoints") + sampling_rate = read_hdf5(storenames[idx][0], filepath, "sampling_rate") + + if name_1 == name_2: + timeRecStart = timestamp[0] + timestamps = np.subtract(timestamp, timeRecStart) + adder = np.arange(npoints) / sampling_rate + lengthAdder = adder.shape[0] + timestampNew = np.zeros((len(timestamps), lengthAdder)) + for i in range(lengthAdder): + timestampNew[:, i] = np.add(timestamps, adder[i]) + timestampNew = (timestampNew.T).reshape(-1, order="F") + correctionIndex = np.where(timestampNew >= timeForLightsTurnOn)[0] + timestampNew = timestampNew[correctionIndex] + + write_hdf5(np.asarray([timeRecStart]), "timeCorrection_" + name_1, filepath, "timeRecStart") + write_hdf5(timestampNew, "timeCorrection_" + name_1, filepath, "timestampNew") + write_hdf5(correctionIndex, "timeCorrection_" + name_1, filepath, "correctionIndex") + write_hdf5(np.asarray([sampling_rate]), "timeCorrection_" + name_1, filepath, "sampling_rate") + else: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + + logger.info("Timestamps corrected and converted to seconds.") + # return timeRecStart, correctionIndex, timestampNew + + +# function to apply correction to control, signal and event timestamps def applyCorrection(filepath, timeForLightsTurnOn, event, displayName, naming): - cond = check_TDT(os.path.dirname(filepath)) - - if cond==True: - timeRecStart = read_hdf5('timeCorrection_'+naming, filepath, 'timeRecStart')[0] - - timestampNew = read_hdf5('timeCorrection_'+naming, filepath, 'timestampNew') - correctionIndex = read_hdf5('timeCorrection_'+naming, filepath, 'correctionIndex') - - if 'control' in displayName.lower() or 'signal' in displayName.lower(): - split_name = displayName.split('_')[-1] - if split_name==naming: - pass - else: - correctionIndex = read_hdf5('timeCorrection_'+split_name, filepath, 'correctionIndex') - arr = read_hdf5(event, filepath, 'data') - if (arr==0).all()==True: - arr = arr - else: - arr = arr[correctionIndex] - write_hdf5(arr, displayName, filepath, 'data') - else: - arr = read_hdf5(event, filepath, 'timestamps') - if cond==True: - res = (arr>=timeRecStart).all() - if res==True: - arr = np.subtract(arr, timeRecStart) - arr = np.subtract(arr, timeForLightsTurnOn) - else: - arr = np.subtract(arr, timeForLightsTurnOn) - else: - arr = np.subtract(arr, timeForLightsTurnOn) - write_hdf5(arr, displayName+'_'+naming, filepath, 'ts') - - #if isosbestic_control==False and 'control' in displayName.lower(): - # control = create_control_channel(filepath, displayName) - # write_hdf5(control, displayName, filepath, 'data') + cond = check_TDT(os.path.dirname(filepath)) + + if cond == True: + timeRecStart = read_hdf5("timeCorrection_" + naming, filepath, "timeRecStart")[0] + + timestampNew = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") + correctionIndex = read_hdf5("timeCorrection_" + naming, filepath, "correctionIndex") + + if "control" in displayName.lower() or "signal" in displayName.lower(): + split_name = displayName.split("_")[-1] + if split_name == naming: + pass + else: + correctionIndex = read_hdf5("timeCorrection_" + split_name, filepath, "correctionIndex") + arr = read_hdf5(event, filepath, "data") + if (arr == 0).all() == True: + arr = arr + else: + arr = arr[correctionIndex] + write_hdf5(arr, displayName, filepath, "data") + else: + arr = read_hdf5(event, filepath, "timestamps") + if cond == True: + res = (arr >= timeRecStart).all() + if res == True: + arr = np.subtract(arr, timeRecStart) + arr = np.subtract(arr, timeForLightsTurnOn) + else: + arr = np.subtract(arr, timeForLightsTurnOn) + else: + arr = np.subtract(arr, timeForLightsTurnOn) + write_hdf5(arr, displayName + "_" + naming, filepath, "ts") + + # if isosbestic_control==False and 'control' in displayName.lower(): + # control = create_control_channel(filepath, displayName) + # write_hdf5(control, displayName, filepath, 'data') # function to check if naming convention was followed while saving storeslist file # and apply timestamps correction using the function applyCorrection def decide_naming_convention_and_applyCorrection(filepath, timeForLightsTurnOn, event, displayName, storesList): - logger.debug("Applying correction of timestamps to the data and event timestamps") - storesList = storesList[1,:] + logger.debug("Applying correction of timestamps to the data and event timestamps") + storesList = storesList[1, :] - arr = [] - for i in range(storesList.shape[0]): - if 'control' in storesList[i].lower() or 'signal' in storesList[i].lower(): - arr.append(storesList[i]) + arr = [] + for i in range(storesList.shape[0]): + if "control" in storesList[i].lower() or "signal" in storesList[i].lower(): + arr.append(storesList[i]) - arr = sorted(arr, key=str.casefold) - arr = np.asarray(arr).reshape(2,-1) + arr = sorted(arr, key=str.casefold) + arr = np.asarray(arr).reshape(2, -1) - for i in range(arr.shape[1]): - name_1 = arr[0,i].split('_')[-1] - name_2 = arr[1,i].split('_')[-1] - #dirname = os.path.dirname(path[i]) - if name_1==name_2: - applyCorrection(filepath, timeForLightsTurnOn, event, displayName, name_1) - else: - logger.error('Error in naming convention of files or Error in storesList file') - raise Exception('Error in naming convention of files or Error in storesList file') - - logger.info("Timestamps corrections applied to the data and event timestamps.") + for i in range(arr.shape[1]): + name_1 = arr[0, i].split("_")[-1] + name_2 = arr[1, i].split("_")[-1] + # dirname = os.path.dirname(path[i]) + if name_1 == name_2: + applyCorrection(filepath, timeForLightsTurnOn, event, displayName, name_1) + else: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") + logger.info("Timestamps corrections applied to the data and event timestamps.") -# functino to plot z_score +# function to plot z_score def visualize_z_score(filepath): - name = os.path.basename(filepath) + name = os.path.basename(filepath) - path = glob.glob(os.path.join(filepath, 'z_score_*')) - - path = sorted(path) + path = glob.glob(os.path.join(filepath, "z_score_*")) + + path = sorted(path) + + for i in range(len(path)): + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + x = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") + y = read_hdf5("", path[i], "data") + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(x, y) + ax.set_title(basename) + fig.suptitle(name) + # plt.show() - for i in range(len(path)): - basename = (os.path.basename(path[i])).split('.')[0] - name_1 = basename.split('_')[-1] - x = read_hdf5('timeCorrection_'+name_1, filepath, 'timestampNew') - y = read_hdf5('', path[i], 'data') - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot(x,y) - ax.set_title(basename) - fig.suptitle(name) - #plt.show() # function to plot deltaF/F def visualize_dff(filepath): - name = os.path.basename(filepath) + name = os.path.basename(filepath) - path = glob.glob(os.path.join(filepath, 'dff_*')) - - path = sorted(path) + path = glob.glob(os.path.join(filepath, "dff_*")) - for i in range(len(path)): - basename = (os.path.basename(path[i])).split('.')[0] - name_1 = basename.split('_')[-1] - x = read_hdf5('timeCorrection_'+name_1, filepath, 'timestampNew') - y = read_hdf5('', path[i], 'data') - fig = plt.figure() - ax = fig.add_subplot(111) - ax.plot(x,y) - ax.set_title(basename) - fig.suptitle(name) - #plt.show() + path = sorted(path) + for i in range(len(path)): + basename = (os.path.basename(path[i])).split(".")[0] + name_1 = basename.split("_")[-1] + x = read_hdf5("timeCorrection_" + name_1, filepath, "timestampNew") + y = read_hdf5("", path[i], "data") + fig = plt.figure() + ax = fig.add_subplot(111) + ax.plot(x, y) + ax.set_title(basename) + fig.suptitle(name) + # plt.show() def visualize(filepath, x, y1, y2, y3, plot_name, removeArtifacts): - - - # plotting control and signal data - - if (y1==0).all()==True: - y1 = np.zeros(x.shape[0]) - - coords_path = os.path.join(filepath, 'coordsForPreProcessing_'+plot_name[0].split('_')[-1]+'.npy') - name = os.path.basename(filepath) - fig = plt.figure() - ax1 = fig.add_subplot(311) - line1, = ax1.plot(x,y1) - ax1.set_title(plot_name[0]) - ax2 = fig.add_subplot(312) - line2, = ax2.plot(x,y2) - ax2.set_title(plot_name[1]) - ax3 = fig.add_subplot(313) - line3, = ax3.plot(x,y2) - line3, = ax3.plot(x,y3) - ax3.set_title(plot_name[2]) - fig.suptitle(name) - - hfont = {'fontname':'DejaVu Sans'} - - if removeArtifacts==True and os.path.exists(coords_path): - ax3.set_xlabel('Time(s) \n Note : Artifacts have been removed, but are not reflected in this plot.', **hfont) - else: - ax3.set_xlabel('Time(s)', **hfont) - - global coords - coords = [] - - # clicking 'space' key on keyboard will draw a line on the plot so that user can see what chunks are selected - # and clicking 'd' key on keyboard will deselect the selected point - def onclick(event): - #global ix, iy - - if event.key == ' ': - ix, iy = event.xdata, event.ydata - logger.info(f"x = {ix}, y = {iy}") - y1_max, y1_min = np.amax(y1), np.amin(y1) - y2_max, y2_min = np.amax(y2), np.amin(y2) - - #ax1.plot([ix,ix], [y1_max, y1_min], 'k--') - #ax2.plot([ix,ix], [y2_max, y2_min], 'k--') - - ax1.axvline(ix, c='black', ls='--') - ax2.axvline(ix, c='black', ls='--') - ax3.axvline(ix, c='black', ls='--') - - fig.canvas.draw() - - global coords - coords.append((ix, iy)) - - #if len(coords) == 2: - # fig.canvas.mpl_disconnect(cid) - - return coords - - elif event.key == 'd': - if len(coords)>0: - logger.info(f"x = {coords[-1][0]}, y = {coords[-1][1]}; deleted") - del coords[-1] - ax1.lines[-1].remove() - ax2.lines[-1].remove() - ax3.lines[-1].remove() - fig.canvas.draw() - - return coords - - # close the plot will save coordinates for all the selected chunks in the data - def plt_close_event(event): - global coords - if coords and len(coords)>0: - name_1 = plot_name[0].split('_')[-1] - np.save(os.path.join(filepath, 'coordsForPreProcessing_'+name_1+'.npy'), coords) - logger.info(f"Coordinates file saved at {os.path.join(filepath, 'coordsForPreProcessing_'+name_1+'.npy')}") - fig.canvas.mpl_disconnect(cid) - coords = [] - - - cid = fig.canvas.mpl_connect('key_press_event', onclick) - cid = fig.canvas.mpl_connect('close_event', plt_close_event) - #multi = MultiCursor(fig.canvas, (ax1, ax2), color='g', lw=1, horizOn=False, vertOn=True) - - #plt.show() - #return fig + + # plotting control and signal data + + if (y1 == 0).all() == True: + y1 = np.zeros(x.shape[0]) + + coords_path = os.path.join(filepath, "coordsForPreProcessing_" + plot_name[0].split("_")[-1] + ".npy") + name = os.path.basename(filepath) + fig = plt.figure() + ax1 = fig.add_subplot(311) + (line1,) = ax1.plot(x, y1) + ax1.set_title(plot_name[0]) + ax2 = fig.add_subplot(312) + (line2,) = ax2.plot(x, y2) + ax2.set_title(plot_name[1]) + ax3 = fig.add_subplot(313) + (line3,) = ax3.plot(x, y2) + (line3,) = ax3.plot(x, y3) + ax3.set_title(plot_name[2]) + fig.suptitle(name) + + hfont = {"fontname": "DejaVu Sans"} + + if removeArtifacts == True and os.path.exists(coords_path): + ax3.set_xlabel("Time(s) \n Note : Artifacts have been removed, but are not reflected in this plot.", **hfont) + else: + ax3.set_xlabel("Time(s)", **hfont) + + global coords + coords = [] + + # clicking 'space' key on keyboard will draw a line on the plot so that user can see what chunks are selected + # and clicking 'd' key on keyboard will deselect the selected point + def onclick(event): + # global ix, iy + + if event.key == " ": + ix, iy = event.xdata, event.ydata + logger.info(f"x = {ix}, y = {iy}") + y1_max, y1_min = np.amax(y1), np.amin(y1) + y2_max, y2_min = np.amax(y2), np.amin(y2) + + # ax1.plot([ix,ix], [y1_max, y1_min], 'k--') + # ax2.plot([ix,ix], [y2_max, y2_min], 'k--') + + ax1.axvline(ix, c="black", ls="--") + ax2.axvline(ix, c="black", ls="--") + ax3.axvline(ix, c="black", ls="--") + + fig.canvas.draw() + + global coords + coords.append((ix, iy)) + + # if len(coords) == 2: + # fig.canvas.mpl_disconnect(cid) + + return coords + + elif event.key == "d": + if len(coords) > 0: + logger.info(f"x = {coords[-1][0]}, y = {coords[-1][1]}; deleted") + del coords[-1] + ax1.lines[-1].remove() + ax2.lines[-1].remove() + ax3.lines[-1].remove() + fig.canvas.draw() + + return coords + + # close the plot will save coordinates for all the selected chunks in the data + def plt_close_event(event): + global coords + if coords and len(coords) > 0: + name_1 = plot_name[0].split("_")[-1] + np.save(os.path.join(filepath, "coordsForPreProcessing_" + name_1 + ".npy"), coords) + logger.info(f"Coordinates file saved at {os.path.join(filepath, 'coordsForPreProcessing_'+name_1+'.npy')}") + fig.canvas.mpl_disconnect(cid) + coords = [] + + cid = fig.canvas.mpl_connect("key_press_event", onclick) + cid = fig.canvas.mpl_connect("close_event", plt_close_event) + # multi = MultiCursor(fig.canvas, (ax1, ax2), color='g', lw=1, horizOn=False, vertOn=True) + + # plt.show() + # return fig + # function to plot control and signal, also provide a feature to select chunks for artifacts removal def visualizeControlAndSignal(filepath, removeArtifacts): - path_1 = find_files(filepath, 'control_*', ignore_case=True) #glob.glob(os.path.join(filepath, 'control*')) - - path_2 = find_files(filepath, 'signal_*', ignore_case=True) #glob.glob(os.path.join(filepath, 'signal*')) - - - path = sorted(path_1 + path_2, key=str.casefold) - - if len(path)%2 != 0: - logger.error('There are not equal number of Control and Signal data') - raise Exception('There are not equal number of Control and Signal data') - - path = np.asarray(path).reshape(2,-1) - - for i in range(path.shape[1]): - - name_1 = ((os.path.basename(path[0,i])).split('.')[0]).split('_') - name_2 = ((os.path.basename(path[1,i])).split('.')[0]).split('_') - - ts_path = os.path.join(filepath, 'timeCorrection_'+name_1[-1]+'.hdf5') - cntrl_sig_fit_path = os.path.join(filepath, 'cntrl_sig_fit_'+name_1[-1]+'.hdf5') - ts = read_hdf5('', ts_path, 'timestampNew') - - control = read_hdf5('', path[0,i], 'data').reshape(-1) - signal = read_hdf5('', path[1,i], 'data').reshape(-1) - cntrl_sig_fit = read_hdf5('', cntrl_sig_fit_path, 'data').reshape(-1) - - plot_name = [(os.path.basename(path[0,i])).split('.')[0], - (os.path.basename(path[1,i])).split('.')[0], - (os.path.basename(cntrl_sig_fit_path)).split('.')[0]] - visualize(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts) - - -# functino to check if the naming convention for saving storeslist file was followed or not + path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) + + path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) + + path = sorted(path_1 + path_2, key=str.casefold) + + if len(path) % 2 != 0: + logger.error("There are not equal number of Control and Signal data") + raise Exception("There are not equal number of Control and Signal data") + + path = np.asarray(path).reshape(2, -1) + + for i in range(path.shape[1]): + + name_1 = ((os.path.basename(path[0, i])).split(".")[0]).split("_") + name_2 = ((os.path.basename(path[1, i])).split(".")[0]).split("_") + + ts_path = os.path.join(filepath, "timeCorrection_" + name_1[-1] + ".hdf5") + cntrl_sig_fit_path = os.path.join(filepath, "cntrl_sig_fit_" + name_1[-1] + ".hdf5") + ts = read_hdf5("", ts_path, "timestampNew") + + control = read_hdf5("", path[0, i], "data").reshape(-1) + signal = read_hdf5("", path[1, i], "data").reshape(-1) + cntrl_sig_fit = read_hdf5("", cntrl_sig_fit_path, "data").reshape(-1) + + plot_name = [ + (os.path.basename(path[0, i])).split(".")[0], + (os.path.basename(path[1, i])).split(".")[0], + (os.path.basename(cntrl_sig_fit_path)).split(".")[0], + ] + visualize(filepath, ts, control, signal, cntrl_sig_fit, plot_name, removeArtifacts) + + +# function to check if the naming convention for saving storeslist file was followed or not def decide_naming_convention(filepath): - path_1 = find_files(filepath, 'control_*', ignore_case=True) #glob.glob(os.path.join(filepath, 'control*')) - - path_2 = find_files(filepath, 'signal_*', ignore_case=True) #glob.glob(os.path.join(filepath, 'signal*')) - - path = sorted(path_1 + path_2, key=str.casefold) - if len(path)%2 != 0: - logger.error('There are not equal number of Control and Signal data') - raise Exception('There are not equal number of Control and Signal data') - - path = np.asarray(path).reshape(2,-1) + path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) + + path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) - return path + path = sorted(path_1 + path_2, key=str.casefold) + if len(path) % 2 != 0: + logger.error("There are not equal number of Control and Signal data") + raise Exception("There are not equal number of Control and Signal data") + + path = np.asarray(path).reshape(2, -1) + + return path # function to read coordinates file which was saved by selecting chunks for artifacts removal def fetchCoords(filepath, naming, data): - path = os.path.join(filepath, 'coordsForPreProcessing_'+naming+'.npy') + path = os.path.join(filepath, "coordsForPreProcessing_" + naming + ".npy") - if not os.path.exists(path): - coords = np.array([0, data[-1]]) - else: - coords = np.load(os.path.join(filepath, 'coordsForPreProcessing_'+naming+'.npy'))[:,0] + if not os.path.exists(path): + coords = np.array([0, data[-1]]) + else: + coords = np.load(os.path.join(filepath, "coordsForPreProcessing_" + naming + ".npy"))[:, 0] - if coords.shape[0] % 2 != 0: - logger.error('Number of values in coordsForPreProcessing file is not even.') - raise Exception('Number of values in coordsForPreProcessing file is not even.') + if coords.shape[0] % 2 != 0: + logger.error("Number of values in coordsForPreProcessing file is not even.") + raise Exception("Number of values in coordsForPreProcessing file is not even.") - coords = coords.reshape(-1,2) + coords = coords.reshape(-1, 2) - return coords + return coords # helper function to process control and signal timestamps def eliminateData(filepath, timeForLightsTurnOn, event, sampling_rate, naming): - - ts = read_hdf5('timeCorrection_'+naming, filepath, 'timestampNew') - data = read_hdf5(event, filepath, 'data').reshape(-1) - coords = fetchCoords(filepath, naming, ts) - - if (data==0).all()==True: - data = np.zeros(ts.shape[0]) - - arr = np.array([]) - ts_arr = np.array([]) - for i in range(coords.shape[0]): - - index = np.where((ts>coords[i,0]) & (ts coords[i, 0]) & (ts < coords[i, 1]))[0] + + if len(arr) == 0: + arr = np.concatenate((arr, data[index])) + sub = ts[index][0] - timeForLightsTurnOn + new_ts = ts[index] - sub + ts_arr = np.concatenate((ts_arr, new_ts)) + else: + temp = data[index] + # new = temp + (arr[-1]-temp[0]) + temp_ts = ts[index] + new_ts = temp_ts - (temp_ts[0] - ts_arr[-1]) + arr = np.concatenate((arr, temp)) + ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) + + # logger.info(arr.shape, ts_arr.shape) + return arr, ts_arr # helper function to align event timestamps with the control and signal timestamps def eliminateTs(filepath, timeForLightsTurnOn, event, sampling_rate, naming): - - tsNew = read_hdf5('timeCorrection_'+naming, filepath, 'timestampNew') - ts = read_hdf5(event+'_'+naming, filepath, 'ts').reshape(-1) - coords = fetchCoords(filepath, naming, tsNew) - - ts_arr = np.array([]) - tsNew_arr = np.array([]) - for i in range(coords.shape[0]): - tsNew_index = np.where((tsNew>coords[i,0]) & (tsNewcoords[i,0]) & (ts coords[i, 0]) & (tsNew < coords[i, 1]))[0] + ts_index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] + + if len(tsNew_arr) == 0: + sub = tsNew[tsNew_index][0] - timeForLightsTurnOn + tsNew_arr = np.concatenate((tsNew_arr, tsNew[tsNew_index] - sub)) + ts_arr = np.concatenate((ts_arr, ts[ts_index] - sub)) + else: + temp_tsNew = tsNew[tsNew_index] + temp_ts = ts[ts_index] + new_ts = temp_ts - (temp_tsNew[0] - tsNew_arr[-1]) + new_tsNew = temp_tsNew - (temp_tsNew[0] - tsNew_arr[-1]) + tsNew_arr = np.concatenate((tsNew_arr, new_tsNew + (1 / sampling_rate))) + ts_arr = np.concatenate((ts_arr, new_ts + (1 / sampling_rate))) + + return ts_arr + + +# adding nan values to removed chunks # when using artifacts removal method - replace with NaN def addingNaNValues(filepath, event, naming): - - ts = read_hdf5('timeCorrection_'+naming, filepath, 'timestampNew') - data = read_hdf5(event, filepath, 'data').reshape(-1) - coords = fetchCoords(filepath, naming, ts) - if (data==0).all()==True: - data = np.zeros(ts.shape[0]) + ts = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") + data = read_hdf5(event, filepath, "data").reshape(-1) + coords = fetchCoords(filepath, naming, ts) + + if (data == 0).all() == True: + data = np.zeros(ts.shape[0]) - arr = np.array([]) - ts_index = np.arange(ts.shape[0]) - for i in range(coords.shape[0]): + arr = np.array([]) + ts_index = np.arange(ts.shape[0]) + for i in range(coords.shape[0]): - index = np.where((ts>coords[i,0]) & (ts coords[i, 0]) & (ts < coords[i, 1]))[0] + arr = np.concatenate((arr, index)) + + nan_indices = list(set(ts_index).symmetric_difference(arr)) + data[nan_indices] = np.nan + + return data - return data # remove event TTLs which falls in the removed chunks # when using artifacts removal method - replace with NaN def removeTTLs(filepath, event, naming): - tsNew = read_hdf5('timeCorrection_'+naming, filepath, 'timestampNew') - ts = read_hdf5(event+'_'+naming, filepath, 'ts').reshape(-1) - coords = fetchCoords(filepath, naming, tsNew) + tsNew = read_hdf5("timeCorrection_" + naming, filepath, "timestampNew") + ts = read_hdf5(event + "_" + naming, filepath, "ts").reshape(-1) + coords = fetchCoords(filepath, naming, tsNew) + + ts_arr = np.array([]) + for i in range(coords.shape[0]): + ts_index = np.where((ts > coords[i, 0]) & (ts < coords[i, 1]))[0] + ts_arr = np.concatenate((ts_arr, ts[ts_index])) + + return ts_arr - ts_arr = np.array([]) - for i in range(coords.shape[0]): - ts_index = np.where((ts>coords[i,0]) & (ts1: - b = np.divide(np.ones((filter_window,)), filter_window) - a = 1 - filtered_signal = ss.filtfilt(b, a, signal) - return filtered_signal - else: - raise Exception("Moving average filter window value is not correct.") + if filter_window == 0: + return signal + elif filter_window > 1: + b = np.divide(np.ones((filter_window,)), filter_window) + a = 1 + filtered_signal = ss.filtfilt(b, a, signal) + return filtered_signal + else: + raise Exception("Moving average filter window value is not correct.") + # function to filter control and signal channel, also execute above two function : controlFit and deltaFF # function will also take care if there is only signal channel and no control channel # if there is only signal channel, z-score will be computed using just signal channel def execute_controlFit_dff(control, signal, isosbestic_control, filter_window): - if isosbestic_control==False: - signal_smooth = filterSignal(filter_window, signal) #ss.filtfilt(b, a, signal) - control_fit = controlFit(control, signal_smooth) - norm_data = deltaFF(signal_smooth, control_fit) - else: - control_smooth = filterSignal(filter_window, control) #ss.filtfilt(b, a, control) - signal_smooth = filterSignal(filter_window, signal) #ss.filtfilt(b, a, signal) - control_fit = controlFit(control_smooth, signal_smooth) - norm_data = deltaFF(signal_smooth, control_fit) - - return norm_data, control_fit + if isosbestic_control == False: + signal_smooth = filterSignal(filter_window, signal) # ss.filtfilt(b, a, signal) + control_fit = controlFit(control, signal_smooth) + norm_data = deltaFF(signal_smooth, control_fit) + else: + control_smooth = filterSignal(filter_window, control) # ss.filtfilt(b, a, control) + signal_smooth = filterSignal(filter_window, signal) # ss.filtfilt(b, a, signal) + control_fit = controlFit(control_smooth, signal_smooth) + norm_data = deltaFF(signal_smooth, control_fit) + + return norm_data, control_fit + # function to compute z-score based on z-score computation method def z_score_computation(dff, timestamps, inputParameters): - zscore_method = inputParameters['zscore_method'] - baseline_start, baseline_end = inputParameters['baselineWindowStart'], inputParameters['baselineWindowEnd'] - - if zscore_method=='standard z-score': - numerator = np.subtract(dff, np.nanmean(dff)) - zscore = np.divide(numerator, np.nanstd(dff)) - elif zscore_method=='baseline z-score': - idx = np.where((timestamps>baseline_start) & (timestamps baseline_start) & (timestamps < baseline_end))[0] + if idx.shape[0] == 0: + logger.error( + "Baseline Window Parameters for baseline z-score computation zscore_method \ + are not correct." + ) + raise Exception( + "Baseline Window Parameters for baseline z-score computation zscore_method \ + are not correct." + ) + else: + baseline_mean = np.nanmean(dff[idx]) + baseline_std = np.nanstd(dff[idx]) + numerator = np.subtract(dff, baseline_mean) + zscore = np.divide(numerator, baseline_std) + else: + median = np.median(dff) + mad = np.median(np.abs(dff - median)) + numerator = 0.6745 * (dff - median) + zscore = np.divide(numerator, mad) + + return zscore + # helper function to compute z-score and deltaF/F -def helper_z_score(control, signal, filepath, name, inputParameters): #helper_z_score(control_smooth, signal_smooth): - - removeArtifacts = inputParameters['removeArtifacts'] - artifactsRemovalMethod = inputParameters['artifactsRemovalMethod'] - filter_window = inputParameters['filter_window'] - - isosbestic_control = inputParameters['isosbestic_control'] - tsNew = read_hdf5('timeCorrection_'+name, filepath, 'timestampNew') - coords_path = os.path.join(filepath, 'coordsForPreProcessing_'+name+'.npy') - - logger.info("Remove Artifacts : ", removeArtifacts) - - if (control==0).all()==True: - control = np.zeros(tsNew.shape[0]) - - z_score_arr = np.array([]) - norm_data_arr = np.full(tsNew.shape[0], np.nan) - control_fit_arr = np.full(tsNew.shape[0], np.nan) - temp_control_arr = np.full(tsNew.shape[0], np.nan) - - if removeArtifacts==True: - coords = fetchCoords(filepath, name, tsNew) - - # for artifacts removal, each chunk which was selected by user is being processed individually and then - # z-score is calculated - for i in range(coords.shape[0]): - tsNew_index = np.where((tsNew>coords[i,0]) & (tsNewcoords[i,1]) & (tsNew=tsNew[0]) & (tsNewcoords[-1]) & (tsNew<=tsNew[-1]))[0] - temp_control_arr[idx] = np.full(idx.shape[0], np.nan) - write_hdf5(temp_control_arr, 'control_'+name, filepath, 'data') - - return z_score_arr, norm_data_arr, control_fit_arr +def helper_z_score(control, signal, filepath, name, inputParameters): # helper_z_score(control_smooth, signal_smooth): + + removeArtifacts = inputParameters["removeArtifacts"] + artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] + filter_window = inputParameters["filter_window"] + + isosbestic_control = inputParameters["isosbestic_control"] + tsNew = read_hdf5("timeCorrection_" + name, filepath, "timestampNew") + coords_path = os.path.join(filepath, "coordsForPreProcessing_" + name + ".npy") + + logger.info("Remove Artifacts : ", removeArtifacts) + + if (control == 0).all() == True: + control = np.zeros(tsNew.shape[0]) + + z_score_arr = np.array([]) + norm_data_arr = np.full(tsNew.shape[0], np.nan) + control_fit_arr = np.full(tsNew.shape[0], np.nan) + temp_control_arr = np.full(tsNew.shape[0], np.nan) + + if removeArtifacts == True: + coords = fetchCoords(filepath, name, tsNew) + + # for artifacts removal, each chunk which was selected by user is being processed individually and then + # z-score is calculated + for i in range(coords.shape[0]): + tsNew_index = np.where((tsNew > coords[i, 0]) & (tsNew < coords[i, 1]))[0] + if isosbestic_control == False: + control_arr = helper_create_control_channel(signal[tsNew_index], tsNew[tsNew_index], window=101) + signal_arr = signal[tsNew_index] + norm_data, control_fit = execute_controlFit_dff( + control_arr, signal_arr, isosbestic_control, filter_window + ) + temp_control_arr[tsNew_index] = control_arr + if i < coords.shape[0] - 1: + blank_index = np.where((tsNew > coords[i, 1]) & (tsNew < coords[i + 1, 0]))[0] + temp_control_arr[blank_index] = np.full(blank_index.shape[0], np.nan) + else: + control_arr = control[tsNew_index] + signal_arr = signal[tsNew_index] + norm_data, control_fit = execute_controlFit_dff( + control_arr, signal_arr, isosbestic_control, filter_window + ) + norm_data_arr[tsNew_index] = norm_data + control_fit_arr[tsNew_index] = control_fit + + if artifactsRemovalMethod == "concatenate": + norm_data_arr = norm_data_arr[~np.isnan(norm_data_arr)] + control_fit_arr = control_fit_arr[~np.isnan(control_fit_arr)] + z_score = z_score_computation(norm_data_arr, tsNew, inputParameters) + z_score_arr = np.concatenate((z_score_arr, z_score)) + else: + tsNew_index = np.arange(tsNew.shape[0]) + norm_data, control_fit = execute_controlFit_dff(control, signal, isosbestic_control, filter_window) + z_score = z_score_computation(norm_data, tsNew, inputParameters) + z_score_arr = np.concatenate((z_score_arr, z_score)) + norm_data_arr[tsNew_index] = norm_data # np.concatenate((norm_data_arr, norm_data)) + control_fit_arr[tsNew_index] = control_fit # np.concatenate((control_fit_arr, control_fit)) + + # handle the case if there are chunks being cut in the front and the end + if isosbestic_control == False and removeArtifacts == True: + coords = coords.flatten() + # front chunk + idx = np.where((tsNew >= tsNew[0]) & (tsNew < coords[0]))[0] + temp_control_arr[idx] = np.full(idx.shape[0], np.nan) + # end chunk + idx = np.where((tsNew > coords[-1]) & (tsNew <= tsNew[-1]))[0] + temp_control_arr[idx] = np.full(idx.shape[0], np.nan) + write_hdf5(temp_control_arr, "control_" + name, filepath, "data") + + return z_score_arr, norm_data_arr, control_fit_arr # compute z-score and deltaF/F and save it to hdf5 file def compute_z_score(filepath, inputParameters): - logger.debug(f"Computing z-score for each of the data in {filepath}") - remove_artifacts = inputParameters['removeArtifacts'] - + logger.debug(f"Computing z-score for each of the data in {filepath}") + remove_artifacts = inputParameters["removeArtifacts"] - path_1 = find_files(filepath, 'control_*', ignore_case=True) #glob.glob(os.path.join(filepath, 'control*')) - path_2 = find_files(filepath, 'signal_*', ignore_case=True) #glob.glob(os.path.join(filepath, 'signal*')) - + path_1 = find_files(filepath, "control_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'control*')) + path_2 = find_files(filepath, "signal_*", ignore_case=True) # glob.glob(os.path.join(filepath, 'signal*')) - path = sorted(path_1 + path_2, key=str.casefold) + path = sorted(path_1 + path_2, key=str.casefold) + b = np.divide(np.ones((100,)), 100) + a = 1 - b = np.divide(np.ones((100,)), 100) - a = 1 + if len(path) % 2 != 0: + logger.error("There are not equal number of Control and Signal data") + raise Exception("There are not equal number of Control and Signal data") - if len(path)%2 != 0: - logger.error('There are not equal number of Control and Signal data') - raise Exception('There are not equal number of Control and Signal data') + path = np.asarray(path).reshape(2, -1) - path = np.asarray(path).reshape(2,-1) + for i in range(path.shape[1]): + name_1 = ((os.path.basename(path[0, i])).split(".")[0]).split("_") + name_2 = ((os.path.basename(path[1, i])).split(".")[0]).split("_") + # dirname = os.path.dirname(path[i]) - for i in range(path.shape[1]): - name_1 = ((os.path.basename(path[0,i])).split('.')[0]).split('_') - name_2 = ((os.path.basename(path[1,i])).split('.')[0]).split('_') - #dirname = os.path.dirname(path[i]) - - if name_1[-1]==name_2[-1]: - name = name_1[-1] - control = read_hdf5('', path[0,i], 'data').reshape(-1) - signal = read_hdf5('', path[1,i], 'data').reshape(-1) - #control_smooth = ss.filtfilt(b, a, control) - #signal_smooth = ss.filtfilt(b, a, signal) - #_score, dff = helper_z_score(control_smooth, signal_smooth) - z_score, dff, control_fit = helper_z_score(control, signal, filepath, name, inputParameters) - if remove_artifacts==True: - write_hdf5(z_score, 'z_score_'+name, filepath, 'data') - write_hdf5(dff, 'dff_'+name, filepath, 'data') - write_hdf5(control_fit, 'cntrl_sig_fit_'+name, filepath, 'data') - else: - write_hdf5(z_score, 'z_score_'+name, filepath, 'data') - write_hdf5(dff, 'dff_'+name, filepath, 'data') - write_hdf5(control_fit, 'cntrl_sig_fit_'+name, filepath, 'data') - else: - logger.error('Error in naming convention of files or Error in storesList file') - raise Exception('Error in naming convention of files or Error in storesList file') + if name_1[-1] == name_2[-1]: + name = name_1[-1] + control = read_hdf5("", path[0, i], "data").reshape(-1) + signal = read_hdf5("", path[1, i], "data").reshape(-1) + # control_smooth = ss.filtfilt(b, a, control) + # signal_smooth = ss.filtfilt(b, a, signal) + # _score, dff = helper_z_score(control_smooth, signal_smooth) + z_score, dff, control_fit = helper_z_score(control, signal, filepath, name, inputParameters) + if remove_artifacts == True: + write_hdf5(z_score, "z_score_" + name, filepath, "data") + write_hdf5(dff, "dff_" + name, filepath, "data") + write_hdf5(control_fit, "cntrl_sig_fit_" + name, filepath, "data") + else: + write_hdf5(z_score, "z_score_" + name, filepath, "data") + write_hdf5(dff, "dff_" + name, filepath, "data") + write_hdf5(control_fit, "cntrl_sig_fit_" + name, filepath, "data") + else: + logger.error("Error in naming convention of files or Error in storesList file") + raise Exception("Error in naming convention of files or Error in storesList file") - logger.info(f"z-score for the data in {filepath} computed.") - + logger.info(f"z-score for the data in {filepath} computed.") # function to execute timestamps corrections using functions timestampCorrection and decide_naming_convention_and_applyCorrection def execute_timestamp_correction(folderNames, inputParameters): - timeForLightsTurnOn = inputParameters['timeForLightsTurnOn'] - isosbestic_control = inputParameters['isosbestic_control'] + timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] + isosbestic_control = inputParameters["isosbestic_control"] - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - cond = check_TDT(folderNames[i]) - logger.debug(f"Timestamps corrections started for {filepath}") - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + cond = check_TDT(folderNames[i]) + logger.debug(f"Timestamps corrections started for {filepath}") + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ) - if isosbestic_control==False: - storesList = add_control_channel(filepath, storesList) - + if isosbestic_control == False: + storesList = add_control_channel(filepath, storesList) - if cond==True: - timestampCorrection_tdt(filepath, timeForLightsTurnOn, storesList) - else: - timestampCorrection_csv(filepath, timeForLightsTurnOn, storesList) + if cond == True: + timestampCorrection_tdt(filepath, timeForLightsTurnOn, storesList) + else: + timestampCorrection_csv(filepath, timeForLightsTurnOn, storesList) - for k in range(storesList.shape[1]): - decide_naming_convention_and_applyCorrection(filepath, timeForLightsTurnOn, - storesList[0,k], storesList[1,k], storesList) - - # check if isosbestic control is false and also if new control channel is added - if isosbestic_control==False: - create_control_channel(filepath, storesList, window=101) - - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - logger.info(f"Timestamps corrections finished for {filepath}") + for k in range(storesList.shape[1]): + decide_naming_convention_and_applyCorrection( + filepath, timeForLightsTurnOn, storesList[0, k], storesList[1, k], storesList + ) + # check if isosbestic control is false and also if new control channel is added + if isosbestic_control == False: + create_control_channel(filepath, storesList, window=101) + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + logger.info(f"Timestamps corrections finished for {filepath}") # for combining data, reading storeslist file from both data and create a new storeslist array def check_storeslistfile(folderNames): - storesList = np.array([[],[]]) - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.concatenate((storesList, np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1)), axis=1) - - storesList = np.unique(storesList, axis=1) - - return storesList + storesList = np.array([[], []]) + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1), + ), + axis=1, + ) + + storesList = np.unique(storesList, axis=1) + + return storesList + def get_all_stores_for_combining_data(folderNames): - op = [] - for i in range(100): - temp = [] - match = r'[\s\S]*'+'_output_'+str(i) - for j in folderNames: - temp.append(re.findall(match, j)) - temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold) - if len(temp)>0: - op.append(temp) + op = [] + for i in range(100): + temp = [] + match = r"[\s\S]*" + "_output_" + str(i) + for j in folderNames: + temp.append(re.findall(match, j)) + temp = sorted(list(np.concatenate(temp).flatten()), key=str.casefold) + if len(temp) > 0: + op.append(temp) - return op + return op # function to combine data when there are two different data files for the same recording session # it will combine the data, do timestamps processing and save the combined data in the first output folder. def combineData(folderNames, inputParameters, storesList): - logger.debug("Combining Data from different data files...") - timeForLightsTurnOn = inputParameters['timeForLightsTurnOn'] - op_folder = [] - for i in range(len(folderNames)): - filepath = folderNames[i] - op_folder.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - - op_folder = list(np.concatenate(op_folder).flatten()) - sampling_rate_fp = [] - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList_new = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - sampling_rate_fp.append(glob.glob(os.path.join(filepath, 'timeCorrection_*'))) - - # check if sampling rate is same for both data - sampling_rate_fp = np.concatenate(sampling_rate_fp) - sampling_rate = [] - for i in range(sampling_rate_fp.shape[0]): - sampling_rate.append(read_hdf5('', sampling_rate_fp[i], 'sampling_rate')) - - res = all(i == sampling_rate[0] for i in sampling_rate) - if res==False: - logger.error('To combine the data, sampling rate for both the data should be same.') - raise Exception('To combine the data, sampling rate for both the data should be same.') - - # get the output folders informatinos - op = get_all_stores_for_combining_data(op_folder) - - # processing timestamps for combining the data - processTimestampsForCombiningData(op, timeForLightsTurnOn, storesList, sampling_rate[0]) - logger.info("Data is combined from different data files.") - - - return op + logger.debug("Combining Data from different data files...") + timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] + op_folder = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + op_folder.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + + op_folder = list(np.concatenate(op_folder).flatten()) + sampling_rate_fp = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList_new = np.genfromtxt( + os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + sampling_rate_fp.append(glob.glob(os.path.join(filepath, "timeCorrection_*"))) + + # check if sampling rate is same for both data + sampling_rate_fp = np.concatenate(sampling_rate_fp) + sampling_rate = [] + for i in range(sampling_rate_fp.shape[0]): + sampling_rate.append(read_hdf5("", sampling_rate_fp[i], "sampling_rate")) + + res = all(i == sampling_rate[0] for i in sampling_rate) + if res == False: + logger.error("To combine the data, sampling rate for both the data should be same.") + raise Exception("To combine the data, sampling rate for both the data should be same.") + + # get the output folders informatinos + op = get_all_stores_for_combining_data(op_folder) + + # processing timestamps for combining the data + processTimestampsForCombiningData(op, timeForLightsTurnOn, storesList, sampling_rate[0]) + logger.info("Data is combined from different data files.") + + return op # function to compute z-score and deltaF/F using functions : compute_z_score and/or processTimestampsForArtifacts def execute_zscore(folderNames, inputParameters): - timeForLightsTurnOn = inputParameters['timeForLightsTurnOn'] - remove_artifacts = inputParameters['removeArtifacts'] - artifactsRemovalMethod = inputParameters['artifactsRemovalMethod'] - plot_zScore_dff = inputParameters['plot_zScore_dff'] - combine_data = inputParameters['combine_data'] - isosbestic_control = inputParameters['isosbestic_control'] - - storesListPath = [] - for i in range(len(folderNames)): - if combine_data==True: - storesListPath.append([folderNames[i][0]]) - else: - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - - storesListPath = np.concatenate(storesListPath) - - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - - if remove_artifacts==True: - logger.debug("Removing Artifacts from the data and correcting timestamps...") - compute_z_score(filepath, inputParameters) - if artifactsRemovalMethod=='concatenate': - processTimestampsForArtifacts(filepath, timeForLightsTurnOn, storesList) - else: - addingNaNtoChunksWithArtifacts(filepath, storesList) - visualizeControlAndSignal(filepath, remove_artifacts) - logger.info("Artifacts from the data are removed and timestamps are corrected.") - else: - compute_z_score(filepath, inputParameters) - visualizeControlAndSignal(filepath, remove_artifacts) - - if plot_zScore_dff=='z_score': - visualize_z_score(filepath) - if plot_zScore_dff=='dff': - visualize_dff(filepath) - if plot_zScore_dff=='Both': - visualize_z_score(filepath) - visualize_dff(filepath) - - writeToFile(str(10+((inputParameters['step']+1)*10))+'\n') - inputParameters['step'] += 1 - - plt.show() - logger.info("Signal data and event timestamps are extracted.") + timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] + remove_artifacts = inputParameters["removeArtifacts"] + artifactsRemovalMethod = inputParameters["artifactsRemovalMethod"] + plot_zScore_dff = inputParameters["plot_zScore_dff"] + combine_data = inputParameters["combine_data"] + isosbestic_control = inputParameters["isosbestic_control"] + + storesListPath = [] + for i in range(len(folderNames)): + if combine_data == True: + storesListPath.append([folderNames[i][0]]) + else: + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + + storesListPath = np.concatenate(storesListPath) + + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt(os.path.join(filepath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1) + + if remove_artifacts == True: + logger.debug("Removing Artifacts from the data and correcting timestamps...") + compute_z_score(filepath, inputParameters) + if artifactsRemovalMethod == "concatenate": + processTimestampsForArtifacts(filepath, timeForLightsTurnOn, storesList) + else: + addingNaNtoChunksWithArtifacts(filepath, storesList) + visualizeControlAndSignal(filepath, remove_artifacts) + logger.info("Artifacts from the data are removed and timestamps are corrected.") + else: + compute_z_score(filepath, inputParameters) + visualizeControlAndSignal(filepath, remove_artifacts) + + if plot_zScore_dff == "z_score": + visualize_z_score(filepath) + if plot_zScore_dff == "dff": + visualize_dff(filepath) + if plot_zScore_dff == "Both": + visualize_z_score(filepath) + visualize_dff(filepath) + + writeToFile(str(10 + ((inputParameters["step"] + 1) * 10)) + "\n") + inputParameters["step"] += 1 + + plt.show() + logger.info("Signal data and event timestamps are extracted.") def extractTsAndSignal(inputParameters): - logger.debug("Extracting signal data and event timestamps...") - inputParameters = inputParameters - - #storesList = np.genfromtxt(inputParameters['storesListPath'], dtype='str', delimiter=',') - - folderNames = inputParameters['folderNames'] - timeForLightsTurnOn = inputParameters['timeForLightsTurnOn'] - isosbestic_control = inputParameters['isosbestic_control'] - remove_artifacts = inputParameters['removeArtifacts'] - combine_data = inputParameters['combine_data'] - - inputParameters['step'] = 0 - logger.info(f"Remove Artifacts : {remove_artifacts}") - logger.info(f"Combine Data : {combine_data}") - logger.info(f"Isosbestic Control Channel : {isosbestic_control}") - storesListPath = [] - for i in range(len(folderNames)): - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], '*_output_*')))) - storesListPath = np.concatenate(storesListPath) - #pbMaxValue = storesListPath.shape[0] + len(folderNames) - #writeToFile(str((pbMaxValue+1)*10)+'\n'+str(10)+'\n') - if combine_data==False: - pbMaxValue = storesListPath.shape[0] + len(folderNames) - writeToFile(str((pbMaxValue+1)*10)+'\n'+str(10)+'\n') - execute_timestamp_correction(folderNames, inputParameters) - execute_zscore(folderNames, inputParameters) - else: - pbMaxValue = 1 + len(folderNames) - writeToFile(str((pbMaxValue)*10)+'\n'+str(10)+'\n') - execute_timestamp_correction(folderNames, inputParameters) - storesList = check_storeslistfile(folderNames) - op_folder = combineData(folderNames, inputParameters, storesList) - execute_zscore(op_folder, inputParameters) - - - + logger.debug("Extracting signal data and event timestamps...") + inputParameters = inputParameters + + # storesList = np.genfromtxt(inputParameters['storesListPath'], dtype='str', delimiter=',') + + folderNames = inputParameters["folderNames"] + timeForLightsTurnOn = inputParameters["timeForLightsTurnOn"] + isosbestic_control = inputParameters["isosbestic_control"] + remove_artifacts = inputParameters["removeArtifacts"] + combine_data = inputParameters["combine_data"] + + inputParameters["step"] = 0 + logger.info(f"Remove Artifacts : {remove_artifacts}") + logger.info(f"Combine Data : {combine_data}") + logger.info(f"Isosbestic Control Channel : {isosbestic_control}") + storesListPath = [] + for i in range(len(folderNames)): + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(folderNames[i], "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + # pbMaxValue = storesListPath.shape[0] + len(folderNames) + # writeToFile(str((pbMaxValue+1)*10)+'\n'+str(10)+'\n') + if combine_data == False: + pbMaxValue = storesListPath.shape[0] + len(folderNames) + writeToFile(str((pbMaxValue + 1) * 10) + "\n" + str(10) + "\n") + execute_timestamp_correction(folderNames, inputParameters) + execute_zscore(folderNames, inputParameters) + else: + pbMaxValue = 1 + len(folderNames) + writeToFile(str((pbMaxValue) * 10) + "\n" + str(10) + "\n") + execute_timestamp_correction(folderNames, inputParameters) + storesList = check_storeslistfile(folderNames) + op_folder = combineData(folderNames, inputParameters, storesList) + execute_zscore(op_folder, inputParameters) + + def main(input_parameters): - try: - extractTsAndSignal(input_parameters) - logger.info('#'*400) - except Exception as e: - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(str(-1)+"\n") - logger.error(str(e)) - raise e + try: + extractTsAndSignal(input_parameters) + logger.info("#" * 400) + except Exception as e: + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(str(-1) + "\n") + logger.error(str(e)) + raise e + if __name__ == "__main__": - input_parameters = json.loads(sys.argv[1]) - main(input_parameters=input_parameters) + input_parameters = json.loads(sys.argv[1]) + main(input_parameters=input_parameters) diff --git a/src/guppy/readTevTsq.py b/src/guppy/readTevTsq.py index 3a44b2e..6deb3b1 100755 --- a/src/guppy/readTevTsq.py +++ b/src/guppy/readTevTsq.py @@ -1,516 +1,545 @@ +import glob +import json +import logging +import multiprocessing as mp import os -import sys import re -import json +import sys import time -import glob -import h5py import warnings from itertools import repeat + +import h5py import numpy as np import pandas as pd -from numpy import int32, uint32, uint8, uint16, float64, int64, int32, float32 -import multiprocessing as mp -from pathlib import Path -import logging +from numpy import float32, float64, int32, int64, uint16 logger = logging.getLogger(__name__) + def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths)-set(removePaths)) + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + def writeToFile(value: str): - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(value) + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(value) -# functino to read tsq file + +# function to read tsq file def readtsq(filepath): - logger.debug("Trying to read tsq file.") - names = ('size', 'type', 'name', 'chan', 'sort_code', 'timestamp', - 'fp_loc', 'strobe', 'format', 'frequency') - formats = (int32, int32, 'S4', uint16, uint16, float64, int64, - float64, int32, float32) - offsets = 0, 4, 8, 12, 14, 16, 24, 24, 32, 36 - tsq_dtype = np.dtype({'names': names, 'formats': formats, - 'offsets': offsets}, align=True) - path = glob.glob(os.path.join(filepath, '*.tsq')) - if len(path)>1: - logger.error('Two tsq files are present at the location.') - raise Exception('Two tsq files are present at the location.') - elif len(path)==0: - logger.info("\033[1m"+"tsq file not found."+"\033[1m") - return 0, 0 - else: - path = path[0] - flag = 'tsq' - - # reading tsq file - tsq = np.fromfile(path, dtype=tsq_dtype) - - # creating dataframe of the data - df = pd.DataFrame(tsq) - - logger.info("Data from tsq file fetched.") - return df, flag + logger.debug("Trying to read tsq file.") + names = ("size", "type", "name", "chan", "sort_code", "timestamp", "fp_loc", "strobe", "format", "frequency") + formats = (int32, int32, "S4", uint16, uint16, float64, int64, float64, int32, float32) + offsets = 0, 4, 8, 12, 14, 16, 24, 24, 32, 36 + tsq_dtype = np.dtype({"names": names, "formats": formats, "offsets": offsets}, align=True) + path = glob.glob(os.path.join(filepath, "*.tsq")) + if len(path) > 1: + logger.error("Two tsq files are present at the location.") + raise Exception("Two tsq files are present at the location.") + elif len(path) == 0: + logger.info("\033[1m" + "tsq file not found." + "\033[1m") + return 0, 0 + else: + path = path[0] + flag = "tsq" + + # reading tsq file + tsq = np.fromfile(path, dtype=tsq_dtype) + + # creating dataframe of the data + df = pd.DataFrame(tsq) + + logger.info("Data from tsq file fetched.") + return df, flag + # function to check if doric file exists def check_doric(filepath): - logger.debug('Checking if doric file exists') - path = glob.glob(os.path.join(filepath, '*.csv')) + \ - glob.glob(os.path.join(filepath, '*.doric')) - - flag_arr = [] - for i in range(len(path)): - ext = os.path.basename(path[i]).split('.')[-1] - if ext=='csv': - with warnings.catch_warnings(): - warnings.simplefilter("error") - try: - df = pd.read_csv(path[i], index_col=False, dtype=float) - except: - df = pd.read_csv(path[i], header=1, index_col=False, nrows=10) - flag = 'doric_csv' - flag_arr.append(flag) - elif ext=='doric': - flag = 'doric_doric' - flag_arr.append(flag) - else: - pass - - if len(flag_arr)>1: - logger.error('Two doric files are present at the same location') - raise Exception('Two doric files are present at the same location') - if len(flag_arr)==0: - logger.error("\033[1m"+"Doric file not found."+"\033[1m") - return 0 - logger.info('Doric file found.') - return flag_arr[0] - + logger.debug("Checking if doric file exists") + path = glob.glob(os.path.join(filepath, "*.csv")) + glob.glob(os.path.join(filepath, "*.doric")) + + flag_arr = [] + for i in range(len(path)): + ext = os.path.basename(path[i]).split(".")[-1] + if ext == "csv": + with warnings.catch_warnings(): + warnings.simplefilter("error") + try: + df = pd.read_csv(path[i], index_col=False, dtype=float) + except: + df = pd.read_csv(path[i], header=1, index_col=False, nrows=10) + flag = "doric_csv" + flag_arr.append(flag) + elif ext == "doric": + flag = "doric_doric" + flag_arr.append(flag) + else: + pass + + if len(flag_arr) > 1: + logger.error("Two doric files are present at the same location") + raise Exception("Two doric files are present at the same location") + if len(flag_arr) == 0: + logger.error("\033[1m" + "Doric file not found." + "\033[1m") + return 0 + logger.info("Doric file found.") + return flag_arr[0] + # check if a particular element is there in an array or not def ismember(arr, element): - res = [1 if i==element else 0 for i in arr] + res = [1 if i == element else 0 for i in arr] return np.asarray(res) # function to write data to a hdf5 file def write_hdf5(data, event, filepath, key): - # replacing \\ or / in storenames with _ (to avoid errors while saving data) - event = event.replace("\\","_") - event = event.replace("/","_") - - op = os.path.join(filepath, event+'.hdf5') - - # if file does not exist create a new file - if not os.path.exists(op): - with h5py.File(op, 'w') as f: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) - - # if file already exists, append data to it or add a new key to it - else: - with h5py.File(op, 'r+') as f: - if key in list(f.keys()): - if type(data) is np.ndarray: - f[key].resize(data.shape) - arr = f[key] - arr[:] = data - else: - arr = f[key] - arr = data - else: - if type(data) is np.ndarray: - f.create_dataset(key, data=data, maxshape=(None,), chunks=True) - else: - f.create_dataset(key, data=data) + # replacing \\ or / in storenames with _ (to avoid errors while saving data) + event = event.replace("\\", "_") + event = event.replace("/", "_") + + op = os.path.join(filepath, event + ".hdf5") + + # if file does not exist create a new file + if not os.path.exists(op): + with h5py.File(op, "w") as f: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) + + # if file already exists, append data to it or add a new key to it + else: + with h5py.File(op, "r+") as f: + if key in list(f.keys()): + if type(data) is np.ndarray: + f[key].resize(data.shape) + arr = f[key] + arr[:] = data + else: + arr = f[key] + arr = data + else: + if type(data) is np.ndarray: + f.create_dataset(key, data=data, maxshape=(None,), chunks=True) + else: + f.create_dataset(key, data=data) # function to read event timestamps csv file. def import_csv(filepath, event, outputPath): - logger.debug("\033[1m"+"Trying to read data for {} from csv file.".format(event)+"\033[0m") - if not os.path.exists(os.path.join(filepath, event+'.csv')): - logger.error("\033[1m"+"No csv file found for event {}".format(event)+"\033[0m") - raise Exception("\033[1m"+"No csv file found for event {}".format(event)+"\033[0m") - - df = pd.read_csv(os.path.join(filepath, event+'.csv'), index_col=False) - data = df - key = list(df.columns) - - if len(key)==3: - arr1 = np.array(['timestamps', 'data', 'sampling_rate']) - arr2 = np.char.lower(np.array(key)) - if (np.sort(arr1)==np.sort(arr2)).all()==False: - logger.error("\033[1m"+"Column names should be timestamps, data and sampling_rate"+"\033[0m") - raise Exception("\033[1m"+"Column names should be timestamps, data and sampling_rate"+"\033[0m") - - if len(key)==1: - if key[0].lower()!='timestamps': - logger.error("\033[1m"+"Column names should be timestamps, data and sampling_rate"+"\033[0m") - raise Exception("\033[1m"+"Column name should be timestamps"+"\033[0m") - - if len(key)!=3 and len(key)!=1: - logger.error("\033[1m"+"Number of columns in csv file should be either three or one. Three columns if \ - the file is for control or signal data or one column if the file is for event TTLs."+"\033[0m") - raise Exception("\033[1m"+"Number of columns in csv file should be either three or one. Three columns if \ - the file is for control or signal data or one column if the file is for event TTLs."+"\033[0m") - - for i in range(len(key)): - write_hdf5(data[key[i]].dropna(), event, outputPath, key[i].lower()) - - logger.info("\033[1m"+"Reading data for {} from csv file is completed.".format(event)+"\033[0m") - - return data, key + logger.debug("\033[1m" + "Trying to read data for {} from csv file.".format(event) + "\033[0m") + if not os.path.exists(os.path.join(filepath, event + ".csv")): + logger.error("\033[1m" + "No csv file found for event {}".format(event) + "\033[0m") + raise Exception("\033[1m" + "No csv file found for event {}".format(event) + "\033[0m") + + df = pd.read_csv(os.path.join(filepath, event + ".csv"), index_col=False) + data = df + key = list(df.columns) + + if len(key) == 3: + arr1 = np.array(["timestamps", "data", "sampling_rate"]) + arr2 = np.char.lower(np.array(key)) + if (np.sort(arr1) == np.sort(arr2)).all() == False: + logger.error("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m") + raise Exception("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m") + + if len(key) == 1: + if key[0].lower() != "timestamps": + logger.error("\033[1m" + "Column names should be timestamps, data and sampling_rate" + "\033[0m") + raise Exception("\033[1m" + "Column name should be timestamps" + "\033[0m") + + if len(key) != 3 and len(key) != 1: + logger.error( + "\033[1m" + + "Number of columns in csv file should be either three or one. Three columns if \ + the file is for control or signal data or one column if the file is for event TTLs." + + "\033[0m" + ) + raise Exception( + "\033[1m" + + "Number of columns in csv file should be either three or one. Three columns if \ + the file is for control or signal data or one column if the file is for event TTLs." + + "\033[0m" + ) + + for i in range(len(key)): + write_hdf5(data[key[i]].dropna(), event, outputPath, key[i].lower()) + + logger.info("\033[1m" + "Reading data for {} from csv file is completed.".format(event) + "\033[0m") + + return data, key + # function to save data read from tev file to hdf5 file def save_dict_to_hdf5(S, event, outputPath): - write_hdf5(S['storename'], event, outputPath, 'storename') - write_hdf5(S['sampling_rate'], event, outputPath, 'sampling_rate') - write_hdf5(S['timestamps'], event, outputPath, 'timestamps') - - write_hdf5(S['data'], event, outputPath, 'data') - write_hdf5(S['npoints'], event, outputPath, 'npoints') - write_hdf5(S['channels'], event, outputPath, 'channels') + write_hdf5(S["storename"], event, outputPath, "storename") + write_hdf5(S["sampling_rate"], event, outputPath, "sampling_rate") + write_hdf5(S["timestamps"], event, outputPath, "timestamps") + write_hdf5(S["data"], event, outputPath, "data") + write_hdf5(S["npoints"], event, outputPath, "npoints") + write_hdf5(S["channels"], event, outputPath, "channels") # function to check event data (checking whether event timestamps belongs to same event or multiple events) def check_data(S, filepath, event, outputPath): - #logger.info("Checking event storename data for creating multiple event names from single event storename...") - new_event = event.replace("\\","") - new_event = event.replace("/","") - diff = np.diff(S['data']) - arr = np.full(diff.shape[0],1) - - storesList = np.genfromtxt(os.path.join(outputPath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - - if diff.shape[0]==0: - return 0 - - if S['sampling_rate']==0 and np.all(diff==diff[0])==False: - logger.info("\033[1m"+"Data in event {} belongs to multiple behavior".format(event)+"\033[0m") - logger.debug("\033[1m"+"Create timestamp files for individual new event and change the stores list file."+"\033[0m") - i_d = np.unique(S['data']) - for i in range(i_d.shape[0]): - new_S = dict() - idx = np.where(S['data']==i_d[i])[0] - new_S['timestamps'] = S['timestamps'][idx] - new_S['storename'] = new_event+str(int(i_d[i])) - new_S['sampling_rate'] = S['sampling_rate'] - new_S['data'] = S['data'] - new_S['npoints'] = S['npoints'] - new_S['channels'] = S['channels'] - storesList = np.concatenate((storesList, [[new_event+str(int(i_d[i]))], [new_event+'_'+str(int(i_d[i]))]]), axis=1) - save_dict_to_hdf5(new_S, new_event+str(int(i_d[i])), outputPath) - - idx = np.where(storesList[0]==event)[0] - storesList = np.delete(storesList,idx,axis=1) - if not os.path.exists(os.path.join(outputPath, '.cache_storesList.csv')): - os.rename(os.path.join(outputPath, 'storesList.csv'), os.path.join(outputPath, '.cache_storesList.csv')) - if idx.shape[0]==0: - pass - else: - np.savetxt(os.path.join(outputPath, 'storesList.csv'), storesList, delimiter=",", fmt='%s') - logger.info("\033[1m"+"Timestamp files for individual new event are created \ - and the stores list file is changed."+"\033[0m") - - + # logger.info("Checking event storename data for creating multiple event names from single event storename...") + new_event = event.replace("\\", "") + new_event = event.replace("/", "") + diff = np.diff(S["data"]) + arr = np.full(diff.shape[0], 1) + + storesList = np.genfromtxt(os.path.join(outputPath, "storesList.csv"), dtype="str", delimiter=",").reshape(2, -1) + + if diff.shape[0] == 0: + return 0 + + if S["sampling_rate"] == 0 and np.all(diff == diff[0]) == False: + logger.info("\033[1m" + "Data in event {} belongs to multiple behavior".format(event) + "\033[0m") + logger.debug( + "\033[1m" + "Create timestamp files for individual new event and change the stores list file." + "\033[0m" + ) + i_d = np.unique(S["data"]) + for i in range(i_d.shape[0]): + new_S = dict() + idx = np.where(S["data"] == i_d[i])[0] + new_S["timestamps"] = S["timestamps"][idx] + new_S["storename"] = new_event + str(int(i_d[i])) + new_S["sampling_rate"] = S["sampling_rate"] + new_S["data"] = S["data"] + new_S["npoints"] = S["npoints"] + new_S["channels"] = S["channels"] + storesList = np.concatenate( + (storesList, [[new_event + str(int(i_d[i]))], [new_event + "_" + str(int(i_d[i]))]]), axis=1 + ) + save_dict_to_hdf5(new_S, new_event + str(int(i_d[i])), outputPath) + + idx = np.where(storesList[0] == event)[0] + storesList = np.delete(storesList, idx, axis=1) + if not os.path.exists(os.path.join(outputPath, ".cache_storesList.csv")): + os.rename(os.path.join(outputPath, "storesList.csv"), os.path.join(outputPath, ".cache_storesList.csv")) + if idx.shape[0] == 0: + pass + else: + np.savetxt(os.path.join(outputPath, "storesList.csv"), storesList, delimiter=",", fmt="%s") + logger.info( + "\033[1m" + + "Timestamp files for individual new event are created \ + and the stores list file is changed." + + "\033[0m" + ) + # function to read tev file def readtev(data, filepath, event, outputPath): - logger.debug("Reading data for event {} ...".format(event)) - tevfilepath = glob.glob(os.path.join(filepath, '*.tev')) - if len(tevfilepath)>1: - raise Exception('Two tev files are present at the location.') - else: - tevfilepath = tevfilepath[0] + logger.debug("Reading data for event {} ...".format(event)) + tevfilepath = glob.glob(os.path.join(filepath, "*.tev")) + if len(tevfilepath) > 1: + raise Exception("Two tev files are present at the location.") + else: + tevfilepath = tevfilepath[0] + data["name"] = np.asarray(data["name"], dtype=str) - data['name'] = np.asarray(data['name'], dtype=str) + allnames = np.unique(data["name"]) - allnames = np.unique(data['name']) + index = [] + for i in range(len(allnames)): + length = len(str(allnames[i])) + if length < 4: + index.append(i) - index = [] - for i in range(len(allnames)): - length = len(str(allnames[i])) - if length<4: - index.append(i) + allnames = np.delete(allnames, index, 0) + eventNew = np.array(list(event)) - allnames = np.delete(allnames, index, 0) + # logger.info(allnames) + # logger.info(eventNew) + row = ismember(data["name"], event) + if sum(row) == 0: + logger.info("\033[1m" + "Requested store name " + event + " not found (case-sensitive)." + "\033[0m") + logger.info("\033[1m" + "File contains the following TDT store names:" + "\033[0m") + logger.info("\033[1m" + str(allnames) + "\033[0m") + logger.info("\033[1m" + "TDT store name " + str(event) + " not found." + "\033[0m") + import_csv(filepath, event, outputPath) - eventNew = np.array(list(event)) + return 0 - #logger.info(allnames) - #logger.info(eventNew) - row = ismember(data['name'], event) + allIndexesWhereEventIsPresent = np.where(row == 1) + first_row = allIndexesWhereEventIsPresent[0][0] + formatNew = data["format"][first_row] + 1 - if sum(row)==0: - logger.info("\033[1m"+"Requested store name "+event+" not found (case-sensitive)."+"\033[0m") - logger.info("\033[1m"+"File contains the following TDT store names:"+"\033[0m") - logger.info("\033[1m"+str(allnames)+"\033[0m") - logger.info("\033[1m"+"TDT store name "+str(event)+" not found."+"\033[0m") - import_csv(filepath, event, outputPath) + table = np.array( + [ + [0, 0, 0, 0], + [0, "float", 1, np.float32], + [0, "long", 1, np.int32], + [0, "short", 2, np.int16], + [0, "byte", 4, np.int8], + ] + ) - return 0 - - allIndexesWhereEventIsPresent = np.where(row==1) - first_row = allIndexesWhereEventIsPresent[0][0] + S = dict() - formatNew = data['format'][first_row]+1 + S["storename"] = str(event) + S["sampling_rate"] = data["frequency"][first_row] + S["timestamps"] = np.asarray(data["timestamp"][allIndexesWhereEventIsPresent[0]]) + S["channels"] = np.asarray(data["chan"][allIndexesWhereEventIsPresent[0]]) - table = np.array([[0,0,0,0], - [0,'float',1, np.float32], - [0,'long', 1, np.int32], - [0,'short',2, np.int16], - [0,'byte', 4, np.int8]]) + fp_loc = np.asarray(data["fp_loc"][allIndexesWhereEventIsPresent[0]]) + data_size = np.asarray(data["size"]) - S = dict() + if formatNew != 5: + nsample = (data_size[first_row,] - 10) * int(table[formatNew, 2]) + S["data"] = np.zeros((len(fp_loc), nsample)) + for i in range(0, len(fp_loc)): + with open(tevfilepath, "rb") as fp: + fp.seek(fp_loc[i], os.SEEK_SET) + S["data"][i, :] = np.fromfile(fp, dtype=table[formatNew, 3], count=nsample).reshape( + 1, nsample, order="F" + ) + # S['data'] = S['data'].swapaxes() + S["npoints"] = nsample + else: + S["data"] = np.asarray(data["strobe"][allIndexesWhereEventIsPresent[0]]) + S["npoints"] = 1 + S["channels"] = np.tile(1, (S["data"].shape[0],)) - S['storename'] = str(event) - S['sampling_rate'] = data['frequency'][first_row] - S['timestamps'] = np.asarray(data['timestamp'][allIndexesWhereEventIsPresent[0]]) - S['channels'] = np.asarray(data['chan'][allIndexesWhereEventIsPresent[0]]) + S["data"] = (S["data"].T).reshape(-1, order="F") + save_dict_to_hdf5(S, event, outputPath) - fp_loc = np.asarray(data['fp_loc'][allIndexesWhereEventIsPresent[0]]) - data_size = np.asarray(data['size']) + check_data(S, filepath, event, outputPath) - if formatNew != 5: - nsample = (data_size[first_row,]-10)*int(table[formatNew, 2]) - S['data'] = np.zeros((len(fp_loc), nsample)) - for i in range(0, len(fp_loc)): - with open(tevfilepath, 'rb') as fp: - fp.seek(fp_loc[i], os.SEEK_SET) - S['data'][i,:] = np.fromfile(fp, dtype=table[formatNew, 3], count=nsample).reshape(1, nsample, order='F') - #S['data'] = S['data'].swapaxes() - S['npoints'] = nsample - else: - S['data'] = np.asarray(data['strobe'][allIndexesWhereEventIsPresent[0]]) - S['npoints'] = 1 - S['channels'] = np.tile(1, (S['data'].shape[0],)) - - - S['data'] = (S['data'].T).reshape(-1, order='F') - - save_dict_to_hdf5(S, event, outputPath) - - check_data(S, filepath, event, outputPath) - - logger.info("Data for event {} fetched and stored.".format(event)) + logger.info("Data for event {} fetched and stored.".format(event)) # function to execute readtev function using multiprocessing to make it faster def execute_readtev(data, filepath, event, outputPath, numProcesses=mp.cpu_count()): - start = time.time() - with mp.Pool(numProcesses) as p: - p.starmap(readtev, zip(repeat(data), repeat(filepath), event, repeat(outputPath))) - #p = mp.Pool(mp.cpu_count()) - #p.starmap(readtev, zip(repeat(data), repeat(filepath), event, repeat(outputPath))) - #p.close() - #p.join() - logger.info("Time taken = {0:.5f}".format(time.time() - start)) + start = time.time() + with mp.Pool(numProcesses) as p: + p.starmap(readtev, zip(repeat(data), repeat(filepath), event, repeat(outputPath))) + # p = mp.Pool(mp.cpu_count()) + # p.starmap(readtev, zip(repeat(data), repeat(filepath), event, repeat(outputPath))) + # p.close() + # p.join() + logger.info("Time taken = {0:.5f}".format(time.time() - start)) def execute_import_csv(filepath, event, outputPath, numProcesses=mp.cpu_count()): - #logger.info("Reading data for event {} ...".format(event)) + # logger.info("Reading data for event {} ...".format(event)) + + start = time.time() + with mp.Pool(numProcesses) as p: + p.starmap(import_csv, zip(repeat(filepath), event, repeat(outputPath))) + logger.info("Time taken = {0:.5f}".format(time.time() - start)) - start = time.time() - with mp.Pool(numProcesses) as p: - p.starmap(import_csv, zip(repeat(filepath), event, repeat(outputPath))) - logger.info("Time taken = {0:.5f}".format(time.time() - start)) def access_data_doricV1(doric_file, storesList, outputPath): - keys = list(doric_file['Traces']['Console'].keys()) - for i in range(storesList.shape[1]): - if 'control' in storesList[1,i] or 'signal' in storesList[1,i]: - timestamps = np.array(doric_file['Traces']['Console']['Time(s)']['Console_time(s)']) - sampling_rate = np.array([1/(timestamps[-1]-timestamps[-2])]) - data = np.array(doric_file['Traces']['Console'][storesList[0,i]][storesList[0,i]]) - write_hdf5(sampling_rate, storesList[0,i], outputPath, 'sampling_rate') - write_hdf5(timestamps, storesList[0,i], outputPath, 'timestamps') - write_hdf5(data, storesList[0,i], outputPath, 'data') - else: - timestamps = np.array(doric_file['Traces']['Console']['Time(s)']['Console_time(s)']) - ttl = np.array(doric_file['Traces']['Console'][storesList[0,i]][storesList[0,i]]) - indices = np.where(ttl<=0)[0] - diff_indices = np.where(np.diff(indices)>1)[0] - write_hdf5(timestamps[indices[diff_indices]+1], storesList[0,i], outputPath, 'timestamps') + keys = list(doric_file["Traces"]["Console"].keys()) + for i in range(storesList.shape[1]): + if "control" in storesList[1, i] or "signal" in storesList[1, i]: + timestamps = np.array(doric_file["Traces"]["Console"]["Time(s)"]["Console_time(s)"]) + sampling_rate = np.array([1 / (timestamps[-1] - timestamps[-2])]) + data = np.array(doric_file["Traces"]["Console"][storesList[0, i]][storesList[0, i]]) + write_hdf5(sampling_rate, storesList[0, i], outputPath, "sampling_rate") + write_hdf5(timestamps, storesList[0, i], outputPath, "timestamps") + write_hdf5(data, storesList[0, i], outputPath, "data") + else: + timestamps = np.array(doric_file["Traces"]["Console"]["Time(s)"]["Console_time(s)"]) + ttl = np.array(doric_file["Traces"]["Console"][storesList[0, i]][storesList[0, i]]) + indices = np.where(ttl <= 0)[0] + diff_indices = np.where(np.diff(indices) > 1)[0] + write_hdf5(timestamps[indices[diff_indices] + 1], storesList[0, i], outputPath, "timestamps") + def separate_last_element(arr): l = arr[-1] return arr[:-1], l + def find_string(regex, arr): - for i in range(len(arr)): - if regex.match(arr[i]): - return i + for i in range(len(arr)): + if regex.match(arr[i]): + return i + def access_data_doricV6(doric_file, storesList, outputPath): - data = [doric_file["DataAcquisition"]] - res = [] - while len(data) != 0: - members = len(data) - while members != 0: - members -= 1 - data, last_element = separate_last_element(data) - if isinstance(last_element, h5py.Dataset) and not last_element.name.endswith("/Time"): - res.append(last_element.name) - elif isinstance(last_element, h5py.Group): - data.extend(reversed([last_element[k] for k in last_element.keys()])) - - decide_path = [] - for element in res: - sep_values = element.split('/') - if sep_values[-1]=='Values': - if f'{sep_values[-3]}/{sep_values[-2]}' in storesList[0,:]: - decide_path.append(element) - else: - if f'{sep_values[-2]}/{sep_values[-1]}' in storesList[0,:]: - decide_path.append(element) - - for i in range(storesList.shape[1]): - if 'control' in storesList[1,i] or 'signal' in storesList[1,i]: - regex = re.compile('(.*?)'+str(storesList[0,i])+'(.*?)') - idx = [i for i in range(len(decide_path)) if regex.match(decide_path[i])] - if len(idx)>1: - logger.error('More than one string matched (which should not be the case)') - raise Exception('More than one string matched (which should not be the case)') - idx = idx[0] - data = np.array(doric_file[decide_path[idx]]) - timestamps = np.array(doric_file[decide_path[idx].rsplit('/',1)[0]+'/Time']) - sampling_rate = np.array([1/(timestamps[-1]-timestamps[-2])]) - write_hdf5(sampling_rate, storesList[0,i], outputPath, 'sampling_rate') - write_hdf5(timestamps, storesList[0,i], outputPath, 'timestamps') - write_hdf5(data, storesList[0,i], outputPath, 'data') - else: - regex = re.compile('(.*?)'+storesList[0,i]+'$') - idx = [i for i in range(len(decide_path)) if regex.match(decide_path[i])] - if len(idx)>1: - logger.error('More than one string matched (which should not be the case)') - raise Exception('More than one string matched (which should not be the case)') - idx = idx[0] - ttl = np.array(doric_file[decide_path[idx]]) - timestamps = np.array(doric_file[decide_path[idx].rsplit('/',1)[0]+'/Time']) - indices = np.where(ttl<=0)[0] - diff_indices = np.where(np.diff(indices)>1)[0] - write_hdf5(timestamps[indices[diff_indices]+1], storesList[0,i], outputPath, 'timestamps') + data = [doric_file["DataAcquisition"]] + res = [] + while len(data) != 0: + members = len(data) + while members != 0: + members -= 1 + data, last_element = separate_last_element(data) + if isinstance(last_element, h5py.Dataset) and not last_element.name.endswith("/Time"): + res.append(last_element.name) + elif isinstance(last_element, h5py.Group): + data.extend(reversed([last_element[k] for k in last_element.keys()])) + + decide_path = [] + for element in res: + sep_values = element.split("/") + if sep_values[-1] == "Values": + if f"{sep_values[-3]}/{sep_values[-2]}" in storesList[0, :]: + decide_path.append(element) + else: + if f"{sep_values[-2]}/{sep_values[-1]}" in storesList[0, :]: + decide_path.append(element) + + for i in range(storesList.shape[1]): + if "control" in storesList[1, i] or "signal" in storesList[1, i]: + regex = re.compile("(.*?)" + str(storesList[0, i]) + "(.*?)") + idx = [i for i in range(len(decide_path)) if regex.match(decide_path[i])] + if len(idx) > 1: + logger.error("More than one string matched (which should not be the case)") + raise Exception("More than one string matched (which should not be the case)") + idx = idx[0] + data = np.array(doric_file[decide_path[idx]]) + timestamps = np.array(doric_file[decide_path[idx].rsplit("/", 1)[0] + "/Time"]) + sampling_rate = np.array([1 / (timestamps[-1] - timestamps[-2])]) + write_hdf5(sampling_rate, storesList[0, i], outputPath, "sampling_rate") + write_hdf5(timestamps, storesList[0, i], outputPath, "timestamps") + write_hdf5(data, storesList[0, i], outputPath, "data") + else: + regex = re.compile("(.*?)" + storesList[0, i] + "$") + idx = [i for i in range(len(decide_path)) if regex.match(decide_path[i])] + if len(idx) > 1: + logger.error("More than one string matched (which should not be the case)") + raise Exception("More than one string matched (which should not be the case)") + idx = idx[0] + ttl = np.array(doric_file[decide_path[idx]]) + timestamps = np.array(doric_file[decide_path[idx].rsplit("/", 1)[0] + "/Time"]) + indices = np.where(ttl <= 0)[0] + diff_indices = np.where(np.diff(indices) > 1)[0] + write_hdf5(timestamps[indices[diff_indices] + 1], storesList[0, i], outputPath, "timestamps") + def execute_import_doric(filepath, storesList, flag, outputPath): - - if flag=='doric_csv': - path = glob.glob(os.path.join(filepath, '*.csv')) - if len(path)>1: - logger.error('An error occurred : More than one Doric csv file present at the location') - raise Exception('More than one Doric csv file present at the location') - else: - df = pd.read_csv(path[0], header=1, index_col=False) - df = df.dropna(axis=1, how='all') - df = df.dropna(axis=0, how='any') - df['Time(s)'] = df['Time(s)'] - df['Time(s)'].to_numpy()[0] - for i in range(storesList.shape[1]): - if 'control' in storesList[1,i] or 'signal' in storesList[1,i]: - timestamps = np.array(df['Time(s)']) - sampling_rate = np.array([1/(timestamps[-1]-timestamps[-2])]) - write_hdf5(sampling_rate, storesList[0,i], outputPath, 'sampling_rate') - write_hdf5(df['Time(s)'].to_numpy(), storesList[0,i], outputPath, 'timestamps') - write_hdf5(df[storesList[0,i]].to_numpy(), storesList[0,i], outputPath, 'data') - else: - ttl = df[storesList[0,i]] - indices = np.where(ttl<=0)[0] - diff_indices = np.where(np.diff(indices)>1)[0] - write_hdf5(df['Time(s)'][indices[diff_indices]+1].to_numpy(), storesList[0,i], outputPath, 'timestamps') - else: - path = glob.glob(os.path.join(filepath, '*.doric')) - if len(path)>1: - logger.error('An error occurred : More than one Doric file present at the location') - raise Exception('More than one Doric file present at the location') - else: - with h5py.File(path[0], 'r') as f: - if 'Traces' in list(f.keys()): - keys = access_data_doricV1(f, storesList, outputPath) - elif list(f.keys())==['Configurations', 'DataAcquisition']: - keys = access_data_doricV6(f, storesList, outputPath) - + + if flag == "doric_csv": + path = glob.glob(os.path.join(filepath, "*.csv")) + if len(path) > 1: + logger.error("An error occurred : More than one Doric csv file present at the location") + raise Exception("More than one Doric csv file present at the location") + else: + df = pd.read_csv(path[0], header=1, index_col=False) + df = df.dropna(axis=1, how="all") + df = df.dropna(axis=0, how="any") + df["Time(s)"] = df["Time(s)"] - df["Time(s)"].to_numpy()[0] + for i in range(storesList.shape[1]): + if "control" in storesList[1, i] or "signal" in storesList[1, i]: + timestamps = np.array(df["Time(s)"]) + sampling_rate = np.array([1 / (timestamps[-1] - timestamps[-2])]) + write_hdf5(sampling_rate, storesList[0, i], outputPath, "sampling_rate") + write_hdf5(df["Time(s)"].to_numpy(), storesList[0, i], outputPath, "timestamps") + write_hdf5(df[storesList[0, i]].to_numpy(), storesList[0, i], outputPath, "data") + else: + ttl = df[storesList[0, i]] + indices = np.where(ttl <= 0)[0] + diff_indices = np.where(np.diff(indices) > 1)[0] + write_hdf5( + df["Time(s)"][indices[diff_indices] + 1].to_numpy(), storesList[0, i], outputPath, "timestamps" + ) + else: + path = glob.glob(os.path.join(filepath, "*.doric")) + if len(path) > 1: + logger.error("An error occurred : More than one Doric file present at the location") + raise Exception("More than one Doric file present at the location") + else: + with h5py.File(path[0], "r") as f: + if "Traces" in list(f.keys()): + keys = access_data_doricV1(f, storesList, outputPath) + elif list(f.keys()) == ["Configurations", "DataAcquisition"]: + keys = access_data_doricV6(f, storesList, outputPath) + # function to read data from 'tsq' and 'tev' files def readRawData(inputParameters): - - logger.debug('### Reading raw data... ###') - # get input parameters - inputParameters = inputParameters - folderNames = inputParameters['folderNames'] - numProcesses = inputParameters['numberOfCores'] - storesListPath = [] - if numProcesses==0: - numProcesses = mp.cpu_count() - elif numProcesses>mp.cpu_count(): - logger.warning('Warning : # of cores parameter set is greater than the cores available \ - available in your machine') - numProcesses = mp.cpu_count()-1 - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - storesListPath = np.concatenate(storesListPath) - writeToFile(str((storesListPath.shape[0]+1)*10)+'\n'+str(10)+'\n') - step = 0 - for i in range(len(folderNames)): - filepath = folderNames[i] - logger.debug(f"### Reading raw data for folder {folderNames[i]}") - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - # reading tsq file - data, flag = readtsq(filepath) - # checking if doric file exists - if flag=='tsq': - pass - else: - flag = check_doric(filepath) - - # read data corresponding to each storename selected by user while saving the storeslist file - for j in range(len(storesListPath)): - op = storesListPath[j] - if os.path.exists(os.path.join(op, '.cache_storesList.csv')): - storesList = np.genfromtxt(os.path.join(op, '.cache_storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - else: - storesList = np.genfromtxt(os.path.join(op, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - - if isinstance(data, pd.DataFrame) and flag=='tsq': - execute_readtev(data, filepath, np.unique(storesList[0,:]), op, numProcesses) - elif flag=='doric_csv': - execute_import_doric(filepath, storesList, flag, op) - elif flag=='doric_doric': - execute_import_doric(filepath, storesList, flag, op) - else: - execute_import_csv(filepath, np.unique(storesList[0,:]), op, numProcesses) - - writeToFile(str(10+((step+1)*10))+'\n') - step += 1 - logger.info(f"### Raw data for folder {folderNames[i]} fetched") - logger.info('Raw data fetched and saved.') - logger.info("#" * 400) + logger.debug("### Reading raw data... ###") + # get input parameters + inputParameters = inputParameters + folderNames = inputParameters["folderNames"] + numProcesses = inputParameters["numberOfCores"] + storesListPath = [] + if numProcesses == 0: + numProcesses = mp.cpu_count() + elif numProcesses > mp.cpu_count(): + logger.warning( + "Warning : # of cores parameter set is greater than the cores available \ + available in your machine" + ) + numProcesses = mp.cpu_count() - 1 + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + writeToFile(str((storesListPath.shape[0] + 1) * 10) + "\n" + str(10) + "\n") + step = 0 + for i in range(len(folderNames)): + filepath = folderNames[i] + logger.debug(f"### Reading raw data for folder {folderNames[i]}") + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + # reading tsq file + data, flag = readtsq(filepath) + # checking if doric file exists + if flag == "tsq": + pass + else: + flag = check_doric(filepath) + + # read data corresponding to each storename selected by user while saving the storeslist file + for j in range(len(storesListPath)): + op = storesListPath[j] + if os.path.exists(os.path.join(op, ".cache_storesList.csv")): + storesList = np.genfromtxt( + os.path.join(op, ".cache_storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + else: + storesList = np.genfromtxt(os.path.join(op, "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ) + + if isinstance(data, pd.DataFrame) and flag == "tsq": + execute_readtev(data, filepath, np.unique(storesList[0, :]), op, numProcesses) + elif flag == "doric_csv": + execute_import_doric(filepath, storesList, flag, op) + elif flag == "doric_doric": + execute_import_doric(filepath, storesList, flag, op) + else: + execute_import_csv(filepath, np.unique(storesList[0, :]), op, numProcesses) + + writeToFile(str(10 + ((step + 1) * 10)) + "\n") + step += 1 + logger.info(f"### Raw data for folder {folderNames[i]} fetched") + logger.info("Raw data fetched and saved.") + logger.info("#" * 400) + def main(input_parameters): - logger.info('run') - try: - readRawData(input_parameters) - logger.info('#'*400) - except Exception as e: - with open(os.path.join(os.path.expanduser('~'), 'pbSteps.txt'), 'a') as file: - file.write(str(-1)+"\n") - logger.error(f"An error occurred: {e}") - raise e + logger.info("run") + try: + readRawData(input_parameters) + logger.info("#" * 400) + except Exception as e: + with open(os.path.join(os.path.expanduser("~"), "pbSteps.txt"), "a") as file: + file.write(str(-1) + "\n") + logger.error(f"An error occurred: {e}") + raise e + if __name__ == "__main__": - input_parameters = json.loads(sys.argv[1]) - main(input_parameters=input_parameters) + input_parameters = json.loads(sys.argv[1]) + main(input_parameters=input_parameters) diff --git a/src/guppy/runFiberPhotometryAnalysis.ipynb b/src/guppy/runFiberPhotometryAnalysis.ipynb index ef5828c..7cc2e93 100755 --- a/src/guppy/runFiberPhotometryAnalysis.ipynb +++ b/src/guppy/runFiberPhotometryAnalysis.ipynb @@ -17,11 +17,10 @@ "%autoreload 2\n", "\n", "%matplotlib\n", - "import os\n", - "from readTevTsq import readRawData\n", - "from preprocess import extractTsAndSignal\n", "from computePsth import psthForEachStorename\n", - "from findTransientsFreqAndAmp import executeFindFreqAndAmp" + "from findTransientsFreqAndAmp import executeFindFreqAndAmp\n", + "from preprocess import extractTsAndSignal\n", + "from readTevTsq import readRawData" ] }, { diff --git a/src/guppy/saveStoresList.py b/src/guppy/saveStoresList.py index 4d0355c..d69d832 100755 --- a/src/guppy/saveStoresList.py +++ b/src/guppy/saveStoresList.py @@ -4,29 +4,30 @@ # In[1]: -import os -import json import glob +import json +import logging +import os +import socket +import tkinter as tk +from pathlib import Path +from random import randint +from tkinter import StringVar, messagebox, ttk + import h5py +import holoviews as hv import numpy as np import pandas as pd -from numpy import int32, uint32, uint8, uint16, float64, int64, int32, float32 import panel as pn -from random import randint -from pathlib import Path -import holoviews as hv -import warnings -import socket -import tkinter as tk -from tkinter import ttk, StringVar, messagebox -import logging +from numpy import float32, float64, int32, int64, uint16 -#hv.extension() +# hv.extension() pn.extension() logger = logging.getLogger(__name__) -def scanPortsAndFind(start_port=5000, end_port=5200, host='127.0.0.1'): + +def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"): while True: port = randint(start_port, end_port) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -39,29 +40,32 @@ def scanPortsAndFind(start_port=5000, end_port=5200, host='127.0.0.1'): return port + def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths)-set(removePaths)) + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + # function to show location for over-writing or creating a new stores list file. def show_dir(filepath): i = 1 while True: basename = os.path.basename(filepath) - op = os.path.join(filepath, basename+'_output_'+str(i)) + op = os.path.join(filepath, basename + "_output_" + str(i)) if not os.path.exists(op): break i += 1 return op + def make_dir(filepath): i = 1 while True: basename = os.path.basename(filepath) - op = os.path.join(filepath, basename+'_output_'+str(i)) + op = os.path.join(filepath, basename + "_output_" + str(i)) if not os.path.exists(op): os.mkdir(op) break @@ -69,6 +73,7 @@ def make_dir(filepath): return op + def check_header(df): arr = list(df.columns) check_float = [] @@ -77,24 +82,21 @@ def check_header(df): check_float.append(float(i)) except: pass - + return arr, check_float # function to read 'tsq' file def readtsq(filepath): - names = ('size', 'type', 'name', 'chan', 'sort_code', 'timestamp', - 'fp_loc', 'strobe', 'format', 'frequency') - formats = (int32, int32, 'S4', uint16, uint16, float64, int64, - float64, int32, float32) + names = ("size", "type", "name", "chan", "sort_code", "timestamp", "fp_loc", "strobe", "format", "frequency") + formats = (int32, int32, "S4", uint16, uint16, float64, int64, float64, int32, float32) offsets = 0, 4, 8, 12, 14, 16, 24, 24, 32, 36 - tsq_dtype = np.dtype({'names': names, 'formats': formats, - 'offsets': offsets}, align=True) - path = glob.glob(os.path.join(filepath, '*.tsq')) - if len(path)>1: - logger.error('Two tsq files are present at the location.') - raise Exception('Two tsq files are present at the location.') - elif len(path)==0: + tsq_dtype = np.dtype({"names": names, "formats": formats, "offsets": offsets}, align=True) + path = glob.glob(os.path.join(filepath, "*.tsq")) + if len(path) > 1: + logger.error("Two tsq files are present at the location.") + raise Exception("Two tsq files are present at the location.") + elif len(path) == 0: return 0 else: path = path[0] @@ -103,10 +105,10 @@ def readtsq(filepath): return df -# function to show GUI and save +# function to show GUI and save def saveStorenames(inputParameters, data, event_name, flag, filepath): - - logger.debug('Saving stores list file.') + + logger.debug("Saving stores list file.") # getting input parameters inputParameters = inputParameters @@ -115,19 +117,19 @@ def saveStorenames(inputParameters, data, event_name, flag, filepath): if isinstance(storenames_map, dict) and len(storenames_map) > 0: op = make_dir(filepath) arr = np.asarray([list(storenames_map.keys()), list(storenames_map.values())], dtype=str) - np.savetxt(os.path.join(op, 'storesList.csv'), arr, delimiter=",", fmt='%s') + np.savetxt(os.path.join(op, "storesList.csv"), arr, delimiter=",", fmt="%s") logger.info(f"Storeslist file saved at {op}") - logger.info('Storeslist : \n'+str(arr)) + logger.info("Storeslist : \n" + str(arr)) return # reading storenames from the data fetched using 'readtsq' function if isinstance(data, pd.DataFrame): - data['name'] = np.asarray(data['name'], dtype=str) - allnames = np.unique(data['name']) + data["name"] = np.asarray(data["name"], dtype=str) + allnames = np.unique(data["name"]) index = [] for i in range(len(allnames)): length = len(str(allnames[i])) - if length<4: + if length < 4: index.append(i) allnames = np.delete(allnames, index, 0) allnames = list(allnames) @@ -135,74 +137,75 @@ def saveStorenames(inputParameters, data, event_name, flag, filepath): else: allnames = [] - - if 'data_np_v2' in flag or 'data_np' in flag or 'event_np' in flag: - path_chev = glob.glob(os.path.join(filepath, '*chev*')) - path_chod = glob.glob(os.path.join(filepath, '*chod*')) - path_chpr = glob.glob(os.path.join(filepath, '*chpr*')) + if "data_np_v2" in flag or "data_np" in flag or "event_np" in flag: + path_chev = glob.glob(os.path.join(filepath, "*chev*")) + path_chod = glob.glob(os.path.join(filepath, "*chod*")) + path_chpr = glob.glob(os.path.join(filepath, "*chpr*")) combine_paths = path_chev + path_chod + path_chpr d = dict() for i in range(len(combine_paths)): - basename = (os.path.basename(combine_paths[i])).split('.')[0] + basename = (os.path.basename(combine_paths[i])).split(".")[0] df = pd.read_csv(combine_paths[i]) - d[basename] = { - 'x': np.array(df['timestamps']), - 'y': np.array(df['data']) - } + d[basename] = {"x": np.array(df["timestamps"]), "y": np.array(df["data"])} keys = list(d.keys()) - mark_down_np = pn.pane.Markdown(""" + mark_down_np = pn.pane.Markdown( + """ ### Extra Instructions to follow when using Neurophotometrics data : - - Guppy will take the NPM data, which has interleaved frames - from the signal and control channels, and divide it out into - separate channels for each site you recordded. - However, since NPM does not automatically annotate which - frames belong to the signal channel and which belong to the + - Guppy will take the NPM data, which has interleaved frames + from the signal and control channels, and divide it out into + separate channels for each site you recordded. + However, since NPM does not automatically annotate which + frames belong to the signal channel and which belong to the control channel, the user must specify this for GuPPy. - - Each of your recording sites will have a channel + - Each of your recording sites will have a channel named “chod” and a channel named “chev” - - View the plots below and, for each site, + - View the plots below and, for each site, determine whether the “chev” or “chod” channel is signal or control - - When you give your storenames, name the channels appropriately. - For example, “chev1” might be “signal_A” and + - When you give your storenames, name the channels appropriately. + For example, “chev1” might be “signal_A” and “chod1” might be “control_A” (or vice versa). - """) - plot_select = pn.widgets.Select(name='Select channel to see correspondings channels', options=keys, value=keys[0]) - + """ + ) + plot_select = pn.widgets.Select( + name="Select channel to see correspondings channels", options=keys, value=keys[0] + ) + @pn.depends(plot_select=plot_select) def plot(plot_select): - return hv.Curve((d[plot_select]['x'], d[plot_select]['y'])).opts(width=550) + return hv.Curve((d[plot_select]["x"], d[plot_select]["y"])).opts(width=550) + else: pass - # finalizing all the storenames + # finalizing all the storenames allnames = allnames + event_name - # instructions about how to save the storeslist file - mark_down = pn.pane.Markdown(""" + mark_down = pn.pane.Markdown( + """ - ### Instructions to follow : + ### Instructions to follow : - - Check Storenames to repeat checkbox and see instructions in “Github Wiki” for duplicating storenames. + - Check Storenames to repeat checkbox and see instructions in “Github Wiki” for duplicating storenames. Otherwise do not check the Storenames to repeat checkbox.
- Select storenames from list and click “Select Storenames” to populate area below.
- Enter names for storenames, in order, using the following naming convention:
Isosbestic = “control_region” (ex: Dv1A= control_DMS)
Signal= “signal_region” (ex: Dv2A= signal_DMS)
TTLs can be named using any convention (ex: PrtR = RewardedPortEntries) but should be kept consistent for later group analysis - + ``` - {"storenames": ["Dv1A", "Dv2A", - "Dv3B", "Dv4B", - "LNRW", "LNnR", - "PrtN", "PrtR", - "RNPS"], - "names_for_storenames": ["control_DMS", "signal_DMS", - "control_DLS", "signal_DLS", - "RewardedNosepoke", "UnrewardedNosepoke", - "UnrewardedPort", "RewardedPort", + {"storenames": ["Dv1A", "Dv2A", + "Dv3B", "Dv4B", + "LNRW", "LNnR", + "PrtN", "PrtR", + "RNPS"], + "names_for_storenames": ["control_DMS", "signal_DMS", + "control_DLS", "signal_DLS", + "RewardedNosepoke", "UnrewardedNosepoke", + "UnrewardedPort", "RewardedPort", "InactiveNosepoke"]} ``` - If user has saved storenames before, clicking "Select Storenames" button will pop up a dialog box @@ -212,153 +215,162 @@ def plot(plot_select): - Select “create new” or “overwrite” to generate a new storenames list or replace a previous one - Click Save - """, width=550) - + """, + width=550, + ) # creating GUI template - template = pn.template.BootstrapTemplate(title='Storenames GUI - {}'.format(os.path.basename(filepath), mark_down)) - - + template = pn.template.BootstrapTemplate(title="Storenames GUI - {}".format(os.path.basename(filepath), mark_down)) # creating different buttons and selectors for the GUI - cross_selector = pn.widgets.CrossSelector(name='Store Names Selection', value=[], options=allnames, width=600) - multi_choice = pn.widgets.MultiChoice(name='Select Storenames which you want more than once (multi-choice: multiple options selection)', value=[], options=allnames) + cross_selector = pn.widgets.CrossSelector(name="Store Names Selection", value=[], options=allnames, width=600) + multi_choice = pn.widgets.MultiChoice( + name="Select Storenames which you want more than once (multi-choice: multiple options selection)", + value=[], + options=allnames, + ) - literal_input_1 = pn.widgets.LiteralInput(name='Number of times you want the above storename (list)', - value=[], type=list) - #literal_input_2 = pn.widgets.LiteralInput(name='Names for Storenames (list)', type=list) + literal_input_1 = pn.widgets.LiteralInput( + name="Number of times you want the above storename (list)", value=[], type=list + ) + # literal_input_2 = pn.widgets.LiteralInput(name='Names for Storenames (list)', type=list) + + repeat_storenames = pn.widgets.Checkbox(name="Storenames to repeat", value=False) + repeat_storename_wd = pn.WidgetBox("", width=600) - repeat_storenames = pn.widgets.Checkbox(name='Storenames to repeat', value=False) - repeat_storename_wd = pn.WidgetBox('', width=600) def callback(target, event): - if event.new==True: + if event.new == True: target.objects = [multi_choice, literal_input_1] - elif event.new==False: + elif event.new == False: target.clear() - repeat_storenames.link(repeat_storename_wd, callbacks={'value': callback}) - #repeat_storename_wd = pn.WidgetBox('Storenames to repeat (leave blank if not needed)', multi_choice, literal_input_1, background="white", width=600) - update_options = pn.widgets.Button(name='Select Storenames', width=600) - save = pn.widgets.Button(name='Save', width=600) + repeat_storenames.link(repeat_storename_wd, callbacks={"value": callback}) + # repeat_storename_wd = pn.WidgetBox('Storenames to repeat (leave blank if not needed)', multi_choice, literal_input_1, background="white", width=600) - text = pn.widgets.LiteralInput(value=[], name='Selected Store Names', type=list, width=600) + update_options = pn.widgets.Button(name="Select Storenames", width=600) + save = pn.widgets.Button(name="Save", width=600) - path = pn.widgets.TextInput(name='Location to Stores List file', width=600) + text = pn.widgets.LiteralInput(value=[], name="Selected Store Names", type=list, width=600) - mark_down_for_overwrite = pn.pane.Markdown(""" Select option from below if user wants to over-write a file or create a new file. - **Creating a new file will make a new ouput folder and will get saved at that location.** - If user selects to over-write a file **Select location of the file to over-write** will provide - the existing options of the ouput folders where user needs to over-write the file""", width=600) + path = pn.widgets.TextInput(name="Location to Stores List file", width=600) - select_location = pn.widgets.Select(name='Select location of the file to over-write', - value='None', options=['None'], width=600) + mark_down_for_overwrite = pn.pane.Markdown( + """ Select option from below if user wants to over-write a file or create a new file. + **Creating a new file will make a new output folder and will get saved at that location.** + If user selects to over-write a file **Select location of the file to over-write** will provide + the existing options of the output folders where user needs to over-write the file""", + width=600, + ) + select_location = pn.widgets.Select( + name="Select location of the file to over-write", value="None", options=["None"], width=600 + ) - overwrite_button = pn.widgets.MenuButton(name='over-write storeslist file or create a new one? ', - items=['over_write_file', 'create_new_file'], - button_type='default', split=True, width=600) - - literal_input_2 = pn.widgets.CodeEditor(value="""{}""", - theme='tomorrow', - language='json', - height=250, width=600) + overwrite_button = pn.widgets.MenuButton( + name="over-write storeslist file or create a new one? ", + items=["over_write_file", "create_new_file"], + button_type="default", + split=True, + width=600, + ) - alert = pn.pane.Alert('#### No alerts !!', alert_type='danger', height=80, width=600) + literal_input_2 = pn.widgets.CodeEditor(value="""{}""", theme="tomorrow", language="json", height=250, width=600) + alert = pn.pane.Alert("#### No alerts !!", alert_type="danger", height=80, width=600) - take_widgets = pn.WidgetBox( - multi_choice, - literal_input_1 - ) + take_widgets = pn.WidgetBox(multi_choice, literal_input_1) - change_widgets = pn.WidgetBox( - text - ) + change_widgets = pn.WidgetBox(text) - storenames = [] storename_dropdowns = {} storename_textboxes = {} - - if len(allnames)==0: - alert.object = '####Alert !! \n No storenames found. There are not any TDT files or csv files to look for storenames.' + + if len(allnames) == 0: + alert.object = ( + "####Alert !! \n No storenames found. There are not any TDT files or csv files to look for storenames." + ) # on clicking overwrite_button, following function is executed def overwrite_button_actions(event): - if event.new=='over_write_file': - select_location.options = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - #select_location.value = select_location.options[0] + if event.new == "over_write_file": + select_location.options = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + # select_location.value = select_location.options[0] else: select_location.options = [show_dir(filepath)] - #select_location.value = select_location.options[0] - + # select_location.value = select_location.options[0] + def fetchValues(event): global storenames - alert.object = '#### No alerts !!' - + alert.object = "#### No alerts !!" + if not storename_dropdowns or not len(storenames) > 0: - alert.object = '####Alert !! \n No storenames selected.' + alert.object = "####Alert !! \n No storenames selected." return - + storenames_cache = dict() - if os.path.exists(os.path.join(Path.home(), '.storesList.json')): - with open(os.path.join(Path.home(), '.storesList.json')) as f: + if os.path.exists(os.path.join(Path.home(), ".storesList.json")): + with open(os.path.join(Path.home(), ".storesList.json")) as f: storenames_cache = json.load(f) - + comboBoxValues, textBoxValues = [], [] dropdown_keys = list(storename_dropdowns.keys()) textbox_keys = list(storename_textboxes.keys()) if storename_textboxes else [] - + # Get dropdown values for key in dropdown_keys: comboBoxValues.append(storename_dropdowns[key].value) - + # Get textbox values (matching with dropdown keys) for key in dropdown_keys: if key in storename_textboxes: textbox_value = storename_textboxes[key].value or "" textBoxValues.append(textbox_value) - + # Validation: Check for whitespace if len(textbox_value.split()) > 1: - alert.object = '####Alert !! \n Whitespace is not allowed in the text box entry.' + alert.object = "####Alert !! \n Whitespace is not allowed in the text box entry." return - + # Validation: Check for empty required fields dropdown_value = storename_dropdowns[key].value - if not textbox_value and dropdown_value not in storenames_cache and dropdown_value in ['control', 'signal', 'event TTLs']: - alert.object = '####Alert !! \n One of the text box entry is empty.' + if ( + not textbox_value + and dropdown_value not in storenames_cache + and dropdown_value in ["control", "signal", "event TTLs"] + ): + alert.object = "####Alert !! \n One of the text box entry is empty." return else: # For cached values, use the dropdown value directly textBoxValues.append(storename_dropdowns[key].value) - + if len(comboBoxValues) != len(textBoxValues): - alert.object = '####Alert !! \n Number of entries in combo box and text box should be same.' + alert.object = "####Alert !! \n Number of entries in combo box and text box should be same." return - + names_for_storenames = [] for i in range(len(comboBoxValues)): - if comboBoxValues[i] == 'control' or comboBoxValues[i] == "signal": - if '_' in textBoxValues[i]: - alert.object = '####Alert !! \n Please do not use underscore in region name.' + if comboBoxValues[i] == "control" or comboBoxValues[i] == "signal": + if "_" in textBoxValues[i]: + alert.object = "####Alert !! \n Please do not use underscore in region name." return names_for_storenames.append("{}_{}".format(comboBoxValues[i], textBoxValues[i])) - elif comboBoxValues[i] == 'event TTLs': + elif comboBoxValues[i] == "event TTLs": names_for_storenames.append(textBoxValues[i]) else: names_for_storenames.append(comboBoxValues[i]) - + d = dict() d["storenames"] = text.value d["names_for_storenames"] = names_for_storenames literal_input_2.value = str(json.dumps(d, indent=2)) - + # Panel-based storename configuration (replaces Tkinter dialog) storename_config_widgets = pn.Column(visible=False) - show_config_button = pn.widgets.Button(name='Show Selected Configuration', width=600) - - # on clicking 'Select Storenames' button, following function is executed + show_config_button = pn.widgets.Button(name="Show Selected Configuration", width=600) + + # on clicking 'Select Storenames' button, following function is executed def update_values(event): global storenames, vars_list arr = [] @@ -371,122 +383,133 @@ def update_values(event): for j in range(arr[1][i]): new_arr.append(arr[0][i]) - if len(new_arr)>0: + if len(new_arr) > 0: storenames = cross_selector.value + new_arr else: storenames = cross_selector.value - + for w in change_widgets: w.value = storenames storenames_cache = dict() - if os.path.exists(os.path.join(Path.home(), '.storesList.json')): - with open(os.path.join(Path.home(), '.storesList.json')) as f: + if os.path.exists(os.path.join(Path.home(), ".storesList.json")): + with open(os.path.join(Path.home(), ".storesList.json")) as f: storenames_cache = json.load(f) - + # Create Panel widgets for storename configuration config_widgets = [] storename_dropdowns.clear() storename_textboxes.clear() - + if len(storenames) > 0: - config_widgets.append(pn.pane.Markdown("## Configure Storenames\nSelect appropriate options for each storename and provide names as needed:")) - + config_widgets.append( + pn.pane.Markdown( + "## Configure Storenames\nSelect appropriate options for each storename and provide names as needed:" + ) + ) + for i, storename in enumerate(storenames): # Create a row for each storename row_widgets = [] - + # Label label = pn.pane.Markdown(f"**{storename}:**") row_widgets.append(label) - + # Dropdown options if storename in storenames_cache: options = storenames_cache[storename] - default_value = options[0] if options else '' + default_value = options[0] if options else "" else: - options = ['', 'control', 'signal', 'event TTLs'] - default_value = '' - + options = ["", "control", "signal", "event TTLs"] + default_value = "" + # Create unique key for widget - widget_key = f"{storename}_{i}" if f"{storename}_{i}" not in storename_dropdowns else f"{storename}_{i}_{len(storename_dropdowns)}" - - dropdown = pn.widgets.Select( - name='Type', - value=default_value, - options=options, - width=150 + widget_key = ( + f"{storename}_{i}" + if f"{storename}_{i}" not in storename_dropdowns + else f"{storename}_{i}_{len(storename_dropdowns)}" ) + + dropdown = pn.widgets.Select(name="Type", value=default_value, options=options, width=150) storename_dropdowns[widget_key] = dropdown row_widgets.append(dropdown) - + # Text input (only show if not cached or if control/signal/event TTLs selected) - if storename not in storenames_cache or default_value in ['control', 'signal', 'event TTLs']: + if storename not in storenames_cache or default_value in ["control", "signal", "event TTLs"]: textbox = pn.widgets.TextInput( - name='Name', - value='', - placeholder='Enter region/event name', - width=200 + name="Name", value="", placeholder="Enter region/event name", width=200 ) storename_textboxes[widget_key] = textbox row_widgets.append(textbox) - + # Add helper text based on selection def create_help_function(dropdown_widget, help_pane_container): @pn.depends(dropdown_widget.param.value, watch=True) def update_help(dropdown_value): - if dropdown_value == 'control': - help_pane_container[0] = pn.pane.Markdown("*Type appropriate region name*", styles={'color': 'gray', 'font-size': '12px'}) - elif dropdown_value == 'signal': - help_pane_container[0] = pn.pane.Markdown("*Type appropriate region name*", styles={'color': 'gray', 'font-size': '12px'}) - elif dropdown_value == 'event TTLs': - help_pane_container[0] = pn.pane.Markdown("*Type event name for the TTLs*", styles={'color': 'gray', 'font-size': '12px'}) + if dropdown_value == "control": + help_pane_container[0] = pn.pane.Markdown( + "*Type appropriate region name*", styles={"color": "gray", "font-size": "12px"} + ) + elif dropdown_value == "signal": + help_pane_container[0] = pn.pane.Markdown( + "*Type appropriate region name*", styles={"color": "gray", "font-size": "12px"} + ) + elif dropdown_value == "event TTLs": + help_pane_container[0] = pn.pane.Markdown( + "*Type event name for the TTLs*", styles={"color": "gray", "font-size": "12px"} + ) else: - help_pane_container[0] = pn.pane.Markdown("", styles={'color': 'gray', 'font-size': '12px'}) + help_pane_container[0] = pn.pane.Markdown( + "", styles={"color": "gray", "font-size": "12px"} + ) + return update_help - + help_container = [pn.pane.Markdown("")] help_function = create_help_function(dropdown, help_container) help_function(dropdown.value) # Initialize row_widgets.append(help_container[0]) - + # Add the row to config widgets config_widgets.append(pn.Row(*row_widgets, margin=(5, 0))) - + # Add show button config_widgets.append(pn.Spacer(height=20)) config_widgets.append(show_config_button) - config_widgets.append(pn.pane.Markdown("*Click 'Show Selected Configuration' to apply your selections.*", styles={'font-size': '12px', 'color': 'gray'})) - + config_widgets.append( + pn.pane.Markdown( + "*Click 'Show Selected Configuration' to apply your selections.*", + styles={"font-size": "12px", "color": "gray"}, + ) + ) + # Update the configuration panel storename_config_widgets.objects = config_widgets storename_config_widgets.visible = len(storenames) > 0 - - # on clicking save button, following function is executed def save_button(event=None): global storenames - + d = json.loads(literal_input_2.value) arr1, arr2 = np.asarray(d["storenames"]), np.asarray(d["names_for_storenames"]) - if np.where(arr2=="")[0].size>0: - alert.object = '#### Alert !! \n Empty string in the list names_for_storenames.' - logger.error('Empty string in the list names_for_storenames.') - raise Exception('Empty string in the list names_for_storenames.') + if np.where(arr2 == "")[0].size > 0: + alert.object = "#### Alert !! \n Empty string in the list names_for_storenames." + logger.error("Empty string in the list names_for_storenames.") + raise Exception("Empty string in the list names_for_storenames.") else: - alert.object = '#### No alerts !!' + alert.object = "#### No alerts !!" - if arr1.shape[0]!=arr2.shape[0]: - alert.object = '#### Alert !! \n Length of list storenames and names_for_storenames is not equal.' - logger.error('Length of list storenames and names_for_storenames is not equal.') - raise Exception('Length of list storenames and names_for_storenames is not equal.') + if arr1.shape[0] != arr2.shape[0]: + alert.object = "#### Alert !! \n Length of list storenames and names_for_storenames is not equal." + logger.error("Length of list storenames and names_for_storenames is not equal.") + raise Exception("Length of list storenames and names_for_storenames is not equal.") else: - alert.object = '#### No alerts !!' + alert.object = "#### No alerts !!" - - if not os.path.exists(os.path.join(Path.home(), '.storesList.json')): + if not os.path.exists(os.path.join(Path.home(), ".storesList.json")): storenames_cache = dict() for i in range(arr1.shape[0]): @@ -496,10 +519,10 @@ def save_button(event=None): else: storenames_cache[arr1[i]] = [arr2[i]] - with open(os.path.join(Path.home(), '.storesList.json'), 'w') as f: - json.dump(storenames_cache, f, indent=4) + with open(os.path.join(Path.home(), ".storesList.json"), "w") as f: + json.dump(storenames_cache, f, indent=4) else: - with open(os.path.join(Path.home(), '.storesList.json')) as f: + with open(os.path.join(Path.home(), ".storesList.json")) as f: storenames_cache = json.load(f) for i in range(arr1.shape[0]): @@ -509,19 +532,18 @@ def save_button(event=None): else: storenames_cache[arr1[i]] = [arr2[i]] - with open(os.path.join(Path.home(), '.storesList.json'), 'w') as f: + with open(os.path.join(Path.home(), ".storesList.json"), "w") as f: json.dump(storenames_cache, f, indent=4) arr = np.asarray([arr1, arr2]) logger.info(arr) if not os.path.exists(select_location.value): os.mkdir(select_location.value) - - np.savetxt(os.path.join(select_location.value, 'storesList.csv'), arr, delimiter=",", fmt='%s') - path.value = os.path.join(select_location.value, 'storesList.csv') + + np.savetxt(os.path.join(select_location.value, "storesList.csv"), arr, delimiter=",", fmt="%s") + path.value = os.path.join(select_location.value, "storesList.csv") logger.info(f"Storeslist file saved at {select_location.value}") - logger.info('Storeslist : \n'+str(arr)) - + logger.info("Storeslist : \n" + str(arr)) # Connect button callbacks update_options.on_click(update_values) @@ -532,22 +554,46 @@ def save_button(event=None): # creating widgets, adding them to template and showing a GUI on a new browser window number = scanPortsAndFind(start_port=5000, end_port=5200) - if 'data_np_v2' in flag or 'data_np' in flag or 'event_np' in flag: - widget_1 = pn.Column('# '+os.path.basename(filepath), mark_down, mark_down_np, plot_select, plot) - widget_2 = pn.Column(repeat_storenames, repeat_storename_wd, pn.Spacer(height=20), - cross_selector, update_options, - storename_config_widgets, pn.Spacer(height=10), - text, literal_input_2, alert, mark_down_for_overwrite, - overwrite_button, select_location, save, path) + if "data_np_v2" in flag or "data_np" in flag or "event_np" in flag: + widget_1 = pn.Column("# " + os.path.basename(filepath), mark_down, mark_down_np, plot_select, plot) + widget_2 = pn.Column( + repeat_storenames, + repeat_storename_wd, + pn.Spacer(height=20), + cross_selector, + update_options, + storename_config_widgets, + pn.Spacer(height=10), + text, + literal_input_2, + alert, + mark_down_for_overwrite, + overwrite_button, + select_location, + save, + path, + ) template.main.append(pn.Row(widget_1, widget_2)) else: - widget_1 = pn.Column('# '+os.path.basename(filepath), mark_down) - widget_2 = pn.Column(repeat_storenames, repeat_storename_wd, pn.Spacer(height=20), - cross_selector, update_options, - storename_config_widgets, pn.Spacer(height=10), - text, literal_input_2, alert, mark_down_for_overwrite, - overwrite_button, select_location, save, path) + widget_1 = pn.Column("# " + os.path.basename(filepath), mark_down) + widget_2 = pn.Column( + repeat_storenames, + repeat_storename_wd, + pn.Spacer(height=20), + cross_selector, + update_options, + storename_config_widgets, + pn.Spacer(height=10), + text, + literal_input_2, + alert, + mark_down_for_overwrite, + overwrite_button, + select_location, + save, + path, + ) template.main.append(pn.Row(widget_1, widget_2)) template.show(port=number) @@ -557,139 +603,170 @@ def save_button(event=None): def check_channels(state): state = state.astype(int) unique_state = np.unique(state[2:12]) - if unique_state.shape[0]>3: - logger.error("Looks like there are more than 3 channels in the file. Reading of these files\ - are not supported. Reach out to us if you get this error message.") - raise Exception("Looks like there are more than 3 channels in the file. Reading of these files\ - are not supported. Reach out to us if you get this error message.") + if unique_state.shape[0] > 3: + logger.error( + "Looks like there are more than 3 channels in the file. Reading of these files\ + are not supported. Reach out to us if you get this error message." + ) + raise Exception( + "Looks like there are more than 3 channels in the file. Reading of these files\ + are not supported. Reach out to us if you get this error message." + ) return unique_state.shape[0], unique_state - + + # function to decide NPM timestamps unit (seconds, ms or us) def decide_ts_unit_for_npm(df, timestamp_column_name=None, time_unit=None, headless=False): col_names = np.array(list(df.columns)) - col_names_ts = [''] + col_names_ts = [""] for name in col_names: - if 'timestamp' in name.lower(): + if "timestamp" in name.lower(): col_names_ts.append(name) - - ts_unit = 'seconds' - if len(col_names_ts)>2: + + ts_unit = "seconds" + if len(col_names_ts) > 2: # Headless path: auto-select column/unit without any UI if headless: if timestamp_column_name is not None: - assert timestamp_column_name in col_names_ts, f"Provided timestamp_column_name '{timestamp_column_name}' not found in columns {col_names_ts[1:]}" + assert ( + timestamp_column_name in col_names_ts + ), f"Provided timestamp_column_name '{timestamp_column_name}' not found in columns {col_names_ts[1:]}" chosen = timestamp_column_name else: chosen = col_names_ts[0] - df.insert(1, 'Timestamp', df[chosen]) + df.insert(1, "Timestamp", df[chosen]) df = df.drop(col_names_ts[1:], axis=1) - valid_units = {'seconds', 'milliseconds', 'microseconds'} - ts_unit = time_unit if (isinstance(time_unit, str) and time_unit in valid_units) else 'seconds' + valid_units = {"seconds", "milliseconds", "microseconds"} + ts_unit = time_unit if (isinstance(time_unit, str) and time_unit in valid_units) else "seconds" return df, ts_unit - #def comboBoxSelected(event): + # def comboBoxSelected(event): # logger.info(event.widget.get()) - + window = tk.Tk() - window.title('Select appropriate options for timestamps') - window.geometry('500x200') + window.title("Select appropriate options for timestamps") + window.geometry("500x200") holdComboboxValues = dict() - timestamps_label = ttk.Label(window, - text="Select which timetamps to use : ").grid(row=0, column=1, pady=25, padx=25) - holdComboboxValues['timestamps'] = StringVar() - timestamps_combo = ttk.Combobox(window, - values=col_names_ts, - textvariable=holdComboboxValues['timestamps']) + timestamps_label = ttk.Label(window, text="Select which timestamps to use : ").grid( + row=0, column=1, pady=25, padx=25 + ) + holdComboboxValues["timestamps"] = StringVar() + timestamps_combo = ttk.Combobox(window, values=col_names_ts, textvariable=holdComboboxValues["timestamps"]) timestamps_combo.grid(row=0, column=2, pady=25, padx=25) timestamps_combo.current(0) - #timestamps_combo.bind("<>", comboBoxSelected) + # timestamps_combo.bind("<>", comboBoxSelected) - time_unit_label = ttk.Label(window, text="Select timetamps unit : ").grid(row=1, column=1, pady=25, padx=25) - holdComboboxValues['time_unit'] = StringVar() - time_unit_combo = ttk.Combobox(window, - values=['', 'seconds', 'milliseconds', 'microseconds'], - textvariable=holdComboboxValues['time_unit']) + time_unit_label = ttk.Label(window, text="Select timestamps unit : ").grid(row=1, column=1, pady=25, padx=25) + holdComboboxValues["time_unit"] = StringVar() + time_unit_combo = ttk.Combobox( + window, values=["", "seconds", "milliseconds", "microseconds"], textvariable=holdComboboxValues["time_unit"] + ) time_unit_combo.grid(row=1, column=2, pady=25, padx=25) time_unit_combo.current(0) - #time_unit_combo.bind("<>", comboBoxSelected) + # time_unit_combo.bind("<>", comboBoxSelected) window.lift() window.after(500, lambda: window.lift()) window.mainloop() - if holdComboboxValues['timestamps'].get(): - df.insert(1, 'Timestamp', df[holdComboboxValues['timestamps'].get()]) + if holdComboboxValues["timestamps"].get(): + df.insert(1, "Timestamp", df[holdComboboxValues["timestamps"].get()]) df = df.drop(col_names_ts[1:], axis=1) else: - messagebox.showerror('All options not selected', 'All the options for timestamps \ - were not selected. Please select appropriate options') - logger.error('All the options for timestamps \ - were not selected. Please select appropriate options') - raise Exception('All the options for timestamps \ - were not selected. Please select appropriate options') - if holdComboboxValues['time_unit'].get(): - if holdComboboxValues['time_unit'].get()=='seconds': - ts_unit = holdComboboxValues['time_unit'].get() - elif holdComboboxValues['time_unit'].get()=='milliseconds': - ts_unit = holdComboboxValues['time_unit'].get() + messagebox.showerror( + "All options not selected", + "All the options for timestamps \ + were not selected. Please select appropriate options", + ) + logger.error( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + raise Exception( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + if holdComboboxValues["time_unit"].get(): + if holdComboboxValues["time_unit"].get() == "seconds": + ts_unit = holdComboboxValues["time_unit"].get() + elif holdComboboxValues["time_unit"].get() == "milliseconds": + ts_unit = holdComboboxValues["time_unit"].get() else: - ts_unit = holdComboboxValues['time_unit'].get() + ts_unit = holdComboboxValues["time_unit"].get() else: - messagebox.showerror('All options not selected', 'All the options for timestamps \ - were not selected. Please select appropriate options') - logger.error('All the options for timestamps \ - were not selected. Please select appropriate options') - raise Exception('All the options for timestamps \ - were not selected. Please select appropriate options') + messagebox.showerror( + "All options not selected", + "All the options for timestamps \ + were not selected. Please select appropriate options", + ) + logger.error( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) + raise Exception( + "All the options for timestamps \ + were not selected. Please select appropriate options" + ) else: pass - + return df, ts_unit - + + # function to decide indices of interleaved channels # in neurophotometrics data def decide_indices(file, df, flag, num_ch=2): - ch_name = [file+'chev', file+'chod', file+'chpr'] - if len(ch_name)0: + if len(value) > 0: columns_isstr = False df = pd.read_csv(path[i], header=None) cols = np.array(list(df.columns), dtype=str) @@ -807,148 +886,161 @@ def import_np_doric_csv(filepath, isosbestic_control, num_ch, inputParameters=No columns_isstr = True cols = np.array(list(df.columns), dtype=str) # check the structure of dataframe and assign flag to the type of file - if len(cols)==1: - if cols[0].lower()!='timestamps': - logger.error("\033[1m"+"Column name should be timestamps (all lower-cases)"+"\033[0m") - raise Exception("\033[1m"+"Column name should be timestamps (all lower-cases)"+"\033[0m") + if len(cols) == 1: + if cols[0].lower() != "timestamps": + logger.error("\033[1m" + "Column name should be timestamps (all lower-cases)" + "\033[0m") + raise Exception("\033[1m" + "Column name should be timestamps (all lower-cases)" + "\033[0m") else: - flag = 'event_csv' - elif len(cols)==3: - arr1 = np.array(['timestamps', 'data', 'sampling_rate']) + flag = "event_csv" + elif len(cols) == 3: + arr1 = np.array(["timestamps", "data", "sampling_rate"]) arr2 = np.char.lower(np.array(cols)) - if (np.sort(arr1)==np.sort(arr2)).all()==False: - logger.error("\033[1m"+"Column names should be timestamps, data and sampling_rate (all lower-cases)"+"\033[0m") - raise Exception("\033[1m"+"Column names should be timestamps, data and sampling_rate (all lower-cases)"+"\033[0m") + if (np.sort(arr1) == np.sort(arr2)).all() == False: + logger.error( + "\033[1m" + + "Column names should be timestamps, data and sampling_rate (all lower-cases)" + + "\033[0m" + ) + raise Exception( + "\033[1m" + + "Column names should be timestamps, data and sampling_rate (all lower-cases)" + + "\033[0m" + ) else: - flag = 'data_csv' - elif len(cols)==2: - flag = 'event_or_data_np' - elif len(cols)>=2: - flag = 'data_np' + flag = "data_csv" + elif len(cols) == 2: + flag = "event_or_data_np" + elif len(cols) >= 2: + flag = "data_np" else: - logger.error('Number of columns in csv file does not make sense.') - raise Exception('Number of columns in csv file does not make sense.') - + logger.error("Number of columns in csv file does not make sense.") + raise Exception("Number of columns in csv file does not make sense.") - if columns_isstr == True and ('flags' in np.char.lower(np.array(cols)) or 'ledstate' in np.char.lower(np.array(cols))): - flag = flag+'_v2' + if columns_isstr == True and ( + "flags" in np.char.lower(np.array(cols)) or "ledstate" in np.char.lower(np.array(cols)) + ): + flag = flag + "_v2" else: flag = flag - # used assigned flags to process the files and read the data - if flag=='event_or_data_np': - arr = list(df.iloc[:,1]) + if flag == "event_or_data_np": + arr = list(df.iloc[:, 1]) check_float = [True for i in arr if isinstance(i, float)] - if len(arr)==len(check_float) and columns_isstr == False: - flag = 'data_np' - elif columns_isstr == True and ('value' in np.char.lower(np.array(cols))): - flag = 'event_np' + if len(arr) == len(check_float) and columns_isstr == False: + flag = "data_np" + elif columns_isstr == True and ("value" in np.char.lower(np.array(cols))): + flag = "event_np" else: - flag = 'event_np' + flag = "event_np" else: pass - + flag_arr.append(flag) logger.info(flag) - if flag=='event_csv' or flag=='data_csv': - name = os.path.basename(path[i]).split('.')[0] + if flag == "event_csv" or flag == "data_csv": + name = os.path.basename(path[i]).split(".")[0] event_from_filename.append(name) - elif flag=='data_np': - file = f'file{str(i)}_' + elif flag == "data_np": + file = f"file{str(i)}_" df, indices_dict, num_channels = decide_indices(file, df, flag, num_ch) keys = list(indices_dict.keys()) for k in range(len(keys)): for j in range(df.shape[1]): - if j==0: - timestamps = df.iloc[:,j][indices_dict[keys[k]]] - #timestamps_odd = df.iloc[:,j][odd_indices] + if j == 0: + timestamps = df.iloc[:, j][indices_dict[keys[k]]] + # timestamps_odd = df.iloc[:,j][odd_indices] else: d = dict() - d['timestamps'] = timestamps - d['data'] = df.iloc[:,j][indices_dict[keys[k]]] - + d["timestamps"] = timestamps + d["data"] = df.iloc[:, j][indices_dict[keys[k]]] + df_ch = pd.DataFrame(d) - df_ch.to_csv(os.path.join(dirname, keys[k]+str(j)+'.csv'), index=False) - event_from_filename.append(keys[k]+str(j)) - - elif flag=='event_np': - type_val = np.array(df.iloc[:,1]) + df_ch.to_csv(os.path.join(dirname, keys[k] + str(j) + ".csv"), index=False) + event_from_filename.append(keys[k] + str(j)) + + elif flag == "event_np": + type_val = np.array(df.iloc[:, 1]) type_val_unique = np.unique(type_val) if headless: response = 1 if bool(npm_split_events) else 0 else: window = tk.Tk() - if len(type_val_unique)>1: - response = messagebox.askyesno('Multiple event TTLs', 'Based on the TTL file,\ + if len(type_val_unique) > 1: + response = messagebox.askyesno( + "Multiple event TTLs", + "Based on the TTL file,\ it looks like TTLs \ - belongs to multipe behavior type. \ + belongs to multiple behavior type. \ Do you want to create multiple files for each \ - behavior type ?') + behavior type ?", + ) else: response = 0 window.destroy() - if response==1: - timestamps = np.array(df.iloc[:,0]) + if response == 1: + timestamps = np.array(df.iloc[:, 0]) for j in range(len(type_val_unique)): - idx = np.where(type_val==type_val_unique[j]) + idx = np.where(type_val == type_val_unique[j]) d = dict() - d['timestamps'] = timestamps[idx] + d["timestamps"] = timestamps[idx] df_new = pd.DataFrame(d) - df_new.to_csv(os.path.join(dirname, 'event'+str(type_val_unique[j])+'.csv'), index=False) - event_from_filename.append('event'+str(type_val_unique[j])) + df_new.to_csv(os.path.join(dirname, "event" + str(type_val_unique[j]) + ".csv"), index=False) + event_from_filename.append("event" + str(type_val_unique[j])) else: - timestamps = np.array(df.iloc[:,0]) + timestamps = np.array(df.iloc[:, 0]) d = dict() - d['timestamps'] = timestamps + d["timestamps"] = timestamps df_new = pd.DataFrame(d) - df_new.to_csv(os.path.join(dirname, 'event'+str(0)+'.csv'), index=False) - event_from_filename.append('event'+str(0)) + df_new.to_csv(os.path.join(dirname, "event" + str(0) + ".csv"), index=False) + event_from_filename.append("event" + str(0)) else: - file = f'file{str(i)}_' + file = f"file{str(i)}_" df, ts_unit = decide_ts_unit_for_npm( - df, - timestamp_column_name=npm_timestamp_column_name, - time_unit=npm_time_unit, - headless=headless + df, timestamp_column_name=npm_timestamp_column_name, time_unit=npm_time_unit, headless=headless ) df, indices_dict, num_channels = decide_indices(file, df, flag) keys = list(indices_dict.keys()) for k in range(len(keys)): for j in range(df.shape[1]): - if j==0: - timestamps = df.iloc[:,j][indices_dict[keys[k]]] - #timestamps_odd = df.iloc[:,j][odd_indices] + if j == 0: + timestamps = df.iloc[:, j][indices_dict[keys[k]]] + # timestamps_odd = df.iloc[:,j][odd_indices] else: d = dict() - d['timestamps'] = timestamps - d['data'] = df.iloc[:,j][indices_dict[keys[k]]] - + d["timestamps"] = timestamps + d["data"] = df.iloc[:, j][indices_dict[keys[k]]] + df_ch = pd.DataFrame(d) - df_ch.to_csv(os.path.join(dirname, keys[k]+str(j)+'.csv'), index=False) - event_from_filename.append(keys[k]+str(j)) - - path_chev = glob.glob(os.path.join(filepath, '*chev*')) - path_chod = glob.glob(os.path.join(filepath, '*chod*')) - path_chpr = glob.glob(os.path.join(filepath, '*chpr*')) - path_event = glob.glob(os.path.join(filepath, 'event*')) - #path_sig = glob.glob(os.path.join(filepath, 'sig*')) + df_ch.to_csv(os.path.join(dirname, keys[k] + str(j) + ".csv"), index=False) + event_from_filename.append(keys[k] + str(j)) + + path_chev = glob.glob(os.path.join(filepath, "*chev*")) + path_chod = glob.glob(os.path.join(filepath, "*chod*")) + path_chpr = glob.glob(os.path.join(filepath, "*chpr*")) + path_event = glob.glob(os.path.join(filepath, "event*")) + # path_sig = glob.glob(os.path.join(filepath, 'sig*')) path_chev_chod_chpr = [path_chev, path_chod, path_chpr] - if (('data_np_v2' in flag_arr or 'data_np' in flag_arr) and ('event_np' in flag_arr) and (i==len(path)-1)) or \ - (('data_np_v2' in flag_arr or 'data_np' in flag_arr) and (i==len(path)-1)): # i==len(path)-1 and or 'event_np' in flag + if ( + ("data_np_v2" in flag_arr or "data_np" in flag_arr) + and ("event_np" in flag_arr) + and (i == len(path) - 1) + ) or ( + ("data_np_v2" in flag_arr or "data_np" in flag_arr) and (i == len(path) - 1) + ): # i==len(path)-1 and or 'event_np' in flag num_path_chev, num_path_chod, num_path_chpr = len(path_chev), len(path_chod), len(path_chpr) arr_len, no_ch = [], [] for i in range(len(path_chev_chod_chpr)): - if len(path_chev_chod_chpr[i])>0: + if len(path_chev_chod_chpr[i]) > 0: arr_len.append(len(path_chev_chod_chpr[i])) else: continue unique_arr_len = np.unique(np.array(arr_len)) - if 'data_np_v2' in flag_arr: - if ts_unit == 'seconds': + if "data_np_v2" in flag_arr: + if ts_unit == "seconds": divisor = 1 - elif ts_unit == 'milliseconds': + elif ts_unit == "milliseconds": divisor = 1e3 else: divisor = 1e6 @@ -958,57 +1050,60 @@ def import_np_doric_csv(filepath, isosbestic_control, num_ch, inputParameters=No for j in range(len(path_event)): df_event = pd.read_csv(path_event[j]) df_chev = pd.read_csv(path_chev[0]) - df_event['timestamps'] = (df_event['timestamps']-df_chev['timestamps'][0])/divisor + df_event["timestamps"] = (df_event["timestamps"] - df_chev["timestamps"][0]) / divisor df_event.to_csv(path_event[j], index=False) - if unique_arr_len.shape[0]==1: + if unique_arr_len.shape[0] == 1: for j in range(len(path_chev)): - if file+'chev' in indices_dict.keys(): + if file + "chev" in indices_dict.keys(): df_chev = pd.read_csv(path_chev[j]) - df_chev['timestamps'] = (df_chev['timestamps']-df_chev['timestamps'][0])/divisor - df_chev['sampling_rate'] = np.full(df_chev.shape[0], np.nan) - df_chev.at[0,'sampling_rate'] = df_chev.shape[0]/(df_chev['timestamps'].iloc[-1] - df_chev['timestamps'].iloc[0]) + df_chev["timestamps"] = (df_chev["timestamps"] - df_chev["timestamps"][0]) / divisor + df_chev["sampling_rate"] = np.full(df_chev.shape[0], np.nan) + df_chev.at[0, "sampling_rate"] = df_chev.shape[0] / ( + df_chev["timestamps"].iloc[-1] - df_chev["timestamps"].iloc[0] + ) df_chev.to_csv(path_chev[j], index=False) - if file+'chod' in indices_dict.keys(): + if file + "chod" in indices_dict.keys(): df_chod = pd.read_csv(path_chod[j]) - df_chod['timestamps'] = df_chev['timestamps'] - df_chod['sampling_rate'] = np.full(df_chod.shape[0], np.nan) - df_chod.at[0,'sampling_rate'] = df_chev['sampling_rate'][0] + df_chod["timestamps"] = df_chev["timestamps"] + df_chod["sampling_rate"] = np.full(df_chod.shape[0], np.nan) + df_chod.at[0, "sampling_rate"] = df_chev["sampling_rate"][0] df_chod.to_csv(path_chod[j], index=False) - if file+'chpr' in indices_dict.keys(): + if file + "chpr" in indices_dict.keys(): df_chpr = pd.read_csv(path_chpr[j]) - df_chpr['timestamps'] = df_chev['timestamps'] - df_chpr['sampling_rate'] = np.full(df_chpr.shape[0], np.nan) - df_chpr.at[0,'sampling_rate'] = df_chev['sampling_rate'][0] + df_chpr["timestamps"] = df_chev["timestamps"] + df_chpr["sampling_rate"] = np.full(df_chpr.shape[0], np.nan) + df_chpr.at[0, "sampling_rate"] = df_chev["sampling_rate"][0] df_chpr.to_csv(path_chpr[j], index=False) else: - logger.error('Number of channels should be same for all regions.') - raise Exception('Number of channels should be same for all regions.') + logger.error("Number of channels should be same for all regions.") + raise Exception("Number of channels should be same for all regions.") else: pass - logger.info('Importing of either NPM or Doric or csv file is done.') + logger.info("Importing of either NPM or Doric or csv file is done.") return event_from_filename, flag_arr # function to read input parameters and run the saveStorenames function def execute(inputParameters): - - + inputParameters = inputParameters - folderNames = inputParameters['folderNames'] - isosbestic_control = inputParameters['isosbestic_control'] - num_ch = inputParameters['noChannels'] + folderNames = inputParameters["folderNames"] + isosbestic_control = inputParameters["isosbestic_control"] + num_ch = inputParameters["noChannels"] logger.info(folderNames) try: for i in folderNames: - filepath = os.path.join(inputParameters['abspath'], i) + filepath = os.path.join(inputParameters["abspath"], i) data = readtsq(filepath) - event_name, flag = import_np_doric_csv(filepath, isosbestic_control, num_ch, inputParameters=inputParameters) + event_name, flag = import_np_doric_csv( + filepath, isosbestic_control, num_ch, inputParameters=inputParameters + ) saveStorenames(inputParameters, data, event_name, flag, filepath) - logger.info('#'*400) + logger.info("#" * 400) except Exception as e: logger.error(str(e)) raise e diff --git a/src/guppy/savingInputParameters.py b/src/guppy/savingInputParameters.py index 98f9267..cd515ab 100644 --- a/src/guppy/savingInputParameters.py +++ b/src/guppy/savingInputParameters.py @@ -1,27 +1,28 @@ +import json +import logging import os +import subprocess import sys import time -import subprocess -import json -import panel as pn -import numpy as np -import pandas as pd import tkinter as tk -from tkinter import ttk -from tkinter import filedialog from threading import Thread -from pathlib import Path -from .visualizePlot import visualizeResults +from tkinter import filedialog, ttk + +import numpy as np +import pandas as pd +import panel as pn + from .saveStoresList import execute -import logging +from .visualizePlot import visualizeResults logger = logging.getLogger(__name__) + def savingInputParameters(): pn.extension() # Determine base folder path (headless-friendly via env var) - base_dir_env = os.environ.get('GUPPY_BASE_DIR') + base_dir_env = os.environ.get("GUPPY_BASE_DIR") is_headless = base_dir_env and os.path.isdir(base_dir_env) if is_headless: global folder_path @@ -32,6 +33,7 @@ def savingInputParameters(): folder_selection = tk.Tk() folder_selection.title("Select the folder path where your data is located") folder_selection.geometry("700x200") + def select_folder(): global folder_path folder_path = filedialog.askdirectory(title="Select the folder path where your data is located") @@ -39,7 +41,7 @@ def select_folder(): logger.info(f"Folder path set to {folder_path}") folder_selection.destroy() else: - folder_path = os.path.expanduser('~') + folder_path = os.path.expanduser("~") logger.info(f"Folder path set to {folder_path}") select_button = ttk.Button(folder_selection, text="Select a Folder", command=select_folder) @@ -49,7 +51,7 @@ def select_folder(): current_dir = os.getcwd() def make_dir(filepath): - op = os.path.join(filepath, 'inputParameters') + op = os.path.join(filepath, "inputParameters") if not os.path.exists(op): os.mkdir(op) return op @@ -64,246 +66,292 @@ def extractTs(): def psthComputation(): inputParameters = getInputParameters() - inputParameters['curr_dir'] = current_dir + inputParameters["curr_dir"] = current_dir subprocess.call([sys.executable, "-m", "guppy.computePsth", json.dumps(inputParameters)]) - def readPBIncrementValues(progressBar): logger.info("Read progress bar increment values function started...") - file_path = os.path.join(os.path.expanduser('~'), 'pbSteps.txt') + file_path = os.path.join(os.path.expanduser("~"), "pbSteps.txt") if os.path.exists(file_path): os.remove(file_path) increment, maximum = 0, 100 progressBar.value = increment - progressBar.bar_color = 'success' + progressBar.bar_color = "success" while True: try: - with open(file_path, 'r') as file: + with open(file_path, "r") as file: content = file.readlines() - if len(content)==0: + if len(content) == 0: pass else: maximum = int(content[0]) increment = int(content[-1]) - - if increment==-1: - progressBar.bar_color = 'danger' + + if increment == -1: + progressBar.bar_color = "danger" os.remove(file_path) break progressBar.max = maximum progressBar.value = increment - time.sleep(0.001) + time.sleep(0.001) except FileNotFoundError: time.sleep(0.001) except PermissionError: - time.sleep(0.001) + time.sleep(0.001) except Exception as e: # Handle other exceptions that may occur logger.info(f"An error occurred while reading the file: {e}") break - if increment==maximum: + if increment == maximum: os.remove(file_path) break logger.info("Read progress bar increment values stopped.") - - # progress bars = PB - read_progress = pn.indicators.Progress(name='Progress', value=100, max=100, width=300) - extract_progress = pn.indicators.Progress(name='Progress', value=100, max=100, width=300) - psth_progress = pn.indicators.Progress(name='Progress', value=100, max=100, width=300) + # progress bars = PB + read_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) + extract_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) + psth_progress = pn.indicators.Progress(name="Progress", value=100, max=100, width=300) - template = pn.template.BootstrapTemplate(title='Input Parameters GUI') + template = pn.template.BootstrapTemplate(title="Input Parameters GUI") mark_down_1 = pn.pane.Markdown("""**Select folders for the analysis from the file selector below**""", width=600) - files_1 = pn.widgets.FileSelector(folder_path, name='folderNames', width=950) - + files_1 = pn.widgets.FileSelector(folder_path, name="folderNames", width=950) - explain_time_artifacts = pn.pane.Markdown(""" - - ***Number of cores :*** Number of cores used for analysis. Try to - keep it less than the number of cores in your machine. - - ***Combine Data? :*** Make this parameter ``` True ``` if user wants to combine - the data, especially when there is two different + explain_time_artifacts = pn.pane.Markdown( + """ + - ***Number of cores :*** Number of cores used for analysis. Try to + keep it less than the number of cores in your machine. + - ***Combine Data? :*** Make this parameter ``` True ``` if user wants to combine + the data, especially when there is two different data files for the same recording session.
- ***Isosbestic Control Channel? :*** Make this parameter ``` False ``` if user does not want to use isosbestic control channel in the analysis.
- ***Eliminate first few seconds :*** It is the parameter to cut out first x seconds from the data. Default is 1 seconds.
- ***Window for Moving Average filter :*** The filtering of signals - is done using moving average filter. Default window used for moving + is done using moving average filter. Default window used for moving average filter is 100 datapoints. Change it based on the requirement.
- - ***Moving Window (transients detection) :*** Transients in the z-score - and/or \u0394F/F are detected using this moving window. + - ***Moving Window (transients detection) :*** Transients in the z-score + and/or \u0394F/F are detected using this moving window. Default is 15 seconds. Change it based on the requirement.
- ***High Amplitude filtering threshold (HAFT) (transients detection) :*** High amplitude - events greater than x times the MAD above the median are filtered out. Here, x is + events greater than x times the MAD above the median are filtered out. Here, x is high amplitude filtering threshold. Default is 2. - ***Transients detection threshold (TD Thresh):*** Peaks with local maxima greater than x times the MAD above the median of the trace (after filtering high amplitude events) are detected as transients. Here, x is transients detection threshold. Default is 3. - ***Number of channels (Neurophotometrics only) :*** Number of - channels used while recording, when data files has no column names mentioning "Flags" + channels used while recording, when data files has no column names mentioning "Flags" or "LedState". - - ***removeArtifacts? :*** Make this parameter ``` True``` if there are + - ***removeArtifacts? :*** Make this parameter ``` True``` if there are artifacts and user wants to remove the artifacts. - - ***removeArtifacts method :*** Selecting ```concatenate``` will remove bad + - ***removeArtifacts method :*** Selecting ```concatenate``` will remove bad chunks and concatenate the selected good chunks together. Selecting ```replace with NaN``` will replace bad chunks with NaN values. - """, width=350) + """, + width=350, + ) - timeForLightsTurnOn = pn.widgets.LiteralInput(name='Eliminate first few seconds (int)', value=1, type=int, width=320) + timeForLightsTurnOn = pn.widgets.LiteralInput( + name="Eliminate first few seconds (int)", value=1, type=int, width=320 + ) - isosbestic_control = pn.widgets.Select(name='Isosbestic Control Channel? (bool)', value=True, options=[True, False], width=320) + isosbestic_control = pn.widgets.Select( + name="Isosbestic Control Channel? (bool)", value=True, options=[True, False], width=320 + ) - numberOfCores = pn.widgets.LiteralInput(name='# of cores (int)', value=2, type=int, width=150) + numberOfCores = pn.widgets.LiteralInput(name="# of cores (int)", value=2, type=int, width=150) - combine_data = pn.widgets.Select(name='Combine Data? (bool)', value=False, options=[True, False], width=150) + combine_data = pn.widgets.Select(name="Combine Data? (bool)", value=False, options=[True, False], width=150) - computePsth = pn.widgets.Select(name='z_score and/or \u0394F/F? (psth)', options=['z_score', 'dff', 'Both'], width=320) + computePsth = pn.widgets.Select( + name="z_score and/or \u0394F/F? (psth)", options=["z_score", "dff", "Both"], width=320 + ) - transients = pn.widgets.Select(name='z_score and/or \u0394F/F? (transients)', options=['z_score', 'dff', 'Both'], width=320) + transients = pn.widgets.Select( + name="z_score and/or \u0394F/F? (transients)", options=["z_score", "dff", "Both"], width=320 + ) - plot_zScore_dff = pn.widgets.Select(name='z-score plot and/or \u0394F/F plot?', options=['z_score', 'dff', 'Both', 'None'], value='None', width=320) + plot_zScore_dff = pn.widgets.Select( + name="z-score plot and/or \u0394F/F plot?", options=["z_score", "dff", "Both", "None"], value="None", width=320 + ) - moving_wd = pn.widgets.LiteralInput(name='Moving Window for transients detection (s) (int)', value=15, type=int, width=320) + moving_wd = pn.widgets.LiteralInput( + name="Moving Window for transients detection (s) (int)", value=15, type=int, width=320 + ) - highAmpFilt = pn.widgets.LiteralInput(name='HAFT (int)', value=2, type=int, width=150) + highAmpFilt = pn.widgets.LiteralInput(name="HAFT (int)", value=2, type=int, width=150) - transientsThresh = pn.widgets.LiteralInput(name='TD Thresh (int)', value=3, type=int, width=150) + transientsThresh = pn.widgets.LiteralInput(name="TD Thresh (int)", value=3, type=int, width=150) - moving_avg_filter = pn.widgets.LiteralInput(name='Window for Moving Average filter (int)', - value=100, type=int, width=320) + moving_avg_filter = pn.widgets.LiteralInput( + name="Window for Moving Average filter (int)", value=100, type=int, width=320 + ) - removeArtifacts = pn.widgets.Select(name='removeArtifacts? (bool)', - value=False, options=[True, False], width=150) + removeArtifacts = pn.widgets.Select(name="removeArtifacts? (bool)", value=False, options=[True, False], width=150) - artifactsRemovalMethod = pn.widgets.Select(name='removeArtifacts method', - value='concatenate', - options=['concatenate', 'replace with NaN'], - width=150) + artifactsRemovalMethod = pn.widgets.Select( + name="removeArtifacts method", value="concatenate", options=["concatenate", "replace with NaN"], width=150 + ) - no_channels_np = pn.widgets.LiteralInput(name='Number of channels (Neurophotometrics only)', - value=2, type=int, width=320) + no_channels_np = pn.widgets.LiteralInput( + name="Number of channels (Neurophotometrics only)", value=2, type=int, width=320 + ) - z_score_computation = pn.widgets.Select(name='z-score computation Method', - options=['standard z-score', 'baseline z-score', 'modified z-score'], - value='standard z-score', width=200) - - baseline_wd_strt = pn.widgets.LiteralInput(name='Baseline Window Start Time (s) (int)', value=0, type=int, width=200) - baseline_wd_end = pn.widgets.LiteralInput(name='Baseline Window End Time (s) (int)', value=0, type=int, width=200) + z_score_computation = pn.widgets.Select( + name="z-score computation Method", + options=["standard z-score", "baseline z-score", "modified z-score"], + value="standard z-score", + width=200, + ) - explain_z_score = pn.pane.Markdown(""" + baseline_wd_strt = pn.widgets.LiteralInput( + name="Baseline Window Start Time (s) (int)", value=0, type=int, width=200 + ) + baseline_wd_end = pn.widgets.LiteralInput(name="Baseline Window End Time (s) (int)", value=0, type=int, width=200) + + explain_z_score = pn.pane.Markdown( + """ ***Note :***
- Details about z-score computation methods are explained in Github wiki.
- - The details will make user understand what computation method to use for + - The details will make user understand what computation method to use for their data.
- - Baseline Window Parameters should be kept 0 unless you are using baseline
+ - Baseline Window Parameters should be kept 0 unless you are using baseline
z-score computation method. The parameters are in seconds. - """, width=580) + """, + width=580, + ) - explain_nsec = pn.pane.Markdown(""" + explain_nsec = pn.pane.Markdown( + """ - ***Time Interval :*** To omit bursts of event timestamps, user defined time interval is set so that if the time difference between two timestamps is less than this defined time interval, it will be deleted for the calculation of PSTH. - ***Compute Cross-correlation :*** Make this parameter ```True```, when user wants - to compute cross-correlation between PSTHs of two different signals or signals + to compute cross-correlation between PSTHs of two different signals or signals recorded from different brain regions. - """, width=580) + """, + width=580, + ) - nSecPrev = pn.widgets.LiteralInput(name='Seconds before 0 (int)', value=-10, type=int, width=120) + nSecPrev = pn.widgets.LiteralInput(name="Seconds before 0 (int)", value=-10, type=int, width=120) - nSecPost = pn.widgets.LiteralInput(name='Seconds after 0 (int)', value=20, type=int, width=120) + nSecPost = pn.widgets.LiteralInput(name="Seconds after 0 (int)", value=20, type=int, width=120) - computeCorr = pn.widgets.Select(name='Compute Cross-correlation (bool)', - options=[True, False], - value=False, width=200) + computeCorr = pn.widgets.Select( + name="Compute Cross-correlation (bool)", options=[True, False], value=False, width=200 + ) - timeInterval = pn.widgets.LiteralInput(name='Time Interval (s)', value=2, type=int, width=120) + timeInterval = pn.widgets.LiteralInput(name="Time Interval (s)", value=2, type=int, width=120) - use_time_or_trials = pn.widgets.Select(name='Bin PSTH trials (str)', - options = ['Time (min)', '# of trials'], - value='Time (min)', width=120) + use_time_or_trials = pn.widgets.Select( + name="Bin PSTH trials (str)", options=["Time (min)", "# of trials"], value="Time (min)", width=120 + ) - bin_psth_trials = pn.widgets.LiteralInput(name='Time(min) / # of trials \n for binning? (int)', - value=0, type=int, width=200) + bin_psth_trials = pn.widgets.LiteralInput( + name="Time(min) / # of trials \n for binning? (int)", value=0, type=int, width=200 + ) - explain_baseline = pn.pane.Markdown(""" + explain_baseline = pn.pane.Markdown( + """ ***Note :***
- - If user does not want to do baseline correction, + - If user does not want to do baseline correction, put both parameters 0.
- If the first event timestamp is less than the length of baseline window, it will be rejected in the PSTH computation step.
- - Baseline parameters must be within the PSTH parameters + - Baseline parameters must be within the PSTH parameters set in the PSTH parameters section. - """, width=580) - - baselineCorrectionStart = pn.widgets.LiteralInput(name='Baseline Correction Start time(int)', value=-5, type=int, width=200) - - baselineCorrectionEnd = pn.widgets.LiteralInput(name='Baseline Correction End time(int)', value=0, type=int, width=200) - - zscore_param_wd = pn.WidgetBox("### Z-score Parameters", explain_z_score, - z_score_computation, - pn.Row(baseline_wd_strt, baseline_wd_end), - width=600) - - psth_param_wd = pn.WidgetBox("### PSTH Parameters", explain_nsec, - pn.Row(nSecPrev, nSecPost, computeCorr), - pn.Row(timeInterval, use_time_or_trials, bin_psth_trials), - width=600) - - baseline_param_wd = pn.WidgetBox("### Baseline Parameters", explain_baseline, - pn.Row(baselineCorrectionStart, baselineCorrectionEnd), - width=600) - peak_explain = pn.pane.Markdown(""" + """, + width=580, + ) + + baselineCorrectionStart = pn.widgets.LiteralInput( + name="Baseline Correction Start time(int)", value=-5, type=int, width=200 + ) + + baselineCorrectionEnd = pn.widgets.LiteralInput( + name="Baseline Correction End time(int)", value=0, type=int, width=200 + ) + + zscore_param_wd = pn.WidgetBox( + "### Z-score Parameters", + explain_z_score, + z_score_computation, + pn.Row(baseline_wd_strt, baseline_wd_end), + width=600, + ) + + psth_param_wd = pn.WidgetBox( + "### PSTH Parameters", + explain_nsec, + pn.Row(nSecPrev, nSecPost, computeCorr), + pn.Row(timeInterval, use_time_or_trials, bin_psth_trials), + width=600, + ) + + baseline_param_wd = pn.WidgetBox( + "### Baseline Parameters", explain_baseline, pn.Row(baselineCorrectionStart, baselineCorrectionEnd), width=600 + ) + peak_explain = pn.pane.Markdown( + """ ***Note :***
- Peak and area are computed between the window set below.
- Peak and AUC parameters must be within the PSTH parameters set in the PSTH parameters section.
- - Please make sure when user changes the parameters in the table below, click on any other cell after + - Please make sure when user changes the parameters in the table below, click on any other cell after changing a value in a particular cell. - """, width=580) - - - start_end_point_df = pd.DataFrame({'Peak Start time': [-5, 0, 5, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan], - 'Peak End time': [0, 3, 10, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan]}) - - df_widget = pn.widgets.Tabulator(start_end_point_df, name='DataFrame', - show_index=False, widths=280) - - - peak_param_wd = pn.WidgetBox("### Peak and AUC Parameters", - peak_explain, df_widget, - width=600) + """, + width=580, + ) + + start_end_point_df = pd.DataFrame( + { + "Peak Start time": [-5, 0, 5, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "Peak End time": [0, 3, 10, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + df_widget = pn.widgets.Tabulator(start_end_point_df, name="DataFrame", show_index=False, widths=280) + peak_param_wd = pn.WidgetBox("### Peak and AUC Parameters", peak_explain, df_widget, width=600) - mark_down_2 = pn.pane.Markdown("""**Select folders for the average analysis from the file selector below**""", width=600) + mark_down_2 = pn.pane.Markdown( + """**Select folders for the average analysis from the file selector below**""", width=600 + ) - files_2 = pn.widgets.FileSelector(folder_path, name='folderNamesForAvg', width=950) + files_2 = pn.widgets.FileSelector(folder_path, name="folderNamesForAvg", width=950) - averageForGroup = pn.widgets.Select(name='Average Group? (bool)', value=False, options=[True, False], width=435) + averageForGroup = pn.widgets.Select(name="Average Group? (bool)", value=False, options=[True, False], width=435) - visualizeAverageResults = pn.widgets.Select(name='Visualize Average Results? (bool)', - value=False, options=[True, False], width=435) + visualizeAverageResults = pn.widgets.Select( + name="Visualize Average Results? (bool)", value=False, options=[True, False], width=435 + ) - visualize_zscore_or_dff = pn.widgets.Select(name='z-score or \u0394F/F? (for visualization)', options=['z_score', 'dff'], width=435) + visualize_zscore_or_dff = pn.widgets.Select( + name="z-score or \u0394F/F? (for visualization)", options=["z_score", "dff"], width=435 + ) individual_analysis_wd_2 = pn.Column( - explain_time_artifacts, pn.Row(numberOfCores, combine_data), - isosbestic_control, timeForLightsTurnOn, - moving_avg_filter, computePsth, transients, plot_zScore_dff, - moving_wd, pn.Row(highAmpFilt, transientsThresh), - no_channels_np, pn.Row(removeArtifacts, artifactsRemovalMethod) - ) + explain_time_artifacts, + pn.Row(numberOfCores, combine_data), + isosbestic_control, + timeForLightsTurnOn, + moving_avg_filter, + computePsth, + transients, + plot_zScore_dff, + moving_wd, + pn.Row(highAmpFilt, transientsThresh), + no_channels_np, + pn.Row(removeArtifacts, artifactsRemovalMethod), + ) group_analysis_wd_1 = pn.Column(mark_down_2, files_2, averageForGroup, width=800) visualization_wd = pn.Row(visualize_zscore_or_dff, pn.Spacer(width=60), visualizeAverageResults) - def getInputParameters(): abspath = getAbsPath() inputParameters = { @@ -328,8 +376,8 @@ def getInputParameters(): "use_time_or_trials": use_time_or_trials.value, "baselineCorrectionStart": baselineCorrectionStart.value, "baselineCorrectionEnd": baselineCorrectionEnd.value, - "peak_startPoint": list(df_widget.value['Peak Start time']), #startPoint.value, - "peak_endPoint": list(df_widget.value['Peak End time']), #endPoint.value, + "peak_startPoint": list(df_widget.value["Peak Start time"]), # startPoint.value, + "peak_endPoint": list(df_widget.value["Peak End time"]), # endPoint.value, "selectForComputePsth": computePsth.value, "selectForTransientsComputation": transients.value, "moving_window": moving_wd.value, @@ -339,43 +387,43 @@ def getInputParameters(): "visualize_zscore_or_dff": visualize_zscore_or_dff.value, "folderNamesForAvg": files_2.value, "averageForGroup": averageForGroup.value, - "visualizeAverageResults": visualizeAverageResults.value + "visualizeAverageResults": visualizeAverageResults.value, } return inputParameters def checkSameLocation(arr, abspath): - #abspath = [] + # abspath = [] for i in range(len(arr)): abspath.append(os.path.dirname(arr[i])) abspath = np.asarray(abspath) abspath = np.unique(abspath) - if len(abspath)>1: - logger.error('All the folders selected should be at the same location') - raise Exception('All the folders selected should be at the same location') - + if len(abspath) > 1: + logger.error("All the folders selected should be at the same location") + raise Exception("All the folders selected should be at the same location") + return abspath def getAbsPath(): - arr_1, arr_2 = files_1.value, files_2.value - if len(arr_1)==0 and len(arr_2)==0: - logger.error('No folder is selected for analysis') - raise Exception('No folder is selected for analysis') - + arr_1, arr_2 = files_1.value, files_2.value + if len(arr_1) == 0 and len(arr_2) == 0: + logger.error("No folder is selected for analysis") + raise Exception("No folder is selected for analysis") + abspath = [] - if len(arr_1)>0: + if len(arr_1) > 0: abspath = checkSameLocation(arr_1, abspath) else: abspath = checkSameLocation(arr_2, abspath) - + abspath = np.unique(abspath) - if len(abspath)>1: - logger.error('All the folders selected should be at the same location') - raise Exception('All the folders selected should be at the same location') + if len(abspath) > 1: + logger.error("All the folders selected should be at the same location") + raise Exception("All the folders selected should be at the same location") return abspath def onclickProcess(event=None): - - logger.debug('Saving Input Parameters file.') + + logger.debug("Saving Input Parameters file.") abspath = getAbsPath() analysisParameters = { "combine_data": combine_data.value, @@ -394,23 +442,23 @@ def onclickProcess(event=None): "use_time_or_trials": use_time_or_trials.value, "baselineCorrectionStart": baselineCorrectionStart.value, "baselineCorrectionEnd": baselineCorrectionEnd.value, - "peak_startPoint": list(df_widget.value['Peak Start time']), #startPoint.value, - "peak_endPoint": list(df_widget.value['Peak End time']), #endPoint.value, + "peak_startPoint": list(df_widget.value["Peak Start time"]), # startPoint.value, + "peak_endPoint": list(df_widget.value["Peak End time"]), # endPoint.value, "selectForComputePsth": computePsth.value, "selectForTransientsComputation": transients.value, "moving_window": moving_wd.value, "highAmpFilt": highAmpFilt.value, - "transientsThresh": transientsThresh.value + "transientsThresh": transientsThresh.value, } for folder in files_1.value: - with open(os.path.join(folder, 'GuPPyParamtersUsed.json'), 'w') as f: + with open(os.path.join(folder, "GuPPyParamtersUsed.json"), "w") as f: json.dump(analysisParameters, f, indent=4) logger.info(f"Input Parameters file saved at {folder}") - - logger.info('#'*400) - - #path.value = (os.path.join(op, 'inputParameters.json')).replace('\\', '/') - logger.info('Input Parameters File Saved.') + + logger.info("#" * 400) + + # path.value = (os.path.join(op, 'inputParameters.json')).replace('\\', '/') + logger.info("Input Parameters File Saved.") def onclickStoresList(event=None): inputParameters = getInputParameters() @@ -431,33 +479,36 @@ def onclickextractts(event=None): thread.start() readPBIncrementValues(extract_progress) thread.join() - + def onclickpsth(event=None): thread = Thread(target=psthComputation) thread.start() readPBIncrementValues(psth_progress) thread.join() - - mark_down_ip = pn.pane.Markdown("""**Step 1 : Save Input Parameters**""", width=300) - mark_down_ip_note = pn.pane.Markdown("""***Note : ***
+ mark_down_ip_note = pn.pane.Markdown( + """***Note : ***
- Save Input Parameters will save input parameters used for the analysis in all the folders you selected for the analysis (useful for future reference). All analysis steps will run without saving input parameters. - """, width=300) - save_button = pn.widgets.Button(name='Save to file...', button_type='primary', width=300, align='end') + """, + width=300, + ) + save_button = pn.widgets.Button(name="Save to file...", button_type="primary", width=300, align="end") mark_down_storenames = pn.pane.Markdown("""**Step 2 : Open Storenames GUI
and save storenames**""", width=300) - open_storesList = pn.widgets.Button(name='Open Storenames GUI', button_type='primary', width=300, align='end') + open_storesList = pn.widgets.Button(name="Open Storenames GUI", button_type="primary", width=300, align="end") mark_down_read = pn.pane.Markdown("""**Step 3 : Read Raw Data**""", width=300) - read_rawData = pn.widgets.Button(name='Read Raw Data', button_type='primary', width=300, align='end') + read_rawData = pn.widgets.Button(name="Read Raw Data", button_type="primary", width=300, align="end") mark_down_extract = pn.pane.Markdown("""**Step 4 : Extract timestamps
and its correction**""", width=300) - extract_ts = pn.widgets.Button(name="Extract timestamps and it's correction", button_type='primary', width=300, align='end') + extract_ts = pn.widgets.Button( + name="Extract timestamps and it's correction", button_type="primary", width=300, align="end" + ) mark_down_psth = pn.pane.Markdown("""**Step 5 : PSTH Computation**""", width=300) - psth_computation = pn.widgets.Button(name="PSTH Computation", button_type='primary', width=300, align='end') + psth_computation = pn.widgets.Button(name="PSTH Computation", button_type="primary", width=300, align="end") mark_down_visualization = pn.pane.Markdown("""**Step 6 : Visualization**""", width=300) - open_visualization = pn.widgets.Button(name='Open Visualization GUI', button_type='primary', width=300, align='end') - open_terminal = pn.widgets.Button(name='Open Terminal', button_type='primary', width=300, align='end') + open_visualization = pn.widgets.Button(name="Open Visualization GUI", button_type="primary", width=300, align="end") + open_terminal = pn.widgets.Button(name="Open Terminal", button_type="primary", width=300, align="end") save_button.on_click(onclickProcess) open_storesList.on_click(onclickStoresList) @@ -466,11 +517,10 @@ def onclickpsth(event=None): psth_computation.on_click(onclickpsth) open_visualization.on_click(onclickVisualization) - template.sidebar.append(mark_down_ip) template.sidebar.append(mark_down_ip_note) template.sidebar.append(save_button) - #template.sidebar.append(path) + # template.sidebar.append(path) template.sidebar.append(mark_down_storenames) template.sidebar.append(open_storesList) template.sidebar.append(mark_down_read) @@ -484,20 +534,19 @@ def onclickpsth(event=None): template.sidebar.append(psth_progress) template.sidebar.append(mark_down_visualization) template.sidebar.append(open_visualization) - #template.sidebar.append(open_terminal) - + # template.sidebar.append(open_terminal) psth_baseline_param = pn.Column(zscore_param_wd, psth_param_wd, baseline_param_wd, peak_param_wd) widget = pn.Column(mark_down_1, files_1, pn.Row(individual_analysis_wd_2, psth_baseline_param)) - #file_selector = pn.WidgetBox(files_1) - styles = dict(background='WhiteSmoke') - individual = pn.Card(widget, title='Individual Analysis', styles=styles, width=1000) - group = pn.Card(group_analysis_wd_1, title='Group Analysis', styles=styles, width=1000) - visualize = pn.Card(visualization_wd, title='Visualization Parameters', styles=styles, width=1000) + # file_selector = pn.WidgetBox(files_1) + styles = dict(background="WhiteSmoke") + individual = pn.Card(widget, title="Individual Analysis", styles=styles, width=1000) + group = pn.Card(group_analysis_wd_1, title="Group Analysis", styles=styles, width=1000) + visualize = pn.Card(visualization_wd, title="Visualization Parameters", styles=styles, width=1000) - #template.main.append(file_selector) + # template.main.append(file_selector) template.main.append(individual) template.main.append(group) template.main.append(visualize) diff --git a/src/guppy/testing/api.py b/src/guppy/testing/api.py index bc8b239..587a022 100644 --- a/src/guppy/testing/api.py +++ b/src/guppy/testing/api.py @@ -10,21 +10,15 @@ from __future__ import annotations -import json import os -import numpy as np -from typing import Iterable, List +from typing import Iterable -from guppy.savingInputParameters import savingInputParameters -from guppy.saveStoresList import execute -from guppy.readTevTsq import readRawData -from guppy.preprocess import extractTsAndSignal from guppy.computePsth import psthForEachStorename from guppy.findTransientsFreqAndAmp import executeFindFreqAndAmp - - - - +from guppy.preprocess import extractTsAndSignal +from guppy.readTevTsq import readRawData +from guppy.saveStoresList import execute +from guppy.savingInputParameters import savingInputParameters def step1(*, base_dir: str, selected_folders: Iterable[str]) -> None: @@ -69,7 +63,15 @@ def step1(*, base_dir: str, selected_folders: Iterable[str]) -> None: template._hooks["onclickProcess"]() -def step2(*, base_dir: str, selected_folders: Iterable[str], storenames_map: dict[str, str], npm_timestamp_column_name: str | None = None, npm_time_unit: str = "seconds", npm_split_events: bool = True) -> None: +def step2( + *, + base_dir: str, + selected_folders: Iterable[str], + storenames_map: dict[str, str], + npm_timestamp_column_name: str | None = None, + npm_time_unit: str = "seconds", + npm_split_events: bool = True, +) -> None: """ Run pipeline Step 2 (Save Storenames) via the actual Panel-backed logic. @@ -157,7 +159,14 @@ def step2(*, base_dir: str, selected_folders: Iterable[str], storenames_map: dic execute(input_params) -def step3(*, base_dir: str, selected_folders: Iterable[str], npm_timestamp_column_name: str | None = None, npm_time_unit: str = "seconds", npm_split_events: bool = True) -> None: +def step3( + *, + base_dir: str, + selected_folders: Iterable[str], + npm_timestamp_column_name: str | None = None, + npm_time_unit: str = "seconds", + npm_split_events: bool = True, +) -> None: """ Run pipeline Step 3 (Read Raw Data) via the actual Panel-backed logic, headlessly. @@ -227,7 +236,14 @@ def step3(*, base_dir: str, selected_folders: Iterable[str], npm_timestamp_colum readRawData(input_params) -def step4(*, base_dir: str, selected_folders: Iterable[str], npm_timestamp_column_name: str | None = None, npm_time_unit: str = "seconds", npm_split_events: bool = True) -> None: +def step4( + *, + base_dir: str, + selected_folders: Iterable[str], + npm_timestamp_column_name: str | None = None, + npm_time_unit: str = "seconds", + npm_split_events: bool = True, +) -> None: """ Run pipeline Step 4 (Extract timestamps and signal) via the Panel-backed logic, headlessly. @@ -297,7 +313,14 @@ def step4(*, base_dir: str, selected_folders: Iterable[str], npm_timestamp_colum extractTsAndSignal(input_params) -def step5(*, base_dir: str, selected_folders: Iterable[str], npm_timestamp_column_name: str | None = None, npm_time_unit: str = "seconds", npm_split_events: bool = True) -> None: +def step5( + *, + base_dir: str, + selected_folders: Iterable[str], + npm_timestamp_column_name: str | None = None, + npm_time_unit: str = "seconds", + npm_split_events: bool = True, +) -> None: """ Run pipeline Step 5 (PSTH Computation) via the Panel-backed logic, headlessly. diff --git a/src/guppy/visualizePlot.py b/src/guppy/visualizePlot.py index c98bbe8..929149e 100755 --- a/src/guppy/visualizePlot.py +++ b/src/guppy/visualizePlot.py @@ -1,29 +1,31 @@ -import os import glob -import param -import re +import logging import math -import numpy as np -import pandas as pd +import os +import re import socket from random import randint -import holoviews as hv -from holoviews import opts -from bokeh.io import export_svgs, export_png -from holoviews.plotting.util import process_cmap -from holoviews.operation.datashader import datashade + import datashader as ds +import holoviews as hv import matplotlib.pyplot as plt -from pathlib import Path -from .preprocess import get_all_stores_for_combining_data -import logging -import panel as pn +import numpy as np +import pandas as pd import panel as pn +import param +from bokeh.io import export_png, export_svgs +from holoviews import opts +from holoviews.operation.datashader import datashade +from holoviews.plotting.util import process_cmap + +from .preprocess import get_all_stores_for_combining_data + pn.extension() logger = logging.getLogger(__name__) -def scanPortsAndFind(start_port=5000, end_port=5200, host='127.0.0.1'): + +def scanPortsAndFind(start_port=5000, end_port=5200, host="127.0.0.1"): while True: port = randint(start_port, end_port) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -36,699 +38,899 @@ def scanPortsAndFind(start_port=5000, end_port=5200, host='127.0.0.1'): return port + def takeOnlyDirs(paths): - removePaths = [] - for p in paths: - if os.path.isfile(p): - removePaths.append(p) - return list(set(paths)-set(removePaths)) + removePaths = [] + for p in paths: + if os.path.isfile(p): + removePaths.append(p) + return list(set(paths) - set(removePaths)) + # read h5 file as a dataframe def read_Df(filepath, event, name): - event = event.replace("\\","_") - event = event.replace("/","_") - if name: - op = os.path.join(filepath, event+'_{}.h5'.format(name)) - else: - op = os.path.join(filepath, event+'.h5') - df = pd.read_hdf(op, key='df', mode='r') + event = event.replace("\\", "_") + event = event.replace("/", "_") + if name: + op = os.path.join(filepath, event + "_{}.h5".format(name)) + else: + op = os.path.join(filepath, event + ".h5") + df = pd.read_hdf(op, key="df", mode="r") + + return df - return df # make a new directory for saving plots def make_dir(filepath): - op = os.path.join(filepath, 'saved_plots') - if not os.path.exists(op): - os.mkdir(op) + op = os.path.join(filepath, "saved_plots") + if not os.path.exists(op): + os.mkdir(op) + + return op - return op # remove unnecessary column names def remove_cols(cols): - regex = re.compile('bin_err_*') - remove_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - remove_cols = remove_cols + ['err', 'timestamps'] - cols = [i for i in cols if i not in remove_cols] + regex = re.compile("bin_err_*") + remove_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + remove_cols = remove_cols + ["err", "timestamps"] + cols = [i for i in cols if i not in remove_cols] + + return cols - return cols -#def look_psth_bins(event, name): +# def look_psth_bins(event, name): + # helper function to create plots def helper_plots(filepath, event, name, inputParameters): - basename = os.path.basename(filepath) - visualize_zscore_or_dff = inputParameters['visualize_zscore_or_dff'] - - # note when there are no behavior event TTLs - if len(event)==0: - logger.warning("\033[1m"+"There are no behavior event TTLs present to visualize.".format(event)+"\033[0m") - return 0 - - - if os.path.exists(os.path.join(filepath, 'cross_correlation_output')): - event_corr, frames = [], [] - if visualize_zscore_or_dff=='z_score': - corr_fp = glob.glob(os.path.join(filepath, 'cross_correlation_output', '*_z_score_*')) - elif visualize_zscore_or_dff=='dff': - corr_fp = glob.glob(os.path.join(filepath, 'cross_correlation_output', '*_dff_*')) - for i in range(len(corr_fp)): - filename = os.path.basename(corr_fp[i]).split('.')[0] - event_corr.append(filename) - df = pd.read_hdf(corr_fp[i], key='df', mode='r') - frames.append(df) - if len(frames)>0: - df_corr = pd.concat(frames, keys=event_corr, axis=1) - else: - event_corr = [] - df_corr = [] - else: - event_corr = [] - df_corr = None - - - # combine all the event PSTH so that it can be viewed together - if name: - event_name, name = event, name - new_event, frames, bins = [], [], {} - for i in range(len(event_name)): - - for j in range(len(name)): - new_event.append(event_name[i]+'_'+name[j].split('_')[-1]) - new_name = name[j] - temp_df = read_Df(filepath, new_event[-1], new_name) - cols = list(temp_df.columns) - regex = re.compile('bin_[(]') - bins[new_event[-1]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - #bins.append(keep_cols) - frames.append(temp_df) - - df = pd.concat(frames, keys=new_event, axis=1) - else: - new_event = list(np.unique(np.array(event))) - frames, bins = [], {} - for i in range(len(new_event)): - temp_df = read_Df(filepath, new_event[i], '') - cols = list(temp_df.columns) - regex = re.compile('bin_[(]') - bins[new_event[i]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - frames.append(temp_df) - - df = pd.concat(frames, keys=new_event, axis=1) - - if isinstance(df_corr, pd.DataFrame): - new_event.extend(event_corr) - df = pd.concat([df,df_corr],axis=1,sort=False).reset_index() - - columns_dict = dict() - for i in range(len(new_event)): - df_1 = df[new_event[i]] - columns = list(df_1.columns) - columns.append('All') - columns_dict[new_event[i]] = columns - - - # create a class to make GUI and plot different graphs - class Viewer(param.Parameterized): - - #class_event = new_event - - # make options array for different selectors - multiple_plots_options = [] - heatmap_options = new_event - bins_keys = list(bins.keys()) - if len(bins_keys)>0: - bins_new = bins - for i in range(len(bins_keys)): - arr = bins[bins_keys[i]] - if len(arr)>0: - #heatmap_options.append('{}_bin'.format(bins_keys[i])) - for j in arr: - multiple_plots_options.append('{}_{}'.format(bins_keys[i], j)) - - multiple_plots_options = new_event + multiple_plots_options - else: - multiple_plots_options = new_event - - # create different options and selectors - event_selector = param.ObjectSelector(default=new_event[0], objects=new_event) - event_selector_heatmap = param.ObjectSelector(default=heatmap_options[0], objects=heatmap_options) - columns = columns_dict - df_new = df - - - colormaps = plt.colormaps() - new_colormaps = ['plasma', 'plasma_r', 'magma', 'magma_r', 'inferno', 'inferno_r', 'viridis', 'viridis_r'] - set_a = set(colormaps) - set_b = set(new_colormaps) - colormaps = new_colormaps + list(set_a.difference(set_b)) - - x_min = float(inputParameters['nSecPrev'])-20 - x_max = float(inputParameters['nSecPost'])+20 - selector_for_multipe_events_plot = param.ListSelector(default=[multiple_plots_options[0]], objects=multiple_plots_options) - x = param.ObjectSelector(default=columns[new_event[0]][-4], objects=[columns[new_event[0]][-4]]) - y = param.ObjectSelector(default=remove_cols(columns[new_event[0]])[-2], objects=remove_cols(columns[new_event[0]])) - - trial_no = range(1, len(remove_cols(columns[heatmap_options[0]])[:-2])+1) - trial_ts = ["{} - {}".format(i,j) for i,j in zip(trial_no, remove_cols(columns[heatmap_options[0]])[:-2])] + ['All'] - heatmap_y = param.ListSelector(default=[trial_ts[-1]], objects=trial_ts) - psth_y = param.ListSelector(objects=trial_ts[:-1]) - select_trials_checkbox = param.ListSelector(default=['just trials'], objects=['mean', 'just trials']) - Y_Label = param.ObjectSelector(default='y', objects=['y','z-score', '\u0394F/F']) - save_options = param.ObjectSelector(default='None' , objects=['None', 'save_png_format', 'save_svg_format', 'save_both_format']) - save_options_heatmap = param.ObjectSelector(default='None' , objects=['None', 'save_png_format', 'save_svg_format', 'save_both_format']) - color_map = param.ObjectSelector(default='plasma' , objects=colormaps) - height_heatmap = param.ObjectSelector(default=600, objects=list(np.arange(0,5100,100))[1:]) - width_heatmap = param.ObjectSelector(default=1000, objects=list(np.arange(0,5100,100))[1:]) - Height_Plot = param.ObjectSelector(default=300, objects=list(np.arange(0,5100,100))[1:]) - Width_Plot = param.ObjectSelector(default=1000, objects=list(np.arange(0,5100,100))[1:]) - save_hm = param.Action(lambda x: x.param.trigger('save_hm'), label='Save') - save_psth = param.Action(lambda x: x.param.trigger('save_psth'), label='Save') - X_Limit = param.Range(default=(-5, 10), bounds=(x_min,x_max)) - Y_Limit = param.Range(bounds=(-50, 50.0)) - - #C_Limit = param.Range(bounds=(-20,20.0)) - - results_hm = dict() - results_psth = dict() - - - # function to save heatmaps when save button on heatmap tab is clicked - @param.depends('save_hm', watch=True) - def save_hm_plots(self): - plot = self.results_hm['plot'] - op = self.results_hm['op'] - save_opts = self.save_options_heatmap - logger.info(save_opts) - if save_opts=='save_svg_format': - p = hv.render(plot, backend='bokeh') - p.output_backend = 'svg' - export_svgs(p, filename=op+'.svg') - elif save_opts=='save_png_format': - p = hv.render(plot, backend='bokeh') - export_png(p, filename=op+'.png') - elif save_opts=='save_both_format': - p = hv.render(plot, backend='bokeh') - p.output_backend = 'svg' - export_svgs(p, filename=op+'.svg') - p_png = hv.render(plot, backend='bokeh') - export_png(p_png, filename=op+'.png') - else: - return 0 - - - # function to save PSTH plots when save button on PSTH tab is clicked - @param.depends('save_psth', watch=True) - def save_psth_plot(self): - plot, op = [], [] - plot.append(self.results_psth['plot_combine']) - op.append(self.results_psth['op_combine']) - plot.append(self.results_psth['plot']) - op.append(self.results_psth['op']) - for i in range(len(plot)): - temp_plot, temp_op = plot[i], op[i] - save_opts = self.save_options - if save_opts=='save_svg_format': - p = hv.render(temp_plot, backend='bokeh') - p.output_backend = 'svg' - export_svgs(p, filename=temp_op+'.svg') - elif save_opts=='save_png_format': - p = hv.render(temp_plot, backend='bokeh') - export_png(p, filename=temp_op+'.png') - elif save_opts=='save_both_format': - p = hv.render(temp_plot, backend='bokeh') - p.output_backend = 'svg' - export_svgs(p, filename=temp_op+'.svg') - p_png = hv.render(temp_plot, backend='bokeh') - export_png(p_png, filename=temp_op+'.png') - else: - return 0 - - # function to change Y values based on event selection - @param.depends('event_selector', watch=True) - def _update_x_y(self): - x_value = self.columns[self.event_selector] - y_value = self.columns[self.event_selector] - self.param['x'].objects = [x_value[-4]] - self.param['y'].objects = remove_cols(y_value) - self.x = x_value[-4] - self.y = self.param['y'].objects[-2] - - @param.depends('event_selector_heatmap', watch=True) - def _update_df(self): - cols = self.columns[self.event_selector_heatmap] - trial_no = range(1, len(remove_cols(cols)[:-2])+1) - trial_ts = ["{} - {}".format(i,j) for i,j in zip(trial_no, remove_cols(cols)[:-2])] + ['All'] - self.param['heatmap_y'].objects = trial_ts - self.heatmap_y = [trial_ts[-1]] - - @param.depends('event_selector', watch=True) - def _update_psth_y(self): - cols = self.columns[self.event_selector] - trial_no = range(1, len(remove_cols(cols)[:-2])+1) - trial_ts = ["{} - {}".format(i,j) for i,j in zip(trial_no, remove_cols(cols)[:-2])] - self.param['psth_y'].objects = trial_ts - self.psth_y = [trial_ts[0]] - - # function to plot multiple PSTHs into one plot - @param.depends('selector_for_multipe_events_plot', 'Y_Label', 'save_options', 'X_Limit', 'Y_Limit', 'Height_Plot', 'Width_Plot') - def update_selector(self): - data_curve, cols_curve, data_spread, cols_spread = [], [], [], [] - arr = self.selector_for_multipe_events_plot - df1 = self.df_new - for i in range(len(arr)): - if 'bin' in arr[i]: - split = arr[i].rsplit('_',2) - df_name = split[0] #'{}_{}'.format(split[0], split[1]) - col_name_mean = '{}_{}'.format(split[-2], split[-1]) - col_name_err = '{}_err_{}'.format(split[-2], split[-1]) - data_curve.append(df1[df_name][col_name_mean]) - cols_curve.append(arr[i]) - data_spread.append(df1[df_name][col_name_err]) - cols_spread.append(arr[i]) - else: - data_curve.append(df1[arr[i]]['mean']) - cols_curve.append(arr[i]+'_'+'mean') - data_spread.append(df1[arr[i]]['err']) - cols_spread.append(arr[i]+'_'+'mean') - - - - if len(arr)>0: - if self.Y_Limit==None: - self.Y_Limit = (np.nanmin(np.asarray(data_curve))-0.5, np.nanmax(np.asarray(data_curve))+0.5) - - if 'bin' in arr[i]: - split = arr[i].rsplit('_', 2) - df_name = split[0] - data_curve.append(df1[df_name]['timestamps']) - cols_curve.append('timestamps') - data_spread.append(df1[df_name]['timestamps']) - cols_spread.append('timestamps') - else: - data_curve.append(df1[arr[i]]['timestamps']) - cols_curve.append('timestamps') - data_spread.append(df1[arr[i]]['timestamps']) - cols_spread.append('timestamps') - df_curve = pd.concat(data_curve, axis=1) - df_spread = pd.concat(data_spread, axis=1) - df_curve.columns = cols_curve - df_spread.columns = cols_spread - - ts = df_curve['timestamps'] - index = np.arange(0,ts.shape[0], 3) - df_curve = df_curve.loc[index, :] - df_spread = df_spread.loc[index, :] - overlay = hv.NdOverlay({c:hv.Curve((df_curve['timestamps'], df_curve[c]), kdims=['Time (s)']).opts(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit) for c in cols_curve[:-1]}) - spread = hv.NdOverlay({d:hv.Spread((df_spread['timestamps'], df_curve[d], df_spread[d], df_spread[d]), vdims=['y', 'yerrpos', 'yerrneg']).opts(line_width=0, fill_alpha=0.3) for d in cols_spread[:-1]}) - plot_combine = ((overlay * spread).opts(opts.NdOverlay(xlabel='Time (s)', ylabel=self.Y_Label))).opts(shared_axes=False) - #plot_err = new_df.hvplot.area(x='timestamps', y=[], y2=[]) - save_opts = self.save_options - op = make_dir(filepath) - op_filename = os.path.join(op, str(arr)+'_mean') - - self.results_psth['plot_combine'] = plot_combine - self.results_psth['op_combine'] = op_filename - #self.save_plots(plot_combine, save_opts, op_filename) - return plot_combine - - - # function to plot mean PSTH, single trial in PSTH and all the trials of PSTH with mean - @param.depends('event_selector', 'x', 'y', 'Y_Label', 'save_options', 'Y_Limit', 'X_Limit', 'Height_Plot', 'Width_Plot') - def contPlot(self): - df1 = self.df_new[self.event_selector] - #height = self.Heigth_Plot - #width = self.Width_Plot - #logger.info(height, width) - if self.y == 'All': - if self.Y_Limit==None: - self.Y_Limit = (np.nanmin(np.asarray(df1))-0.5, np.nanmax(np.asarray(df1))-0.5) - - - options = self.param['y'].objects - regex = re.compile('bin_[(]') - remove_bin_trials = [options[i] for i in range(len(options)) if not regex.match(options[i])] - - ndoverlay = hv.NdOverlay({c:hv.Curve((df1[self.x], df1[c])) for c in remove_bin_trials[:-2]}) - img1 = datashade(ndoverlay, normalization='linear', aggregator=ds.count()) - x_points = df1[self.x] - y_points = df1['mean'] - img2 = hv.Curve((x_points, y_points)) - img = (img1*img2).opts(opts.Curve(width=int(self.Width_Plot), height=int(self.Height_Plot), line_width=4, color='black', xlim=self.X_Limit, ylim=self.Y_Limit, xlabel='Time (s)', ylabel=self.Y_Label)) - - save_opts = self.save_options - - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector+'_'+self.y) - self.results_psth['plot'] = img - self.results_psth['op'] = op_filename - #self.save_plots(img, save_opts, op_filename) - - return img - - elif self.y == 'mean' or 'bin' in self.y: - - xpoints = df1[self.x] - ypoints = df1[self.y] - if self.y == 'mean': - err = df1['err'] - else: - split = self.y.split('_') - err = df1['{}_err_{}'.format(split[0],split[1])] - - index = np.arange(0, xpoints.shape[0], 3) - - if self.Y_Limit==None: - self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5) - - ropts_curve = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit, color='blue', xlabel='Time (s)', ylabel=self.Y_Label) - ropts_spread = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), fill_alpha=0.3, fill_color='blue', line_width=0) - - plot_curve = hv.Curve((xpoints[index], ypoints[index])) #.opts(**ropts_curve) - plot_spread = hv.Spread((xpoints[index], ypoints[index], err[index], err[index])) #.opts(**ropts_spread) #vdims=['y', 'yerrpos', 'yerrneg'] - plot = (plot_curve * plot_spread).opts({'Curve': ropts_curve, - 'Spread': ropts_spread}) - - save_opts = self.save_options - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector+'_'+self.y) - self.results_psth['plot'] = plot - self.results_psth['op'] = op_filename - #self.save_plots(plot, save_opts, op_filename) - - return plot - - else: - xpoints = df1[self.x] - ypoints = df1[self.y] - if self.Y_Limit==None: - self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5) - - ropts_curve = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit, color='blue', xlabel='Time (s)', ylabel=self.Y_Label) - plot = hv.Curve((xpoints, ypoints)).opts({'Curve': ropts_curve}) - - save_opts = self.save_options - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector+'_'+self.y) - self.results_psth['plot'] = plot - self.results_psth['op'] = op_filename - #self.save_plots(plot, save_opts, op_filename) - - return plot - - # function to plot specific PSTH trials - @param.depends('event_selector', 'x', 'psth_y', 'select_trials_checkbox', 'Y_Label', 'save_options', 'Y_Limit', 'X_Limit', 'Height_Plot', 'Width_Plot') - def plot_specific_trials(self): - df_psth = self.df_new[self.event_selector] - #if self.Y_Limit==None: - # self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5) - - if self.psth_y==None: - return None - else: - selected_trials = [s.split(' - ')[1] for s in list(self.psth_y)] - - index = np.arange(0, df_psth['timestamps'].shape[0], 3) - - if self.select_trials_checkbox==['just trials']: - overlay = hv.NdOverlay({c:hv.Curve((df_psth['timestamps'][index], df_psth[c][index]), kdims=['Time (s)']) for c in selected_trials}) - ropts = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit, xlabel='Time (s)', ylabel=self.Y_Label) - return overlay.opts(**ropts) - elif self.select_trials_checkbox==['mean']: - arr = np.asarray(df_psth[selected_trials]) - mean = np.nanmean(arr, axis=1) - err = np.nanstd(arr, axis=1)/math.sqrt(arr.shape[1]) - ropts_curve = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit, color='blue', xlabel='Time (s)', ylabel=self.Y_Label) - ropts_spread = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), fill_alpha=0.3, fill_color='blue', line_width=0) - plot_curve = hv.Curve((df_psth['timestamps'][index], mean[index])) - plot_spread = hv.Spread((df_psth['timestamps'][index], mean[index], err[index], err[index])) - plot = (plot_curve * plot_spread).opts({'Curve': ropts_curve, - 'Spread': ropts_spread}) - return plot - elif self.select_trials_checkbox==['mean', 'just trials']: - overlay = hv.NdOverlay({c:hv.Curve((df_psth['timestamps'][index], df_psth[c][index]), kdims=['Time (s)']) for c in selected_trials}) - ropts_overlay = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit, xlabel='Time (s)', ylabel=self.Y_Label) - - arr = np.asarray(df_psth[selected_trials]) - mean = np.nanmean(arr, axis=1) - err = np.nanstd(arr, axis=1)/math.sqrt(arr.shape[1]) - ropts_curve = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), xlim=self.X_Limit, ylim=self.Y_Limit, color='black', xlabel='Time (s)', ylabel=self.Y_Label) - ropts_spread = dict(width=int(self.Width_Plot), height=int(self.Height_Plot), fill_alpha=0.3, fill_color='black', line_width=0) - plot_curve = hv.Curve((df_psth['timestamps'][index], mean[index])) - plot_spread = hv.Spread((df_psth['timestamps'][index], mean[index], err[index], err[index])) - - plot = (plot_curve*plot_spread).opts({'Curve': ropts_curve, - 'Spread': ropts_spread}) - return overlay.opts(**ropts_overlay)*plot - - - # function to show heatmaps for each event - @param.depends('event_selector_heatmap', 'color_map', 'height_heatmap', 'width_heatmap', 'heatmap_y') - def heatmap(self): - height = self.height_heatmap - width = self.width_heatmap - df_hm = self.df_new[self.event_selector_heatmap] - cols = list(df_hm.columns) - regex = re.compile('bin_err_*') - drop_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] - drop_cols = ['err', 'mean'] + drop_cols - df_hm = df_hm.drop(drop_cols, axis=1) - cols = list(df_hm.columns) - bin_cols = [cols[i] for i in range(len(cols)) if re.compile('bin_*').match(cols[i])] - time = np.asarray(df_hm['timestamps']) - event_ts_for_each_event = np.arange(1,len(df_hm.columns[:-1])+1) - yticks = list(event_ts_for_each_event) - z_score = np.asarray(df_hm[df_hm.columns[:-1]]).T - - if self.heatmap_y[0]=='All': - indices = np.arange(z_score.shape[0]-len(bin_cols)) - z_score = z_score[indices,:] - event_ts_for_each_event = np.arange(1,z_score.shape[0]+1) - yticks = list(event_ts_for_each_event) - else: - remove_all = list(set(self.heatmap_y)-set(['All'])) - indices = sorted([int(s.split('-')[0])-1 for s in remove_all]) - z_score = z_score[indices,:] - event_ts_for_each_event = np.arange(1,z_score.shape[0]+1) - yticks = list(event_ts_for_each_event) - - clim = (np.nanmin(z_score), np.nanmax(z_score)) - font_size = {'labels': 16, 'yticks': 6} - - if event_ts_for_each_event.shape[0]==1: - dummy_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)).opts(colorbar=True, clim=clim) - image = ((dummy_image).opts(opts.QuadMesh(width=int(width), height=int(height), cmap=process_cmap(self.color_map, provider="matplotlib"), colorbar=True, ylabel='Trials', xlabel='Time (s)', fontsize=font_size, yticks=yticks))).opts(shared_axes=False) - - save_opts = self.save_options_heatmap - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector_heatmap+'_'+'heatmap') - self.results_hm['plot'] = image - self.results_hm['op'] = op_filename - #self.save_plots(image, save_opts, op_filename) - return image - else: - ropts = dict(width=int(width), height=int(height), ylabel='Trials', xlabel='Time (s)', fontsize=font_size, yticks=yticks, invert_yaxis=True) - dummy_image = hv.QuadMesh((time[0:100], event_ts_for_each_event, z_score[:,0:100])).opts(colorbar=True, cmap=process_cmap(self.color_map, provider="matplotlib"), clim=clim) - actual_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)) - - dynspread_img = datashade(actual_image, cmap=process_cmap(self.color_map, provider="matplotlib")).opts(**ropts) #clims=self.C_Limit, cnorm='log' - image = ((dummy_image * dynspread_img).opts(opts.QuadMesh(width=int(width), height=int(height)))).opts(shared_axes=False) - - save_opts = self.save_options_heatmap - op = make_dir(filepath) - op_filename = os.path.join(op, self.event_selector_heatmap+'_'+'heatmap') - self.results_hm['plot'] = image - self.results_hm['op'] = op_filename - - return image - - - view = Viewer() - - #PSTH plot options - psth_checkbox = pn.Param(view.param.select_trials_checkbox, widgets={ - 'select_trials_checkbox': {'type': pn.widgets.CheckBoxGroup, 'inline': True, - 'name': 'Select mean and/or just trials'}}) - parameters = pn.Param(view.param.selector_for_multipe_events_plot, widgets={ - 'selector_for_multipe_events_plot': {'type': pn.widgets.CrossSelector, 'width':550, 'align':'start'}}) - heatmap_y_parameters = pn.Param(view.param.heatmap_y, widgets={ - 'heatmap_y': {'type':pn.widgets.MultiSelect, 'name':'Trial # - Timestamps', 'width':200, 'size':30}}) - psth_y_parameters = pn.Param(view.param.psth_y, widgets={ - 'psth_y': {'type':pn.widgets.MultiSelect, 'name':'Trial # - Timestamps', 'width':200, 'size':15, 'align':'start'}}) - - event_selector = pn.Param(view.param.event_selector, widgets={ - 'event_selector': {'type':pn.widgets.Select, 'width':400}}) - x_selector = pn.Param(view.param.x, widgets={ - 'x': {'type':pn.widgets.Select, 'width':180}}) - y_selector = pn.Param(view.param.y, widgets={ - 'y': {'type':pn.widgets.Select, 'width':180}}) - - width_plot = pn.Param(view.param.Width_Plot, widgets={ - 'Width_Plot': {'type':pn.widgets.Select, 'width':70}}) - height_plot = pn.Param(view.param.Height_Plot, widgets={ - 'Height_Plot': {'type':pn.widgets.Select, 'width':70}}) - ylabel = pn.Param(view.param.Y_Label, widgets={ - 'Y_Label': {'type':pn.widgets.Select, 'width':70}}) - save_opts = pn.Param(view.param.save_options, widgets={ - 'save_options': {'type':pn.widgets.Select, 'width':70}}) - - xlimit_plot = pn.Param(view.param.X_Limit, widgets={ - 'X_Limit': {'type':pn.widgets.RangeSlider, 'width':180}}) - ylimit_plot = pn.Param(view.param.Y_Limit, widgets={ - 'Y_Limit': {'type':pn.widgets.RangeSlider, 'width':180}}) - save_psth = pn.Param(view.param.save_psth, widgets={ - 'save_psth': {'type':pn.widgets.Button, 'width':400}}) - - options = pn.Column(event_selector, - pn.Row(x_selector, y_selector), - pn.Row(xlimit_plot, ylimit_plot), - pn.Row(width_plot, height_plot, ylabel, save_opts), - save_psth) - - options_selectors = pn.Row(options, parameters) - - line_tab = pn.Column('## '+basename, - pn.Row(options_selectors, pn.Column(psth_checkbox, psth_y_parameters), width=1200), - view.contPlot, - view.update_selector, - view.plot_specific_trials) - - # Heatmap plot options - event_selector_heatmap = pn.Param(view.param.event_selector_heatmap, widgets={ - 'event_selector_heatmap': {'type':pn.widgets.Select, 'width':150}}) - color_map = pn.Param(view.param.color_map, widgets={ - 'color_map': {'type':pn.widgets.Select, 'width':150}}) - width_heatmap = pn.Param(view.param.width_heatmap, widgets={ - 'width_heatmap': {'type':pn.widgets.Select, 'width':150}}) - height_heatmap = pn.Param(view.param.height_heatmap, widgets={ - 'height_heatmap': {'type':pn.widgets.Select, 'width':150}}) - save_hm = pn.Param(view.param.save_hm, widgets={ - 'save_hm': {'type':pn.widgets.Button, 'width':150}}) - save_options_heatmap = pn.Param(view.param.save_options_heatmap, widgets={ - 'save_options_heatmap': {'type':pn.widgets.Select, 'width':150}}) - - hm_tab = pn.Column('## '+basename, pn.Row(event_selector_heatmap, color_map, - width_heatmap, height_heatmap, - save_options_heatmap, pn.Column(pn.Spacer(height=25), save_hm)), - pn.Row(view.heatmap, heatmap_y_parameters)) # - logger.info('app') - - template = pn.template.MaterialTemplate(title='Visualization GUI') - - number = scanPortsAndFind(start_port=5000, end_port=5200) - app = pn.Tabs(('PSTH', line_tab), - ('Heat Map', hm_tab)) - - template.main.append(app) - - template.show(port=number) - + basename = os.path.basename(filepath) + visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"] + + # note when there are no behavior event TTLs + if len(event) == 0: + logger.warning("\033[1m" + "There are no behavior event TTLs present to visualize.".format(event) + "\033[0m") + return 0 + + if os.path.exists(os.path.join(filepath, "cross_correlation_output")): + event_corr, frames = [], [] + if visualize_zscore_or_dff == "z_score": + corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_z_score_*")) + elif visualize_zscore_or_dff == "dff": + corr_fp = glob.glob(os.path.join(filepath, "cross_correlation_output", "*_dff_*")) + for i in range(len(corr_fp)): + filename = os.path.basename(corr_fp[i]).split(".")[0] + event_corr.append(filename) + df = pd.read_hdf(corr_fp[i], key="df", mode="r") + frames.append(df) + if len(frames) > 0: + df_corr = pd.concat(frames, keys=event_corr, axis=1) + else: + event_corr = [] + df_corr = [] + else: + event_corr = [] + df_corr = None + + # combine all the event PSTH so that it can be viewed together + if name: + event_name, name = event, name + new_event, frames, bins = [], [], {} + for i in range(len(event_name)): + + for j in range(len(name)): + new_event.append(event_name[i] + "_" + name[j].split("_")[-1]) + new_name = name[j] + temp_df = read_Df(filepath, new_event[-1], new_name) + cols = list(temp_df.columns) + regex = re.compile("bin_[(]") + bins[new_event[-1]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + # bins.append(keep_cols) + frames.append(temp_df) + + df = pd.concat(frames, keys=new_event, axis=1) + else: + new_event = list(np.unique(np.array(event))) + frames, bins = [], {} + for i in range(len(new_event)): + temp_df = read_Df(filepath, new_event[i], "") + cols = list(temp_df.columns) + regex = re.compile("bin_[(]") + bins[new_event[i]] = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + frames.append(temp_df) + + df = pd.concat(frames, keys=new_event, axis=1) + + if isinstance(df_corr, pd.DataFrame): + new_event.extend(event_corr) + df = pd.concat([df, df_corr], axis=1, sort=False).reset_index() + + columns_dict = dict() + for i in range(len(new_event)): + df_1 = df[new_event[i]] + columns = list(df_1.columns) + columns.append("All") + columns_dict[new_event[i]] = columns + + # create a class to make GUI and plot different graphs + class Viewer(param.Parameterized): + + # class_event = new_event + + # make options array for different selectors + multiple_plots_options = [] + heatmap_options = new_event + bins_keys = list(bins.keys()) + if len(bins_keys) > 0: + bins_new = bins + for i in range(len(bins_keys)): + arr = bins[bins_keys[i]] + if len(arr) > 0: + # heatmap_options.append('{}_bin'.format(bins_keys[i])) + for j in arr: + multiple_plots_options.append("{}_{}".format(bins_keys[i], j)) + + multiple_plots_options = new_event + multiple_plots_options + else: + multiple_plots_options = new_event + + # create different options and selectors + event_selector = param.ObjectSelector(default=new_event[0], objects=new_event) + event_selector_heatmap = param.ObjectSelector(default=heatmap_options[0], objects=heatmap_options) + columns = columns_dict + df_new = df + + colormaps = plt.colormaps() + new_colormaps = ["plasma", "plasma_r", "magma", "magma_r", "inferno", "inferno_r", "viridis", "viridis_r"] + set_a = set(colormaps) + set_b = set(new_colormaps) + colormaps = new_colormaps + list(set_a.difference(set_b)) + + x_min = float(inputParameters["nSecPrev"]) - 20 + x_max = float(inputParameters["nSecPost"]) + 20 + selector_for_multipe_events_plot = param.ListSelector( + default=[multiple_plots_options[0]], objects=multiple_plots_options + ) + x = param.ObjectSelector(default=columns[new_event[0]][-4], objects=[columns[new_event[0]][-4]]) + y = param.ObjectSelector( + default=remove_cols(columns[new_event[0]])[-2], objects=remove_cols(columns[new_event[0]]) + ) + + trial_no = range(1, len(remove_cols(columns[heatmap_options[0]])[:-2]) + 1) + trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(columns[heatmap_options[0]])[:-2])] + [ + "All" + ] + heatmap_y = param.ListSelector(default=[trial_ts[-1]], objects=trial_ts) + psth_y = param.ListSelector(objects=trial_ts[:-1]) + select_trials_checkbox = param.ListSelector(default=["just trials"], objects=["mean", "just trials"]) + Y_Label = param.ObjectSelector(default="y", objects=["y", "z-score", "\u0394F/F"]) + save_options = param.ObjectSelector( + default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"] + ) + save_options_heatmap = param.ObjectSelector( + default="None", objects=["None", "save_png_format", "save_svg_format", "save_both_format"] + ) + color_map = param.ObjectSelector(default="plasma", objects=colormaps) + height_heatmap = param.ObjectSelector(default=600, objects=list(np.arange(0, 5100, 100))[1:]) + width_heatmap = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:]) + Height_Plot = param.ObjectSelector(default=300, objects=list(np.arange(0, 5100, 100))[1:]) + Width_Plot = param.ObjectSelector(default=1000, objects=list(np.arange(0, 5100, 100))[1:]) + save_hm = param.Action(lambda x: x.param.trigger("save_hm"), label="Save") + save_psth = param.Action(lambda x: x.param.trigger("save_psth"), label="Save") + X_Limit = param.Range(default=(-5, 10), bounds=(x_min, x_max)) + Y_Limit = param.Range(bounds=(-50, 50.0)) + + # C_Limit = param.Range(bounds=(-20,20.0)) + + results_hm = dict() + results_psth = dict() + + # function to save heatmaps when save button on heatmap tab is clicked + @param.depends("save_hm", watch=True) + def save_hm_plots(self): + plot = self.results_hm["plot"] + op = self.results_hm["op"] + save_opts = self.save_options_heatmap + logger.info(save_opts) + if save_opts == "save_svg_format": + p = hv.render(plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=op + ".svg") + elif save_opts == "save_png_format": + p = hv.render(plot, backend="bokeh") + export_png(p, filename=op + ".png") + elif save_opts == "save_both_format": + p = hv.render(plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=op + ".svg") + p_png = hv.render(plot, backend="bokeh") + export_png(p_png, filename=op + ".png") + else: + return 0 + + # function to save PSTH plots when save button on PSTH tab is clicked + @param.depends("save_psth", watch=True) + def save_psth_plot(self): + plot, op = [], [] + plot.append(self.results_psth["plot_combine"]) + op.append(self.results_psth["op_combine"]) + plot.append(self.results_psth["plot"]) + op.append(self.results_psth["op"]) + for i in range(len(plot)): + temp_plot, temp_op = plot[i], op[i] + save_opts = self.save_options + if save_opts == "save_svg_format": + p = hv.render(temp_plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=temp_op + ".svg") + elif save_opts == "save_png_format": + p = hv.render(temp_plot, backend="bokeh") + export_png(p, filename=temp_op + ".png") + elif save_opts == "save_both_format": + p = hv.render(temp_plot, backend="bokeh") + p.output_backend = "svg" + export_svgs(p, filename=temp_op + ".svg") + p_png = hv.render(temp_plot, backend="bokeh") + export_png(p_png, filename=temp_op + ".png") + else: + return 0 + + # function to change Y values based on event selection + @param.depends("event_selector", watch=True) + def _update_x_y(self): + x_value = self.columns[self.event_selector] + y_value = self.columns[self.event_selector] + self.param["x"].objects = [x_value[-4]] + self.param["y"].objects = remove_cols(y_value) + self.x = x_value[-4] + self.y = self.param["y"].objects[-2] + + @param.depends("event_selector_heatmap", watch=True) + def _update_df(self): + cols = self.columns[self.event_selector_heatmap] + trial_no = range(1, len(remove_cols(cols)[:-2]) + 1) + trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + ["All"] + self.param["heatmap_y"].objects = trial_ts + self.heatmap_y = [trial_ts[-1]] + + @param.depends("event_selector", watch=True) + def _update_psth_y(self): + cols = self.columns[self.event_selector] + trial_no = range(1, len(remove_cols(cols)[:-2]) + 1) + trial_ts = ["{} - {}".format(i, j) for i, j in zip(trial_no, remove_cols(cols)[:-2])] + self.param["psth_y"].objects = trial_ts + self.psth_y = [trial_ts[0]] + + # function to plot multiple PSTHs into one plot + + @param.depends( + "selector_for_multipe_events_plot", + "Y_Label", + "save_options", + "X_Limit", + "Y_Limit", + "Height_Plot", + "Width_Plot", + ) + def update_selector(self): + data_curve, cols_curve, data_spread, cols_spread = [], [], [], [] + arr = self.selector_for_multipe_events_plot + df1 = self.df_new + for i in range(len(arr)): + if "bin" in arr[i]: + split = arr[i].rsplit("_", 2) + df_name = split[0] #'{}_{}'.format(split[0], split[1]) + col_name_mean = "{}_{}".format(split[-2], split[-1]) + col_name_err = "{}_err_{}".format(split[-2], split[-1]) + data_curve.append(df1[df_name][col_name_mean]) + cols_curve.append(arr[i]) + data_spread.append(df1[df_name][col_name_err]) + cols_spread.append(arr[i]) + else: + data_curve.append(df1[arr[i]]["mean"]) + cols_curve.append(arr[i] + "_" + "mean") + data_spread.append(df1[arr[i]]["err"]) + cols_spread.append(arr[i] + "_" + "mean") + + if len(arr) > 0: + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(np.asarray(data_curve)) - 0.5, np.nanmax(np.asarray(data_curve)) + 0.5) + + if "bin" in arr[i]: + split = arr[i].rsplit("_", 2) + df_name = split[0] + data_curve.append(df1[df_name]["timestamps"]) + cols_curve.append("timestamps") + data_spread.append(df1[df_name]["timestamps"]) + cols_spread.append("timestamps") + else: + data_curve.append(df1[arr[i]]["timestamps"]) + cols_curve.append("timestamps") + data_spread.append(df1[arr[i]]["timestamps"]) + cols_spread.append("timestamps") + df_curve = pd.concat(data_curve, axis=1) + df_spread = pd.concat(data_spread, axis=1) + df_curve.columns = cols_curve + df_spread.columns = cols_spread + + ts = df_curve["timestamps"] + index = np.arange(0, ts.shape[0], 3) + df_curve = df_curve.loc[index, :] + df_spread = df_spread.loc[index, :] + overlay = hv.NdOverlay( + { + c: hv.Curve((df_curve["timestamps"], df_curve[c]), kdims=["Time (s)"]).opts( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + ) + for c in cols_curve[:-1] + } + ) + spread = hv.NdOverlay( + { + d: hv.Spread( + (df_spread["timestamps"], df_curve[d], df_spread[d], df_spread[d]), + vdims=["y", "yerrpos", "yerrneg"], + ).opts(line_width=0, fill_alpha=0.3) + for d in cols_spread[:-1] + } + ) + plot_combine = ((overlay * spread).opts(opts.NdOverlay(xlabel="Time (s)", ylabel=self.Y_Label))).opts( + shared_axes=False + ) + # plot_err = new_df.hvplot.area(x='timestamps', y=[], y2=[]) + save_opts = self.save_options + op = make_dir(filepath) + op_filename = os.path.join(op, str(arr) + "_mean") + + self.results_psth["plot_combine"] = plot_combine + self.results_psth["op_combine"] = op_filename + # self.save_plots(plot_combine, save_opts, op_filename) + return plot_combine + + # function to plot mean PSTH, single trial in PSTH and all the trials of PSTH with mean + @param.depends( + "event_selector", "x", "y", "Y_Label", "save_options", "Y_Limit", "X_Limit", "Height_Plot", "Width_Plot" + ) + def contPlot(self): + df1 = self.df_new[self.event_selector] + # height = self.Heigth_Plot + # width = self.Width_Plot + # logger.info(height, width) + if self.y == "All": + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(np.asarray(df1)) - 0.5, np.nanmax(np.asarray(df1)) - 0.5) + + options = self.param["y"].objects + regex = re.compile("bin_[(]") + remove_bin_trials = [options[i] for i in range(len(options)) if not regex.match(options[i])] + + ndoverlay = hv.NdOverlay({c: hv.Curve((df1[self.x], df1[c])) for c in remove_bin_trials[:-2]}) + img1 = datashade(ndoverlay, normalization="linear", aggregator=ds.count()) + x_points = df1[self.x] + y_points = df1["mean"] + img2 = hv.Curve((x_points, y_points)) + img = (img1 * img2).opts( + opts.Curve( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + line_width=4, + color="black", + xlim=self.X_Limit, + ylim=self.Y_Limit, + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ) + + save_opts = self.save_options + + op = make_dir(filepath) + op_filename = os.path.join(op, self.event_selector + "_" + self.y) + self.results_psth["plot"] = img + self.results_psth["op"] = op_filename + # self.save_plots(img, save_opts, op_filename) + + return img + + elif self.y == "mean" or "bin" in self.y: + + xpoints = df1[self.x] + ypoints = df1[self.y] + if self.y == "mean": + err = df1["err"] + else: + split = self.y.split("_") + err = df1["{}_err_{}".format(split[0], split[1])] + + index = np.arange(0, xpoints.shape[0], 3) + + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5) + + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="blue", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ropts_spread = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + fill_alpha=0.3, + fill_color="blue", + line_width=0, + ) + + plot_curve = hv.Curve((xpoints[index], ypoints[index])) # .opts(**ropts_curve) + plot_spread = hv.Spread( + (xpoints[index], ypoints[index], err[index], err[index]) + ) # .opts(**ropts_spread) #vdims=['y', 'yerrpos', 'yerrneg'] + plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) + + save_opts = self.save_options + op = make_dir(filepath) + op_filename = os.path.join(op, self.event_selector + "_" + self.y) + self.results_psth["plot"] = plot + self.results_psth["op"] = op_filename + # self.save_plots(plot, save_opts, op_filename) + + return plot + + else: + xpoints = df1[self.x] + ypoints = df1[self.y] + if self.Y_Limit == None: + self.Y_Limit = (np.nanmin(ypoints) - 0.5, np.nanmax(ypoints) + 0.5) + + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="blue", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + plot = hv.Curve((xpoints, ypoints)).opts({"Curve": ropts_curve}) + + save_opts = self.save_options + op = make_dir(filepath) + op_filename = os.path.join(op, self.event_selector + "_" + self.y) + self.results_psth["plot"] = plot + self.results_psth["op"] = op_filename + # self.save_plots(plot, save_opts, op_filename) + + return plot + + # function to plot specific PSTH trials + @param.depends( + "event_selector", + "x", + "psth_y", + "select_trials_checkbox", + "Y_Label", + "save_options", + "Y_Limit", + "X_Limit", + "Height_Plot", + "Width_Plot", + ) + def plot_specific_trials(self): + df_psth = self.df_new[self.event_selector] + # if self.Y_Limit==None: + # self.Y_Limit = (np.nanmin(ypoints)-0.5, np.nanmax(ypoints)+0.5) + + if self.psth_y == None: + return None + else: + selected_trials = [s.split(" - ")[1] for s in list(self.psth_y)] + + index = np.arange(0, df_psth["timestamps"].shape[0], 3) + + if self.select_trials_checkbox == ["just trials"]: + overlay = hv.NdOverlay( + { + c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"]) + for c in selected_trials + } + ) + ropts = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + return overlay.opts(**ropts) + elif self.select_trials_checkbox == ["mean"]: + arr = np.asarray(df_psth[selected_trials]) + mean = np.nanmean(arr, axis=1) + err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1]) + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="blue", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ropts_spread = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + fill_alpha=0.3, + fill_color="blue", + line_width=0, + ) + plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index])) + plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index])) + plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) + return plot + elif self.select_trials_checkbox == ["mean", "just trials"]: + overlay = hv.NdOverlay( + { + c: hv.Curve((df_psth["timestamps"][index], df_psth[c][index]), kdims=["Time (s)"]) + for c in selected_trials + } + ) + ropts_overlay = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + + arr = np.asarray(df_psth[selected_trials]) + mean = np.nanmean(arr, axis=1) + err = np.nanstd(arr, axis=1) / math.sqrt(arr.shape[1]) + ropts_curve = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + xlim=self.X_Limit, + ylim=self.Y_Limit, + color="black", + xlabel="Time (s)", + ylabel=self.Y_Label, + ) + ropts_spread = dict( + width=int(self.Width_Plot), + height=int(self.Height_Plot), + fill_alpha=0.3, + fill_color="black", + line_width=0, + ) + plot_curve = hv.Curve((df_psth["timestamps"][index], mean[index])) + plot_spread = hv.Spread((df_psth["timestamps"][index], mean[index], err[index], err[index])) + + plot = (plot_curve * plot_spread).opts({"Curve": ropts_curve, "Spread": ropts_spread}) + return overlay.opts(**ropts_overlay) * plot + + # function to show heatmaps for each event + @param.depends("event_selector_heatmap", "color_map", "height_heatmap", "width_heatmap", "heatmap_y") + def heatmap(self): + height = self.height_heatmap + width = self.width_heatmap + df_hm = self.df_new[self.event_selector_heatmap] + cols = list(df_hm.columns) + regex = re.compile("bin_err_*") + drop_cols = [cols[i] for i in range(len(cols)) if regex.match(cols[i])] + drop_cols = ["err", "mean"] + drop_cols + df_hm = df_hm.drop(drop_cols, axis=1) + cols = list(df_hm.columns) + bin_cols = [cols[i] for i in range(len(cols)) if re.compile("bin_*").match(cols[i])] + time = np.asarray(df_hm["timestamps"]) + event_ts_for_each_event = np.arange(1, len(df_hm.columns[:-1]) + 1) + yticks = list(event_ts_for_each_event) + z_score = np.asarray(df_hm[df_hm.columns[:-1]]).T + + if self.heatmap_y[0] == "All": + indices = np.arange(z_score.shape[0] - len(bin_cols)) + z_score = z_score[indices, :] + event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1) + yticks = list(event_ts_for_each_event) + else: + remove_all = list(set(self.heatmap_y) - set(["All"])) + indices = sorted([int(s.split("-")[0]) - 1 for s in remove_all]) + z_score = z_score[indices, :] + event_ts_for_each_event = np.arange(1, z_score.shape[0] + 1) + yticks = list(event_ts_for_each_event) + + clim = (np.nanmin(z_score), np.nanmax(z_score)) + font_size = {"labels": 16, "yticks": 6} + + if event_ts_for_each_event.shape[0] == 1: + dummy_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)).opts(colorbar=True, clim=clim) + image = ( + (dummy_image).opts( + opts.QuadMesh( + width=int(width), + height=int(height), + cmap=process_cmap(self.color_map, provider="matplotlib"), + colorbar=True, + ylabel="Trials", + xlabel="Time (s)", + fontsize=font_size, + yticks=yticks, + ) + ) + ).opts(shared_axes=False) + + save_opts = self.save_options_heatmap + op = make_dir(filepath) + op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap") + self.results_hm["plot"] = image + self.results_hm["op"] = op_filename + # self.save_plots(image, save_opts, op_filename) + return image + else: + ropts = dict( + width=int(width), + height=int(height), + ylabel="Trials", + xlabel="Time (s)", + fontsize=font_size, + yticks=yticks, + invert_yaxis=True, + ) + dummy_image = hv.QuadMesh((time[0:100], event_ts_for_each_event, z_score[:, 0:100])).opts( + colorbar=True, cmap=process_cmap(self.color_map, provider="matplotlib"), clim=clim + ) + actual_image = hv.QuadMesh((time, event_ts_for_each_event, z_score)) + + dynspread_img = datashade(actual_image, cmap=process_cmap(self.color_map, provider="matplotlib")).opts( + **ropts + ) # clims=self.C_Limit, cnorm='log' + image = ((dummy_image * dynspread_img).opts(opts.QuadMesh(width=int(width), height=int(height)))).opts( + shared_axes=False + ) + + save_opts = self.save_options_heatmap + op = make_dir(filepath) + op_filename = os.path.join(op, self.event_selector_heatmap + "_" + "heatmap") + self.results_hm["plot"] = image + self.results_hm["op"] = op_filename + + return image + + view = Viewer() + + # PSTH plot options + psth_checkbox = pn.Param( + view.param.select_trials_checkbox, + widgets={ + "select_trials_checkbox": { + "type": pn.widgets.CheckBoxGroup, + "inline": True, + "name": "Select mean and/or just trials", + } + }, + ) + parameters = pn.Param( + view.param.selector_for_multipe_events_plot, + widgets={ + "selector_for_multipe_events_plot": {"type": pn.widgets.CrossSelector, "width": 550, "align": "start"} + }, + ) + heatmap_y_parameters = pn.Param( + view.param.heatmap_y, + widgets={ + "heatmap_y": {"type": pn.widgets.MultiSelect, "name": "Trial # - Timestamps", "width": 200, "size": 30} + }, + ) + psth_y_parameters = pn.Param( + view.param.psth_y, + widgets={ + "psth_y": { + "type": pn.widgets.MultiSelect, + "name": "Trial # - Timestamps", + "width": 200, + "size": 15, + "align": "start", + } + }, + ) + + event_selector = pn.Param( + view.param.event_selector, widgets={"event_selector": {"type": pn.widgets.Select, "width": 400}} + ) + x_selector = pn.Param(view.param.x, widgets={"x": {"type": pn.widgets.Select, "width": 180}}) + y_selector = pn.Param(view.param.y, widgets={"y": {"type": pn.widgets.Select, "width": 180}}) + + width_plot = pn.Param(view.param.Width_Plot, widgets={"Width_Plot": {"type": pn.widgets.Select, "width": 70}}) + height_plot = pn.Param(view.param.Height_Plot, widgets={"Height_Plot": {"type": pn.widgets.Select, "width": 70}}) + ylabel = pn.Param(view.param.Y_Label, widgets={"Y_Label": {"type": pn.widgets.Select, "width": 70}}) + save_opts = pn.Param(view.param.save_options, widgets={"save_options": {"type": pn.widgets.Select, "width": 70}}) + + xlimit_plot = pn.Param(view.param.X_Limit, widgets={"X_Limit": {"type": pn.widgets.RangeSlider, "width": 180}}) + ylimit_plot = pn.Param(view.param.Y_Limit, widgets={"Y_Limit": {"type": pn.widgets.RangeSlider, "width": 180}}) + save_psth = pn.Param(view.param.save_psth, widgets={"save_psth": {"type": pn.widgets.Button, "width": 400}}) + + options = pn.Column( + event_selector, + pn.Row(x_selector, y_selector), + pn.Row(xlimit_plot, ylimit_plot), + pn.Row(width_plot, height_plot, ylabel, save_opts), + save_psth, + ) + + options_selectors = pn.Row(options, parameters) + + line_tab = pn.Column( + "## " + basename, + pn.Row(options_selectors, pn.Column(psth_checkbox, psth_y_parameters), width=1200), + view.contPlot, + view.update_selector, + view.plot_specific_trials, + ) + + # Heatmap plot options + event_selector_heatmap = pn.Param( + view.param.event_selector_heatmap, widgets={"event_selector_heatmap": {"type": pn.widgets.Select, "width": 150}} + ) + color_map = pn.Param(view.param.color_map, widgets={"color_map": {"type": pn.widgets.Select, "width": 150}}) + width_heatmap = pn.Param( + view.param.width_heatmap, widgets={"width_heatmap": {"type": pn.widgets.Select, "width": 150}} + ) + height_heatmap = pn.Param( + view.param.height_heatmap, widgets={"height_heatmap": {"type": pn.widgets.Select, "width": 150}} + ) + save_hm = pn.Param(view.param.save_hm, widgets={"save_hm": {"type": pn.widgets.Button, "width": 150}}) + save_options_heatmap = pn.Param( + view.param.save_options_heatmap, widgets={"save_options_heatmap": {"type": pn.widgets.Select, "width": 150}} + ) + + hm_tab = pn.Column( + "## " + basename, + pn.Row( + event_selector_heatmap, + color_map, + width_heatmap, + height_heatmap, + save_options_heatmap, + pn.Column(pn.Spacer(height=25), save_hm), + ), + pn.Row(view.heatmap, heatmap_y_parameters), + ) # + logger.info("app") + + template = pn.template.MaterialTemplate(title="Visualization GUI") + + number = scanPortsAndFind(start_port=5000, end_port=5200) + app = pn.Tabs(("PSTH", line_tab), ("Heat Map", hm_tab)) + + template.main.append(app) + + template.show(port=number) # function to combine all the output folders together and preprocess them to use them in helper_plots function def createPlots(filepath, event, inputParameters): - for i in range(len(event)): - event[i] = event[i].replace("\\","_") - event[i] = event[i].replace("/","_") + for i in range(len(event)): + event[i] = event[i].replace("\\", "_") + event[i] = event[i].replace("/", "_") - average = inputParameters['visualizeAverageResults'] - visualize_zscore_or_dff = inputParameters['visualize_zscore_or_dff'] + average = inputParameters["visualizeAverageResults"] + visualize_zscore_or_dff = inputParameters["visualize_zscore_or_dff"] - if average==True: - path = [] - for i in range(len(event)): - if visualize_zscore_or_dff=='z_score': - path.append(glob.glob(os.path.join(filepath, event[i]+'*_z_score_*'))) - elif visualize_zscore_or_dff=='dff': - path.append(glob.glob(os.path.join(filepath, event[i]+'*_dff_*'))) + if average == True: + path = [] + for i in range(len(event)): + if visualize_zscore_or_dff == "z_score": + path.append(glob.glob(os.path.join(filepath, event[i] + "*_z_score_*"))) + elif visualize_zscore_or_dff == "dff": + path.append(glob.glob(os.path.join(filepath, event[i] + "*_dff_*"))) - path = np.concatenate(path) - else: - if visualize_zscore_or_dff=='z_score': - path = glob.glob(os.path.join(filepath, 'z_score_*')) - elif visualize_zscore_or_dff=='dff': - path = glob.glob(os.path.join(filepath, 'dff_*')) + path = np.concatenate(path) + else: + if visualize_zscore_or_dff == "z_score": + path = glob.glob(os.path.join(filepath, "z_score_*")) + elif visualize_zscore_or_dff == "dff": + path = glob.glob(os.path.join(filepath, "dff_*")) - name_arr = [] - event_arr = [] - - indx = [] - for i in range(len(event)): - if 'control' in event[i].lower() or 'signal' in event[i].lower(): - indx.append(i) + name_arr = [] + event_arr = [] - event = np.delete(event, indx) - - for i in range(len(path)): - name = (os.path.basename(path[i])).split('.') - name = name[0] - name_arr.append(name) + index = [] + for i in range(len(event)): + if "control" in event[i].lower() or "signal" in event[i].lower(): + index.append(i) + event = np.delete(event, index) - if average==True: - logger.info('average') - helper_plots(filepath, name_arr, '', inputParameters) - else: - helper_plots(filepath, event, name_arr, inputParameters) + for i in range(len(path)): + name = (os.path.basename(path[i])).split(".") + name = name[0] + name_arr.append(name) + if average == True: + logger.info("average") + helper_plots(filepath, name_arr, "", inputParameters) + else: + helper_plots(filepath, event, name_arr, inputParameters) def visualizeResults(inputParameters): - - inputParameters = inputParameters - - - average = inputParameters['visualizeAverageResults'] - logger.info(average) - - folderNames = inputParameters['folderNames'] - folderNamesForAvg = inputParameters['folderNamesForAvg'] - combine_data = inputParameters['combine_data'] - - if average==True and len(folderNamesForAvg)>0: - #folderNames = folderNamesForAvg - filepath_avg = os.path.join(inputParameters['abspath'], 'average') - #filepath = os.path.join(inputParameters['abspath'], folderNames[0]) - storesListPath = [] - for i in range(len(folderNamesForAvg)): - filepath = folderNamesForAvg[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - storesListPath = np.concatenate(storesListPath) - storesList = np.asarray([[],[]]) - for i in range(storesListPath.shape[0]): - storesList = np.concatenate((storesList, np.genfromtxt(os.path.join(storesListPath[i], 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1)), axis=1) - storesList = np.unique(storesList, axis=1) - - createPlots(filepath_avg, np.unique(storesList[1,:]), inputParameters) - - else: - if combine_data==True: - storesListPath = [] - for i in range(len(folderNames)): - filepath = folderNames[i] - storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*')))) - storesListPath = list(np.concatenate(storesListPath).flatten()) - op = get_all_stores_for_combining_data(storesListPath) - for i in range(len(op)): - storesList = np.asarray([[],[]]) - for j in range(len(op[i])): - storesList = np.concatenate((storesList, np.genfromtxt(os.path.join(op[i][j], 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1)), axis=1) - storesList = np.unique(storesList, axis=1) - filepath = op[i][0] - createPlots(filepath, storesList[1,:], inputParameters) - else: - for i in range(len(folderNames)): - - filepath = folderNames[i] - storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, '*_output_*'))) - for j in range(len(storesListPath)): - filepath = storesListPath[j] - storesList = np.genfromtxt(os.path.join(filepath, 'storesList.csv'), dtype='str', delimiter=',').reshape(2,-1) - - createPlots(filepath, storesList[1,:], inputParameters) - - -#logger.info(sys.argv[1:]) -#visualizeResults(sys.argv[1:][0]) + inputParameters = inputParameters + + average = inputParameters["visualizeAverageResults"] + logger.info(average) + + folderNames = inputParameters["folderNames"] + folderNamesForAvg = inputParameters["folderNamesForAvg"] + combine_data = inputParameters["combine_data"] + + if average == True and len(folderNamesForAvg) > 0: + # folderNames = folderNamesForAvg + filepath_avg = os.path.join(inputParameters["abspath"], "average") + # filepath = os.path.join(inputParameters['abspath'], folderNames[0]) + storesListPath = [] + for i in range(len(folderNamesForAvg)): + filepath = folderNamesForAvg[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = np.concatenate(storesListPath) + storesList = np.asarray([[], []]) + for i in range(storesListPath.shape[0]): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt( + os.path.join(storesListPath[i], "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + + createPlots(filepath_avg, np.unique(storesList[1, :]), inputParameters) + + else: + if combine_data == True: + storesListPath = [] + for i in range(len(folderNames)): + filepath = folderNames[i] + storesListPath.append(takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*")))) + storesListPath = list(np.concatenate(storesListPath).flatten()) + op = get_all_stores_for_combining_data(storesListPath) + for i in range(len(op)): + storesList = np.asarray([[], []]) + for j in range(len(op[i])): + storesList = np.concatenate( + ( + storesList, + np.genfromtxt(os.path.join(op[i][j], "storesList.csv"), dtype="str", delimiter=",").reshape( + 2, -1 + ), + ), + axis=1, + ) + storesList = np.unique(storesList, axis=1) + filepath = op[i][0] + createPlots(filepath, storesList[1, :], inputParameters) + else: + for i in range(len(folderNames)): + + filepath = folderNames[i] + storesListPath = takeOnlyDirs(glob.glob(os.path.join(filepath, "*_output_*"))) + for j in range(len(storesListPath)): + filepath = storesListPath[j] + storesList = np.genfromtxt( + os.path.join(filepath, "storesList.csv"), dtype="str", delimiter="," + ).reshape(2, -1) + + createPlots(filepath, storesList[1, :], inputParameters) + + +# logger.info(sys.argv[1:]) +# visualizeResults(sys.argv[1:][0]) diff --git a/tests/test_step1.py b/tests/test_step1.py index a428531..a832e48 100644 --- a/tests/test_step1.py +++ b/tests/test_step1.py @@ -26,35 +26,13 @@ def default_parameters(): "use_time_or_trials": "Time (min)", "baselineCorrectionStart": -5, "baselineCorrectionEnd": 0, - "peak_startPoint": [ - -5.0, - 0.0, - 5.0, - np.nan, - np.nan, - np.nan, - np.nan, - np.nan, - np.nan, - np.nan - ], - "peak_endPoint": [ - 0.0, - 3.0, - 10.0, - np.nan, - np.nan, - np.nan, - np.nan, - np.nan, - np.nan, - np.nan - ], + "peak_startPoint": [-5.0, 0.0, 5.0, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + "peak_endPoint": [0.0, 3.0, 10.0, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], "selectForComputePsth": "z_score", "selectForTransientsComputation": "z_score", "moving_window": 15, "highAmpFilt": 2, - "transientsThresh": 3 + "transientsThresh": 3, } diff --git a/tests/test_step2.py b/tests/test_step2.py index 7e11287..76fb0c5 100644 --- a/tests/test_step2.py +++ b/tests/test_step2.py @@ -1,6 +1,6 @@ import csv -import os import glob +import os import shutil from pathlib import Path diff --git a/tests/test_step3.py b/tests/test_step3.py index cf46c42..68424ea 100644 --- a/tests/test_step3.py +++ b/tests/test_step3.py @@ -1,11 +1,11 @@ -import os import csv import glob +import os import shutil +from pathlib import Path import h5py import pytest -from pathlib import Path from guppy.testing.api import step2, step3 @@ -18,6 +18,7 @@ def storenames_map(): "Sample_TTL": "ttl", } + @pytest.mark.parametrize( "session_subdir, storenames_map", [ diff --git a/tests/test_step4.py b/tests/test_step4.py index be83740..760f69e 100644 --- a/tests/test_step4.py +++ b/tests/test_step4.py @@ -1,13 +1,14 @@ -import os import glob +import os import shutil +from pathlib import Path import h5py import pytest -from pathlib import Path from guppy.testing.api import step2, step3, step4 + @pytest.mark.parametrize( "session_subdir, storenames_map, expected_region, expected_ttl", [ @@ -93,6 +94,7 @@ def test_step4(tmp_path, monkeypatch, session_subdir, storenames_map, expected_r # Stub matplotlib.pyplot.show to avoid GUI blocking import matplotlib.pyplot as plt # noqa: F401 + monkeypatch.setattr("matplotlib.pyplot.show", lambda *args, **kwargs: None) # Stage a clean copy of the session into a temporary workspace diff --git a/tests/test_step5.py b/tests/test_step5.py index 59ae5ba..b2a9257 100644 --- a/tests/test_step5.py +++ b/tests/test_step5.py @@ -1,10 +1,10 @@ -import os import glob +import os import shutil +from pathlib import Path -import pytest import pandas as pd -from pathlib import Path +import pytest from guppy.testing.api import step2, step3, step4, step5 @@ -96,6 +96,7 @@ def test_step5(tmp_path, monkeypatch, session_subdir, storenames_map, expected_r # Stub matplotlib.pyplot.show to avoid GUI blocking (used in earlier steps) import matplotlib.pyplot as plt # noqa: F401 + monkeypatch.setattr("matplotlib.pyplot.show", lambda *args, **kwargs: None) # Stage a clean copy of the session into a temporary workspace @@ -140,7 +141,9 @@ def test_step5(tmp_path, monkeypatch, session_subdir, storenames_map, expected_r # Expected PSTH outputs (defaults compute z_score PSTH) psth_h5 = os.path.join(out_dir, f"{expected_ttl}_{expected_region}_z_score_{expected_region}.h5") - psth_baseline_uncorr_h5 = os.path.join(out_dir, f"{expected_ttl}_{expected_region}_baselineUncorrected_z_score_{expected_region}.h5") + psth_baseline_uncorr_h5 = os.path.join( + out_dir, f"{expected_ttl}_{expected_region}_baselineUncorrected_z_score_{expected_region}.h5" + ) peak_auc_h5 = os.path.join(out_dir, f"peak_AUC_{expected_ttl}_{expected_region}_z_score_{expected_region}.h5") peak_auc_csv = os.path.join(out_dir, f"peak_AUC_{expected_ttl}_{expected_region}_z_score_{expected_region}.csv")