Source code for CPAC.nuisance.utils.compcor

# Copyright (C) 2019-2024  C-PAC Developers

# This file is part of C-PAC.

# C-PAC is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.

# C-PAC is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.

# You should have received a copy of the GNU Lesser General Public
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
import os

import numpy as np
import nibabel as nib
from nibabel.filebasedimages import ImageFileError
from scipy import signal
from scipy.linalg import svd

from CPAC.utils import safe_shape
from CPAC.utils.monitoring import IFLOGGER


def calc_compcor_components(data_filename, num_components, mask_filename):
    if num_components < 1:
        msg = f"Improper value for num_components ({num_components}), should be >= 1."
        raise ValueError(msg)

    try:
        image_data = nib.load(data_filename).get_fdata().astype(np.float64)
    except (ImageFileError, MemoryError, OSError, TypeError, ValueError) as e:
        msg = f"Unable to load data from {data_filename}"
        raise ImageFileError(msg) from e

    try:
        binary_mask = nib.load(mask_filename).get_fdata().astype(np.int16)
    except (ImageFileError, MemoryError, OSError, TypeError, ValueError) as e:
        msg = f"Unable to load data from {mask_filename}"
        raise ImageFileError(msg) from e

    if not safe_shape(image_data, binary_mask):
        msg = (
            f"The data in {data_filename} and {mask_filename} do not have a"
            " consistent shape"
        )
        raise ValueError(msg)

    # make sure that the values in binary_mask are binary
    binary_mask[binary_mask > 0] = 1
    binary_mask[binary_mask != 1] = 0

    # reduce the image data to only the voxels in the binary mask
    image_data = image_data[binary_mask == 1, :]

    # filter out any voxels whose variance equals 0
    IFLOGGER.info("Removing zero variance components")
    image_data = image_data[image_data.std(1) != 0, :]

    if image_data.shape.count(0):
        err = (
            "\n\n[!] No wm or csf signals left after removing those "
            "with zero variance.\n\n"
        )
        raise Exception(err)

    IFLOGGER.info("Detrending and centering data")
    Y = signal.detrend(image_data, axis=1, type="linear").T
    Yc = Y - np.tile(Y.mean(0), (Y.shape[0], 1))
    Yc = Yc / np.tile(np.array(Yc.std(0)).reshape(1, Yc.shape[1]), (Yc.shape[0], 1))
    IFLOGGER.info("Calculating SVD decomposition of Y*Y'")
    U, S, Vh = np.linalg.svd(Yc, full_matrices=False)

    # write out the resulting regressor file
    regressor_file = os.path.join(os.getcwd(), "compcor_regressors.1D")
    np.savetxt(regressor_file, U[:, :num_components], delimiter="\t", fmt="%16g")

    return regressor_file


[docs] def cosine_filter( input_image_path, timestep, period_cut=128, remove_mean=True, axis=-1, failure_mode="error", ): """ `cosine_filter` adapted from Nipype. https://github.com/nipy/nipype/blob/d353f0d/nipype/algorithms/confounds.py#L1086-L1107 Parameters ---------- input_image_path : string Bold image to be filtered. timestep : float 'Repetition time (TR) of series (in sec) - derived from image header if unspecified' period_cut : float Minimum period (in sec) for DCT high-pass filter, nipype default value: 128. """ # STATEMENT OF CHANGES: # This function is derived from sources licensed under the Apache-2.0 terms, # and this function has been changed. # CHANGES: # * Refactored to take and return filepaths instead of loaded data # * Removed caluclation and return of `non_constant_regressors` # * Modified docstring to reflect local changes # * Updated style to match C-PAC codebase # ORIGINAL WORK'S ATTRIBUTION NOTICE: # Copyright (c) 2009-2016, Nipype developers # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Prior to release 0.12, Nipype was licensed under a BSD license. # Modifications copyright (C) 2019 - 2024 C-PAC Developers from nipype.algorithms.confounds import _cosine_drift, _full_rank input_img = nib.load(input_image_path) input_data = input_img.get_fdata() datashape = input_data.shape timepoints = datashape[axis] if datashape[0] == 0 and failure_mode != "error": return input_data, np.array([]) input_data = input_data.reshape((-1, timepoints)) frametimes = timestep * np.arange(timepoints) X = _full_rank(_cosine_drift(period_cut, frametimes))[0] betas = np.linalg.lstsq(X, input_data.T)[0] if not remove_mean: X = X[:, :-1] betas = betas[:-1] residuals = input_data - X.dot(betas).T output_data = residuals.reshape(datashape) hdr = input_img.header output_img = nib.Nifti1Image(output_data, header=hdr, affine=input_img.affine) file_name = input_image_path[input_image_path.rindex("/") + 1 :] cosfiltered_img = os.path.join(os.getcwd(), file_name) output_img.to_filename(cosfiltered_img) return cosfiltered_img
def fallback_svd(a, full_matrices=True, compute_uv=True): try: return np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) except np.linalg.LinAlgError: pass return svd( a, full_matrices=full_matrices, compute_uv=compute_uv, lapack_driver="gesvd" ) def TR_string_to_float(tr): """ Convert TR string to seconds (float). Suffixes 's' or 'ms' to indicate seconds or milliseconds. Parameters ---------- tr : TR string representation. May use suffixes 's' or 'ms' to indicate seconds or milliseconds. Returns ------- tr in seconds (float) """ if not isinstance(tr, str): msg = f"Improper type for TR_string_to_float ({tr})." raise TypeError(msg) tr_str = tr.replace(" ", "") try: if tr_str.endswith("ms"): tr_numeric = float(tr_str[:-2]) * 0.001 elif tr.endswith("s"): tr_numeric = float(tr_str[:-1]) else: tr_numeric = float(tr_str) except Exception as exc: msg = f'Can not convert TR string to float: "{tr}".' raise ValueError(msg) from exc return tr_numeric