Brief introduction to Dask#

Dask is a library that makes it relatively easy to perform parallel and distributed computations in python.

The key concept for us is that allows for computations that require more memory than available on your machine.

# First some imports.
import dask
import dask.array as da
import numpy as np
import xarray as xr

from dask.distributed import Client
from matplotlib import pyplot as plt

I like to start a client, since this allows me to control how many workers and memory it can use. And it starts the dashboard, which can give use some nice insights on our computations.

from dask.distributed import Client

client = Client()
client

Client

Client-06c2a091-1679-11ef-95d4-4efbcbc71aa6

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: /user/fc%3Auid%3A8be9434f-5560-4397-97c0-882449c50503/proxy/8787/status

Cluster Info

Info If you are running multiple notebooks at the same time, utilising dask, it is a good idea to start the client through the dask extension available on the left side menu. This enables you to run all computations on the same client.

Let’s say we have some large array we want to do some work on. Here we’ll illustrate this with a (100000, 20000) array with random numbers

rng = np.random.default_rng()
data_np = rng.standard_normal((100000, 25000))

Likely, your kernel is going to crash running this cell.

Let us instead try it with dask:

rng_da = da.random.default_rng()
data_da = rng_da.standard_normal((100000, 25000))

Have a look at the data, this is currently a lazy Dask array.

data_da
Array Chunk
Bytes 18.63 GiB 128.00 MiB
Shape (100000, 25000) (4096, 4096)
Dask graph 175 chunks in 1 graph layer
Data type float64 numpy.ndarray
25000 100000

Then we can create the histogram computation

hist, bins = da.histogram(data_da, bins=20, range=(-3, 3))

Note that the histogram hasn’t been computed yet (it is lazy). Here we can also double-check that the result will fit in memory.

hist
Array Chunk
Bytes 160 B 160 B
Shape (20,) (20,)
Dask graph 1 chunks in 7 graph layers
Data type int64 numpy.ndarray
20 1

To actually compute the result, we call the compute method on the result.

hist = hist.compute()

Plot it to make sure it follows the normal distribution we would expect.

plt.bar(bins[:-1], hist)
<BarContainer object of 20 artists>
../../../_images/f404c6c971b0f946d227498a8f92b151023d957be421314cfd18afaaf1ac0cfa.png

Xarray + Dask#

xarray is an interface to array data like Numpy, or often Dask.

Below we load some MERRA data. Specifying chunks="auto" is optional but it makes sure data is read as Dask arrays.

merra_ds = xr.open_mfdataset(
    "/mnt/craas1-ns9989k-ns9600k/escience_course/MERRA2/MERRA2_300.inst3_3d_aer_Nv.200*.SUB.nc",
    chunks="auto"
)

One of the variables of the dataset

merra_ds.SS001
<xarray.DataArray 'SS001' (time: 8768, lev: 3, lat: 121, lon: 576)> Size: 7GB
dask.array<concatenate, shape=(8768, 3, 121, 576), dtype=float32, chunksize=(8, 3, 121, 576), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 70kB 2007-01-01 ... 2009-12-31T21:00:00
  * lon      (lon) float64 5kB -180.0 -179.4 -178.8 -178.1 ... 178.1 178.8 179.4
  * lat      (lat) float64 968B 30.0 30.5 31.0 31.5 32.0 ... 88.5 89.0 89.5 90.0
  * lev      (lev) float64 24B 56.0 63.0 67.0
Attributes:
    standard_name:   Sea Salt Mixing Ratio (bin 001)
    long_name:       Sea Salt Mixing Ratio (bin 001)
    units:           kg kg-1
    fmissing_value:  1000000000000000.0
    vmax:            1000000000000000.0
    vmin:            -1000000000000000.0

Xarray recognizes that the data is in the form of a dask array. The following will work as we expect, and be “parallelized” in the background using dask (Look at the Progress in the dashboard).

mean_ds = merra_ds.SS001.mean(dim="time").compute()
mean_ds.isel(lev=0).plot()
<matplotlib.collections.QuadMesh at 0x7f7e8931a0d0>
../../../_images/5f3023c96acf14a7e04df1bd20904294920422f78648b00aa0c6db3cc5c3f802.png

