Source code for models.bold_model_pytorch

import torch
import numpy as np
import time


[docs]class BOLD: """ Ballon model, transforms the fire rate (e.g., generated from :std:ref:`block`) to BOLD signal through a approximate convolutional dynamic equation. More detail, ref `here <https://www.sciencedirect.com/science/article/pii/S105381190090630X>`_. """ def __init__(self, epsilon, tao_s, tao_f, tao_0, alpha, E_0, V_0, delta_t=1e-3, init_f_in=None, init_s=None, init_v=None, init_q=None): self.epsilon = epsilon self.tao_s = tao_s self.tao_f = tao_f self.tao_0 = tao_0 self.E_0 = E_0 self.V_0 = V_0 self.delta_t = delta_t self.div_alpha = 1 / alpha self.f_in = init_f_in self.s = init_s self.v = init_v self.q = init_q def update(self, f_str, df): f = self.__getattribute__(f_str) if f is None: self.__setattr__(f_str, df * self.delta_t) else: f += df * self.delta_t def state_update(self, w): if isinstance(w, np.ndarray): w = torch.from_numpy(w.astype(np.float)).cuda() self.s = w[:, :, 0].reshape(-1) self.q = torch.max(w[:, :, 1], torch.tensor([1e-05]).type_as(w)).reshape(-1) self.v = torch.max(w[:, :, 2], torch.Tensor([1e-05]).type_as(w)).reshape(-1) # self.f_in = torch.min(torch.max(w[:, :, 3], torch.Tensor([-15]).type_as(w)), torch.Tensor([10]).type_as(w)).reshape(-1) self.f_in = torch.max(w[:, :, 3], torch.Tensor([1e-05]).type_as(w)).reshape( -1) # usualy setting in zhangwenyong code def save_temp_state(self, flag: int, write_path: str): state = torch.stack([self.s, self.q, self.v, self.f_in], dim=1) state = torch.from_numpy(state.cpu().numpy()) if write_path[-1] == "/": write_path = write_path[:-1] np.save(write_path + f"/state_{flag}.npy", state) def load_temp_state(self, flag: int, load_path: str): if load_path[-1] == "/": load_path = load_path[:-1] state = np.load(load_path + f"/bold_temp_state_{flag}.npy") state = torch.from_numpy(state).cuda() self.s = state[:, 0] self.q = state[:, 1] self.v = state[:, 2] self.f_in = state[:, 3]
[docs] def run(self, u): """ the main method in this class to evolve this balloon model. Parameters ---------- u : ndarray spike activity. Returns ------- bold: ndarray the shape is the same as input. """ if isinstance(u, np.ndarray): u = torch.from_numpy(u.astype(np.float)).cuda() if self.s is None: self.s = torch.zeros_like(u) if self.q is None: self.q = torch.zeros_like(u) if self.v is None: self.v = torch.ones_like(u) if self.f_in is None: self.f_in = torch.ones_like(u) d_s = self.epsilon * u - self.s / self.tao_s - (self.f_in - 1) / self.tao_f q_part = torch.where(self.f_in > 0, 1 - (1 - self.E_0) ** (1 / self.f_in), torch.ones_like(self.f_in)) self.update('q', (self.f_in * q_part / self.E_0 - self.q * self.v ** (self.div_alpha - 1)) / self.tao_0) self.update('v', (self.f_in - self.v ** self.div_alpha) / self.tao_0) self.update('f_in', self.s) self.f_in = torch.where(self.f_in > 0, self.f_in, torch.zeros_like(self.f_in)) self.update('s', d_s) out = self.V_0 * (7 * self.E_0 * (1 - self.q) + 2 * (1 - self.q / self.v) + (2 * self.E_0 - 0.2) * (1 - self.v)) return out