diff --git a/.travis.yml b/.travis.yml index 22c97253..becaab62 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ install: - pip install -r requirements.txt - pip install pytest flake8 codecov pytest-cov wheel setuptools - pip install jupyterlab nbconvert nbformat matplotlib seaborn mne + - pip install ray || true - python setup.py sdist bdist_wheel - pip install --upgrade dist/* diff --git a/neurolib/models/model.py b/neurolib/models/model.py index e6ce622d..d75e6abf 100644 --- a/neurolib/models/model.py +++ b/neurolib/models/model.py @@ -7,6 +7,20 @@ from ..utils.collections import dotdict +RAY_NOT_FOUND = None +try: + import ray +except ImportError: + RAY_NOT_FOUND = True + logging.warning("`ray` module not found. BOLD simulation won't be run in parallel.") + + class Object(object): + pass + + ray = Object() + ray.remote = lambda f: f + + class Model: """The Model superclass manages inputs and outputs of all models. """ @@ -23,6 +37,7 @@ def __init__( bold=False, normalize_bold_input=False, normalize_bold_input_max=50, + parallelize_bold=True, name=None, description=None, ): @@ -56,9 +71,22 @@ def __init__( self.bold_initialized = False self.normalize_bold_input = normalize_bold_input self.normalize_bold_input_max = normalize_bold_input_max + if bold: self.initialize_bold(self.normalize_bold_input, self.normalize_bold_input_max) + self.parallelize_bold = parallelize_bold + if RAY_NOT_FOUND: + self.parallelize_bold = False + self.ray_initialized = False + if self.parallelize_bold and not RAY_NOT_FOUND: + try: + ray.shutdown() + ray.init() + self.ray_initialized = True + except: + logging.warning("`ray` module initialization failed, falling back to serial BOLD simulation.") + self.parallelize_bold = False logging.info(f"{name}: Model initialized.") def initialize_bold(self, normalize_bold_input, normalize_bold_input_max): @@ -80,16 +108,20 @@ def initialize_bold(self, normalize_bold_input, normalize_bold_input_max): self.bold_initialized = True logging.info(f"{self.name}: BOLD model initialized.") + @ray.remote + def bold_ray(self, bold_model): + bold_model.run(self.state[self.default_output]) + return bold_model + def simulate_bold(self): """Gets the default output of the model and simulates the BOLD model. Adds the simulated BOLD signal to outputs. """ if self.bold_initialized: - self.boldModel.run(self.state[self.default_output]) - t_BOLD = self.boldModel.t_BOLD - BOLD = self.boldModel.BOLD - self.setOutput("BOLD.t", t_BOLD) - self.setOutput("BOLD.BOLD", BOLD) + if self.parallelize_bold and self.ray_initialized: + self.boldModel = self.bold_ray.remote(self, self.boldModel) + else: + self.boldModel.run(self.state[self.default_output]) else: logging.warn("BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`") @@ -142,16 +174,25 @@ def run(self, inputs=None, onedt=False, chunkwise=False, chunksize=10000, bold=F if chunkwise is False: self.integrate() if bold: + # run the bold simulation. Will be serial or parallel, depending on self.parallelize_bold self.simulate_bold() - return else: # check if model is safe for chunkwise integration self.check_chunkwise() if bold and not self.bold_initialized: logging.warn(f"{self.name}: BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`") bold = False + # run the chunkwise integrator, will run bold simulation internatlly if bold==True self.integrate_chunkwise(chunksize=chunksize, bold=bold, append_outputs=append_outputs) - return + + # gather all parallel simulation results (if self.parallelize_bold) and set BOLD outputs + if bold: + # when done with everyting, gather all parallel bold simulation results + if self.parallelize_bold and self.ray_initialized: + self.boldModel = ray.get(self.boldModel) + # set the BOLD outputs for easy access + self.setOutput("BOLD.t", self.boldModel.t_BOLD) + self.setOutput("BOLD.BOLD", self.boldModel.BOLD) def integrate(self, append_outputs=False): """Calls each models `integration` function and saves the state and the outputs of the model. @@ -200,6 +241,9 @@ def integrate_chunkwise(self, chunksize, bold=False, append_outputs=False): # we save the last simulated time step lastT += self.state["t"][-1] + # set the duration back to its original value + self.params["duration"] = totalDuration + def autochunk(self, inputs=None, duration=None, append_outputs=False): """Executes a single chunk of integration, either for a given duration or a single timestep `dt`. Gathers all inputs to the model and resets