Working with chunks#

But some computations are a bit more tricky. This won’t work.

merra_ds.SS001.quantile(0.85, dim="time")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[16], line 1
----> 1 merra_ds.SS001.quantile(0.85, dim="time")

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
    111     kwargs.update({name: arg for name, arg in zip_args})
    113     return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/dataarray.py:5187, in DataArray.quantile(self, q, dim, method, keep_attrs, skipna, interpolation)
   5077 @_deprecate_positional_args("v2023.10.0")
   5078 def quantile(
   5079     self,
   (...)
   5086     interpolation: QuantileMethods | None = None,
   5087 ) -> Self:
   5088     """Compute the qth quantile of the data along the specified dimension.
   5089 
   5090     Returns the qth quantiles(s) of the array elements.
   (...)
   5184        The American Statistician, 50(4), pp. 361-365, 1996
   5185     """
-> 5187     ds = self._to_temp_dataset().quantile(
   5188         q,
   5189         dim=dim,
   5190         keep_attrs=keep_attrs,
   5191         method=method,
   5192         skipna=skipna,
   5193         interpolation=interpolation,
   5194     )
   5195     return self._from_temp_dataset(ds)

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
    111     kwargs.update({name: arg for name, arg in zip_args})
    113     return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/dataset.py:8183, in Dataset.quantile(self, q, dim, method, numeric_only, keep_attrs, skipna, interpolation)
   8177     if name not in self.coords:
   8178         if (
   8179             not numeric_only
   8180             or np.issubdtype(var.dtype, np.number)
   8181             or var.dtype == np.bool_
   8182         ):
-> 8183             variables[name] = var.quantile(
   8184                 q,
   8185                 dim=reduce_dims,
   8186                 method=method,
   8187                 keep_attrs=keep_attrs,
   8188                 skipna=skipna,
   8189             )
   8191 else:
   8192     variables[name] = var

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/variable.py:1907, in Variable.quantile(self, q, dim, method, keep_attrs, skipna, interpolation)
   1903 axis = np.arange(-1, -1 * len(dim) - 1, -1)
   1905 kwargs = {"q": q, "axis": axis, "method": method}
-> 1907 result = apply_ufunc(
   1908     _wrapper,
   1909     self,
   1910     input_core_dims=[dim],
   1911     exclude_dims=set(dim),
   1912     output_core_dims=[["quantile"]],
   1913     output_dtypes=[np.float64],
   1914     dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
   1915     dask="parallelized",
   1916     kwargs=kwargs,
   1917 )
   1919 # for backward compatibility
   1920 result = result.transpose("quantile", ...)

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/computation.py:1280, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
   1278 # feed Variables directly through apply_variable_ufunc
   1279 elif any(isinstance(a, Variable) for a in args):
-> 1280     return variables_vfunc(*args)
   1281 else:
   1282     # feed anything else through apply_array_ufunc
   1283     return apply_array_ufunc(func, *args, dask=dask)

File /opt/conda/envs/pangeo-notebook/lib/python3.11/site-packages/xarray/core/computation.py:771, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    769             for axis, dim in enumerate(core_dims, start=-len(core_dims)):
    770                 if len(data.chunks[axis]) != 1:
--> 771                     raise ValueError(
    772                         f"dimension {dim} on {n}th function argument to "
    773                         "apply_ufunc with dask='parallelized' consists of "
    774                         "multiple chunks, but is also a core dimension. To "
    775                         "fix, either rechunk into a single array chunk along "
    776                         f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or "
    777                         "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` "
    778                         "but beware that this may significantly increase memory usage."
    779                     )
    780     dask_gufunc_kwargs["allow_rechunk"] = True
    782 output_sizes = dask_gufunc_kwargs.pop("output_sizes", {})

