Skip to content

Commit b5e7385

Browse files
authored
global backend instance cache (#1691)
1 parent 18cd683 commit b5e7385

File tree

29 files changed

+124
-106
lines changed

29 files changed

+124
-106
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ jobs:
103103
test-suite: "unit_tests/!(dynamics)"
104104
fail-fast: false
105105
runs-on: ${{ matrix.platform }}
106-
timeout-minutes: ${{ startsWith(matrix.platform, 'windows-') && 40 || 30 }}
106+
timeout-minutes: ${{ startsWith(matrix.platform, 'ubuntu-') && 25 || 40 }}
107107
steps:
108108
- uses: actions/[email protected]
109109
with:
@@ -187,7 +187,7 @@ jobs:
187187
test-suite: [ "chemistry_extraterrestrial", "freezing", "isotopes", "condensation_a", "condensation_b", "condensation_c", "coagulation", "breakup", "multi-process_a", "multi-process_b", "multi-process_c", "multi-process_d", "multi-process_e"]
188188
fail-fast: false
189189
runs-on: ${{ matrix.platform }}
190-
timeout-minutes: ${{ startsWith(matrix.platform, 'windows-') && 40 || 30 }}
190+
timeout-minutes: ${{ startsWith(matrix.platform, 'ubuntu-') && 20 || 35 }}
191191
steps:
192192
- uses: actions/[email protected]
193193
with:

PySDM/backends/__init__.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
"""
2-
Backend classes: CPU=`PySDM.backends.numba.Numba`
3-
and GPU=`PySDM.backends.thrust_rtc.ThrustRTC`
2+
Number-crunching backends
43
"""
54

65
import ctypes
6+
from functools import partial
77
import os
88
import sys
99
import warnings
1010

1111
from numba import cuda
1212

13-
from .numba import Numba
13+
from . import numba as _numba
14+
15+
# for pdoc
16+
CPU = None
17+
GPU = None
18+
Numba = _numba.Numba
19+
ThrustRTC = None
1420

1521

1622
# https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
@@ -72,12 +78,18 @@ def __call__(self, storage):
7278

7379
ThrustRTC.Random = Random
7480

75-
CPU = Numba
76-
"""
77-
alias for Numba
78-
"""
81+
_BACKEND_CACHE = {}
7982

80-
GPU = ThrustRTC
81-
"""
82-
alias for ThrustRTC
83-
"""
83+
84+
def _cached_backend(formulae=None, backend_class=None, **kwargs):
85+
key = backend_class.__name__ + ":" + str(formulae) + ":" + str(kwargs)
86+
if key not in _BACKEND_CACHE:
87+
_BACKEND_CACHE[key] = backend_class(formulae=formulae, **kwargs)
88+
return _BACKEND_CACHE[key]
89+
90+
91+
CPU = partial(_cached_backend, backend_class=Numba)
92+
""" returns a cached instance of the Numba backend (cache key including formulae parameters) """
93+
94+
GPU = partial(_cached_backend, backend_class=ThrustRTC)
95+
""" returns a cached instance of the ThrustRTC backend (cache key including formulae parameters) """

PySDM/backends/numba.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class Numba( # pylint: disable=too-many-ancestors,duplicate-code
3636

3737
default_croupier = "local"
3838

39-
def __init__(self, formulae=None, double_precision=True, override_jit_flags=None):
39+
def __init__(
40+
self, formulae=None, *, double_precision=True, override_jit_flags=None
41+
):
4042
if not double_precision:
4143
raise NotImplementedError()
4244
self.formulae = formulae or Formulae()

PySDM/backends/thrust_rtc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class ThrustRTC( # pylint: disable=duplicate-code,too-many-ancestors
4747
default_croupier = "global"
4848

4949
def __init__(
50-
self, formulae=None, double_precision=False, debug=False, verbose=False
50+
self, formulae=None, *, double_precision=False, debug=False, verbose=False
5151
):
5252
self.formulae = formulae or Formulae()
5353

PySDM/initialisation/init_fall_momenta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77

88
from PySDM.dynamics.terminal_velocity import GunnKinzer1949
9-
from PySDM.formulae import Formulae
109
from PySDM.particulator import Particulator
1110

1211

@@ -32,7 +31,7 @@ def init_fall_momenta(
3231

3332
from PySDM.backends import CPU # pylint: disable=import-outside-toplevel
3433

35-
particulator = Particulator(0, CPU(Formulae())) # TODO #1155
34+
particulator = Particulator(0, CPU()) # TODO #1155
3635

3736
approximation = terminal_velocity_approx(particulator=particulator)
3837

docs/markdown/pysdm_landing.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,8 +663,8 @@ mindmap
663663
...
664664
(...)
665665
((backends))
666-
CPU
667-
GPU
666+
Numba
667+
ThrustRTC
668668
((dynamics))
669669
AqueousChemistry
670670
Collision

examples/PySDM_examples/Arabas_and_Shima_2017/simulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
import PySDM.products as PySDM_products
4-
from PySDM.backends import CPU
4+
from PySDM.backends import Numba
55
from PySDM.builder import Builder
66
from PySDM.dynamics import AmbientThermodynamics, Condensation
77
from PySDM.environments import Parcel
@@ -10,7 +10,7 @@
1010

1111

1212
class Simulation:
13-
def __init__(self, settings, backend=CPU):
13+
def __init__(self, settings, backend=Numba):
1414
t_half = settings.z_half / settings.w_avg
1515

1616
dt_output = (2 * t_half) / settings.n_output
@@ -23,7 +23,7 @@ def __init__(self, settings, backend=CPU):
2323
formulae=settings.formulae,
2424
**(
2525
{"override_jit_flags": {"parallel": False}}
26-
if backend == CPU
26+
if backend is Numba
2727
else {}
2828
),
2929
),

examples/PySDM_examples/Arabas_et_al_2015/example_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import PySDM.backends.impl_numba.conf
88
from PySDM import Formulae
9-
from PySDM.backends import CPU, GPU
9+
from PySDM.backends import Numba, ThrustRTC
1010
from PySDM.products import WallTime
1111

1212

@@ -42,11 +42,11 @@ def main():
4242
n_sd = range(14, 16, 1)
4343

4444
times = {}
45-
backends = [(CPU, "sync"), (CPU, "async")]
46-
if GPU.ENABLE:
47-
backends.append((GPU, "async"))
45+
backends = [(Numba, "sync"), (Numba, "async")]
46+
if ThrustRTC.ENABLE:
47+
backends.append((ThrustRTC, "async"))
4848
for backend, mode in backends:
49-
if backend is CPU:
49+
if backend is Numba:
5050
PySDM.backends.impl_numba.conf.NUMBA_PARALLEL = mode
5151
reload_cpu_backend()
5252
key = f"{backend} (mode={mode})"

examples/PySDM_examples/Berry_1967/figs_5_8_10.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"source": [
5252
"import os\n",
5353
"from numpy import errstate\n",
54-
"from PySDM.backends import CPU, GPU\n",
54+
"from PySDM.backends import CPU, GPU, ThrustRTC\n",
5555
"from PySDM.dynamics.collisions.collision_kernels import Geometric, Hydrodynamic, Electric\n",
5656
"from PySDM_examples.Berry_1967.spectrum_plotter import SpectrumPlotter\n",
5757
"from PySDM_examples.Berry_1967.settings import Settings\n",
@@ -2222,7 +2222,7 @@
22222222
"smooth = widgets.Checkbox(value=True, description='smooth plot')\n",
22232223
"gpu = widgets.Checkbox(value=False, description='GPU')\n",
22242224
"options = [adaptive, smooth]\n",
2225-
"if GPU.ENABLE:\n",
2225+
"if ThrustRTC.ENABLE:\n",
22262226
" options.append(gpu)\n",
22272227
"kernel = widgets.Select(\n",
22282228
" options=['geometric sweep-out', 'electric field 3000V/cm', 'hydrodynamic capture'],\n",

examples/PySDM_examples/Bulenok_2023_MasterThesis/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def go_benchmark(
9595
backend_configs.append((GPU, None))
9696

9797
for backend_class, n_threads in backend_configs:
98-
backend_name = backend_class.__name__
98+
backend_name = backend_class().__class__.__name__
9999
if n_threads:
100100
numba.set_num_threads(n_threads)
101101
backend_name += "_" + str(numba.get_num_threads())

0 commit comments

Comments
 (0)