Brief introduction to Dask#
Dask is a library that makes it relatively easy to perform parallel and distributed computations in python.
The key concept for us is that allows for computations that require more memory than available on your machine.
# First some imports.
import dask
import dask.array as da
import numpy as np
import xarray as xr
from dask.distributed import Client
from matplotlib import pyplot as plt
I like to start a client, since this allows me to control how many workers and memory it can use. And it starts the dashboard, which can give use some nice insights on our computations.
from dask.distributed import Client
client = Client()
client
Client
Client-06c2a091-1679-11ef-95d4-4efbcbc71aa6
Connection method: Cluster object | Cluster type: distributed.LocalCluster |
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/8787/status |
Cluster Info
LocalCluster
51649bec
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/8787/status | Workers: 4 |
Total threads: 16 | Total memory: 16.00 GiB |
Status: running | Using processes: True |
Scheduler Info
Scheduler
Scheduler-fe5e2768-ee8a-4e3b-b87b-5340ff36ac59
Comm: tcp://127.0.0.1:42435 | Workers: 4 |
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/8787/status | Total threads: 16 |
Started: Just now | Total memory: 16.00 GiB |
Workers
Worker: 0
Comm: tcp://127.0.0.1:36459 | Total threads: 4 |
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/36155/status | Memory: 4.00 GiB |
Nanny: tcp://127.0.0.1:46681 | |
Local directory: /tmp/dask-scratch-space/worker-j12jor3x |
Worker: 1
Comm: tcp://127.0.0.1:33879 | Total threads: 4 |
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/32817/status | Memory: 4.00 GiB |
Nanny: tcp://127.0.0.1:37427 | |
Local directory: /tmp/dask-scratch-space/worker-x3k19g2n |
Worker: 2
Comm: tcp://127.0.0.1:40011 | Total threads: 4 |
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/36419/status | Memory: 4.00 GiB |
Nanny: tcp://127.0.0.1:44065 | |
Local directory: /tmp/dask-scratch-space/worker-hiatit97 |
Worker: 3
Comm: tcp://127.0.0.1:46511 | Total threads: 4 |
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/32781/status | Memory: 4.00 GiB |
Nanny: tcp://127.0.0.1:40891 | |
Local directory: /tmp/dask-scratch-space/worker-xszymmhf |
Let’s say we have some large array we want to do some work on.
Here we’ll illustrate this with a (100000, 20000)
array with random numbers
rng = np.random.default_rng()
data_np = rng.standard_normal((100000, 25000))
Likely, your kernel is going to crash running this cell.
Let us instead try it with dask:
rng_da = da.random.default_rng()
data_da = rng_da.standard_normal((100000, 25000))
Have a look at the data, this is currently a lazy Dask array.
data_da
|
Then we can create the histogram computation
hist, bins = da.histogram(data_da, bins=20, range=(-3, 3))
Note that the histogram hasn’t been computed yet (it is lazy). Here we can also double-check that the result will fit in memory.
hist
|
To actually compute the result, we call the compute
method on the result.
hist = hist.compute()
Plot it to make sure it follows the normal distribution we would expect.
plt.bar(bins[:-1], hist)
<BarContainer object of 20 artists>
Xarray + Dask#
xarray is an interface to array data like Numpy, or often Dask.
Below we load some MERRA data.
Specifying chunks="auto"
is optional but it makes sure data is read as Dask arrays.
merra_ds = xr.open_mfdataset(
"/mnt/craas1-ns9989k-ns9600k/escience_course/MERRA2/MERRA2_300.inst3_3d_aer_Nv.200*.SUB.nc",
chunks="auto"
)
One of the variables of the dataset
merra_ds.SS001
<xarray.DataArray 'SS001' (time: 8768, lev: 3, lat: 121, lon: 576)> Size: 7GB dask.array<concatenate, shape=(8768, 3, 121, 576), dtype=float32, chunksize=(8, 3, 121, 576), chunktype=numpy.ndarray> Coordinates: * time (time) datetime64[ns] 70kB 2007-01-01 ... 2009-12-31T21:00:00 * lon (lon) float64 5kB -180.0 -179.4 -178.8 -178.1 ... 178.1 178.8 179.4 * lat (lat) float64 968B 30.0 30.5 31.0 31.5 32.0 ... 88.5 89.0 89.5 90.0 * lev (lev) float64 24B 56.0 63.0 67.0 Attributes: standard_name: Sea Salt Mixing Ratio (bin 001) long_name: Sea Salt Mixing Ratio (bin 001) units: kg kg-1 fmissing_value: 1000000000000000.0 vmax: 1000000000000000.0 vmin: -1000000000000000.0
Xarray recognizes that the data is in the form of a dask array. The following will work as we expect, and be “parallelized” in the background using dask (Look at the Progress in the dashboard).
mean_ds = merra_ds.SS001.mean(dim="time").compute()
mean_ds.isel(lev=0).plot()
<matplotlib.collections.QuadMesh at 0x7f7e8931a0d0>
Working with chunks#
But some computations are a bit more tricky. This won’t work.
merra_ds.SS001.quantile(0.85, dim="time")
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[16], line 1
----> 1 merra_ds.SS001.quantile(0.85, dim="time")
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
111 kwargs.update({name: arg for name, arg in zip_args})
113 return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/dataarray.py:5187, in DataArray.quantile(self, q, dim, method, keep_attrs, skipna, interpolation)
5077 @_deprecate_positional_args("v2023.10.0")
5078 def quantile(
5079 self,
(...)
5086 interpolation: QuantileMethods | None = None,
5087 ) -> Self:
5088 """Compute the qth quantile of the data along the specified dimension.
5089
5090 Returns the qth quantiles(s) of the array elements.
(...)
5184 The American Statistician, 50(4), pp. 361-365, 1996
5185 """
-> 5187 ds = self._to_temp_dataset().quantile(
5188 q,
5189 dim=dim,
5190 keep_attrs=keep_attrs,
5191 method=method,
5192 skipna=skipna,
5193 interpolation=interpolation,
5194 )
5195 return self._from_temp_dataset(ds)
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
111 kwargs.update({name: arg for name, arg in zip_args})
113 return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/dataset.py:8183, in Dataset.quantile(self, q, dim, method, numeric_only, keep_attrs, skipna, interpolation)
8177 if name not in self.coords:
8178 if (
8179 not numeric_only
8180 or np.issubdtype(var.dtype, np.number)
8181 or var.dtype == np.bool_
8182 ):
-> 8183 variables[name] = var.quantile(
8184 q,
8185 dim=reduce_dims,
8186 method=method,
8187 keep_attrs=keep_attrs,
8188 skipna=skipna,
8189 )
8191 else:
8192 variables[name] = var
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/variable.py:1907, in Variable.quantile(self, q, dim, method, keep_attrs, skipna, interpolation)
1903 axis = np.arange(-1, -1 * len(dim) - 1, -1)
1905 kwargs = {"q": q, "axis": axis, "method": method}
-> 1907 result = apply_ufunc(
1908 _wrapper,
1909 self,
1910 input_core_dims=[dim],
1911 exclude_dims=set(dim),
1912 output_core_dims=[["quantile"]],
1913 output_dtypes=[np.float64],
1914 dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
1915 dask="parallelized",
1916 kwargs=kwargs,
1917 )
1919 # for backward compatibility
1920 result = result.transpose("quantile", ...)
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/computation.py:1280, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
1278 # feed Variables directly through apply_variable_ufunc
1279 elif any(isinstance(a, Variable) for a in args):
-> 1280 return variables_vfunc(*args)
1281 else:
1282 # feed anything else through apply_array_ufunc
1283 return apply_array_ufunc(func, *args, dask=dask)
File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/computation.py:771, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
769 for axis, dim in enumerate(core_dims, start=-len(core_dims)):
770 if len(data.chunks[axis]) != 1:
--> 771 raise ValueError(
772 f"dimension {dim} on {n}th function argument to "
773 "apply_ufunc with dask='parallelized' consists of "
774 "multiple chunks, but is also a core dimension. To "
775 "fix, either rechunk into a single array chunk along "
776 f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or "
777 "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` "
778 "but beware that this may significantly increase memory usage."
779 )
780 dask_gufunc_kwargs["allow_rechunk"] = True
782 output_sizes = dask_gufunc_kwargs.pop("output_sizes", {})
ValueError: dimension time on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, either rechunk into a single array chunk along this dimension, i.e., ``.chunk(dict(time=-1))``, or pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` but beware that this may significantly increase memory usage.
If we read the error message, it says something about that the dimension on the 0th function argument to apply_ufunc
consists of multiple chunks.
This has to do with that there is no parallel implementation of the quantile algorithm — it needs to see all time steps.
To solve this, we can rechunk
our data.
In this case, we declare that the data should not be chunked along the time dimension, but can be chunked freely along other dimensions.
qtile = merra_ds.SS001.chunk({"time":-1, "lon": "auto", "lat": "auto", "lev": "auto"}).quantile(0.85, dim="time").compute()
2024-05-20 09:20:43,990 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 2.60 GiB -- Worker memory limit: 4.00 GiB
2024-05-20 09:20:44,768 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 2.43 GiB -- Worker memory limit: 4.00 GiB
2024-05-20 09:20:58,762 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 2.49 GiB -- Worker memory limit: 4.00 GiB
qtile.isel(lev=0).plot()
<matplotlib.collections.QuadMesh at 0x7f7e88316d90>
Accessing the dask array#
Not all operations available in Numpy/Dask are wrapped by xarray. For instance, to compute the histogram (not plotting) you could do something like this:
np.histogram(merra_ds.SS0001.values)
where calling values
always gives us the underlying Numpy array.
If the dataset is backed by Dask arrays, they will be loaded into memory — which is a problem if the data is too big.
Instead you could use the data
attribute of a dataset, which will return the Dask array:
merra_ds.SS001.data
|
With this we have an array we can pass to Dask array specific functions such as
hist, bins = da.histogram(merra_ds.SS001.data, bins=10, range=(0, merra_ds.SS001.max()))
bins
|
Then we have to compute the histogram
hist.compute()
array([1833234752, 47344, 1173, 202, 80,
24, 5, 2, 1, 1])
Numba#
Numba is a just-in-time compiler for python, which can speed up your slow python-loops, if they work on the supported data types (see documentation).
rng = np.random.default_rng()
test = rng.random((10000, 2))
test2 = rng.random((10000, 2))
The njit
decorator will try to compile a function the first time it is run.
It will raise an error if numba is not able to compile the function.
from numba import njit
Here, we have two identical functions that search for matching coordinate pairs in two relatively large arrays. In reality this could be real coordinates, now we just use random numbers.
def slow_coord_isin(ds_locs, merra_locs):
mask = np.zeros(ds_locs.shape[0])
for i, ds_coord in enumerate(ds_locs):
for merra_coord in merra_locs:
if ds_coord[0] == merra_coord[0] and ds_coord[1] == merra_coord[1]:
mask[i] = 1
break
return mask
# Here we add the njit decorator.
@njit
def fast_coord_isin(ds_locs, merra_locs):
mask = np.zeros(ds_locs.shape[0])
for i, ds_coord in enumerate(ds_locs):
for merra_coord in merra_locs:
if ds_coord[0] == merra_coord[0] and ds_coord[1] == merra_coord[1]:
mask[i] = 1
break
return mask
Use the %%time
cell magic to time the functions.
For more robust timings use the %%timeit
cell magic.
%%time
slow_coord_isin(test, test)
CPU times: user 10.9 s, sys: 72.2 ms, total: 10.9 s
Wall time: 10.9 s
array([1., 1., 1., ..., 1., 1., 1.])
%%time
fast_coord_isin(test, test)
CPU times: user 639 ms, sys: 60.8 ms, total: 700 ms
Wall time: 690 ms
array([1., 1., 1., ..., 1., 1., 1.])