ValueError: dimension time on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, either rechunk into a single array chunk along this dimension, i.e., ``.chunk(dict(time=-1))``, or pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` but beware that this may significantly increase memory usage.

If we read the error message, it says something about that the dimension on the 0th function argument to apply_ufunc consists of multiple chunks. This has to do with that there is no parallel implementation of the quantile algorithm — it needs to see all time steps. To solve this, we can rechunk our data. In this case, we declare that the data should not be chunked along the time dimension, but can be chunked freely along other dimensions.

Note that rechunking adds some extra computations.
qtile = merra_ds.SS001.chunk({"time":-1, "lon": "auto", "lat": "auto", "lev": "auto"}).quantile(0.85, dim="time").compute()
2024-05-20 09:20:43,990 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 2.60 GiB -- Worker memory limit: 4.00 GiB
2024-05-20 09:20:44,768 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 2.43 GiB -- Worker memory limit: 4.00 GiB
2024-05-20 09:20:58,762 - distributed.worker.memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 2.49 GiB -- Worker memory limit: 4.00 GiB
qtile.isel(lev=0).plot()
<matplotlib.collections.QuadMesh at 0x7f7e88316d90>
../../../_images/2a4616dce41c38c87dc596fb0ecb687eb4ff63db37a342373678c2c6f373fa1a.png

Accessing the dask array#

Not all operations available in Numpy/Dask are wrapped by xarray. For instance, to compute the histogram (not plotting) you could do something like this:

np.histogram(merra_ds.SS0001.values)

where calling values always gives us the underlying Numpy array. If the dataset is backed by Dask arrays, they will be loaded into memory — which is a problem if the data is too big. Instead you could use the data attribute of a dataset, which will return the Dask array:

merra_ds.SS001.data
Array Chunk
Bytes 6.83 GiB 6.38 MiB
Shape (8768, 3, 121, 576) (8, 3, 121, 576)
Dask graph 1096 chunks in 2193 graph layers
Data type float32 numpy.ndarray
8768 1 576 121 3

With this we have an array we can pass to Dask array specific functions such as

hist, bins = da.histogram(merra_ds.SS001.data, bins=10, range=(0, merra_ds.SS001.max()))
bins
Array Chunk
Bytes 88 B 88 B
Shape (11,) (11,)
Dask graph 1 chunks in 2207 graph layers
Data type float64 numpy.ndarray
11 1

Then we have to compute the histogram

hist.compute()
array([1833234752,      47344,       1173,        202,         80,
               24,          5,          2,          1,          1])

Numba#

Numba is a just-in-time compiler for python, which can speed up your slow python-loops, if they work on the supported data types (see documentation).

rng = np.random.default_rng()
test = rng.random((10000, 2))
test2 = rng.random((10000, 2))

The njit decorator will try to compile a function the first time it is run. It will raise an error if numba is not able to compile the function.

from numba import njit

Here, we have two identical functions that search for matching coordinate pairs in two relatively large arrays. In reality this could be real coordinates, now we just use random numbers.

def slow_coord_isin(ds_locs, merra_locs):
    mask = np.zeros(ds_locs.shape[0])
    for i, ds_coord in enumerate(ds_locs):
        for merra_coord in merra_locs:
            if ds_coord[0] == merra_coord[0] and ds_coord[1] == merra_coord[1]:
                mask[i] = 1
                break
    return mask

# Here we add the njit decorator.
@njit
def fast_coord_isin(ds_locs, merra_locs):
    mask = np.zeros(ds_locs.shape[0])
    for i, ds_coord in enumerate(ds_locs):
        for merra_coord in merra_locs:
            if ds_coord[0] == merra_coord[0] and ds_coord[1] == merra_coord[1]:
                mask[i] = 1
                break
    return mask

Use the %%time cell magic to time the functions. For more robust timings use the %%timeit cell magic.

%%time
slow_coord_isin(test, test)
CPU times: user 10.9 s, sys: 72.2 ms, total: 10.9 s
Wall time: 10.9 s
array([1., 1., 1., ..., 1., 1., 1.])
%%time
fast_coord_isin(test, test)
CPU times: user 639 ms, sys: 60.8 ms, total: 700 ms
Wall time: 690 ms
array([1., 1., 1., ..., 1., 1., 1.])