Using masks and computing weighted average#

This example is based from xarray example

Import python packages#

import xarray as xr
import intake
import matplotlib.pyplot as plt
import as ccrs
import numpy as np
%matplotlib inline
cat_url = ""
col = intake.open_esm_datastore(cat_url)

pangeo-cmip6 catalog with 7674 dataset(s) from 514818 asset(s):

activity_id 18
institution_id 36
source_id 88
experiment_id 170
member_id 657
table_id 37
variable_id 700
grid_label 10
zstore 514818
dcpp_init_year 60
version 736
derived_variable_id 0

Search data#

cat =['NorESM2-LM'], experiment_id=['historical'], table_id=['Amon'], variable_id=['tas'], member_id=['r1i1p1f1'])
activity_id institution_id source_id experiment_id member_id table_id variable_id grid_label zstore dcpp_init_year version
0 CMIP NCC NorESM2-LM historical r1i1p1f1 Amon tas gn gs://cmip6/CMIP6/CMIP/NCC/NorESM2-LM/historica... NaN 20190815

Create dictionary from the list of datasets we found#

  • This step may take several minutes so be patient!

dset_dict = cat.to_dataset_dict(zarr_kwargs={'use_cftime':True})
--> The keys in the returned dictionary of datasets are constructed as follows:
100.00% [1/1 00:12<00:00]
dset = dset_dict[list(dset_dict.keys())[0]]
<xarray.Dataset> Size: 110MB
Dimensions:         (lat: 96, bnds: 2, lon: 144, member_id: 1,
                     dcpp_init_year: 1, time: 1980)
    height          float64 8B ...
  * lat             (lat) float64 768B -90.0 -88.11 -86.21 ... 86.21 88.11 90.0
    lat_bnds        (lat, bnds) float64 2kB dask.array<chunksize=(96, 2), meta=np.ndarray>
  * lon             (lon) float64 1kB 0.0 2.5 5.0 7.5 ... 352.5 355.0 357.5
    lon_bnds        (lon, bnds) float64 2kB dask.array<chunksize=(144, 2), meta=np.ndarray>
  * time            (time) object 16kB 1850-01-16 12:00:00 ... 2014-12-16 12:...
    time_bnds       (time, bnds) object 32kB dask.array<chunksize=(1980, 2), meta=np.ndarray>
  * member_id       (member_id) object 8B 'r1i1p1f1'
  * dcpp_init_year  (dcpp_init_year) float64 8B nan
Dimensions without coordinates: bnds
Data variables:
    tas             (member_id, dcpp_init_year, time, lat, lon) float32 109MB dask.array<chunksize=(1, 1, 990, 96, 144), meta=np.ndarray>
Attributes: (12/65)
    Conventions:                      CF-1.7 CMIP-6.2
    activity_id:                      CMIP
    branch_method:                    Hybrid-restart from year 1600-01-01 of ...
    branch_time:                      0.0
    branch_time_in_child:             0.0
    branch_time_in_parent:            430335.0
    ...                               ...
    intake_esm_attrs:variable_id:     tas
    intake_esm_attrs:grid_label:      gn
    intake_esm_attrs:zstore:          gs://cmip6/CMIP6/CMIP/NCC/NorESM2-LM/hi...
    intake_esm_attrs:version:         20190815
    intake_esm_attrs:_data_format_:   zarr

Plot the first timestep

projection = ccrs.Mercator(central_longitude=-10)

f, ax = plt.subplots(subplot_kw=dict(projection=projection))

dset['tas'].isel(time=0).plot(transform=ccrs.PlateCarree(), cbar_kwargs=dict(shrink=0.7), cmap='coolwarm')
Compute weighted mean#

  1. Creating weights: for a rectangular grid the cosine of the latitude is proportional to the grid cell area.

  2. Compute weighted mean values

def computeWeightedMean(ds):
    # Compute weights based on the xarray you pass
    weights = np.cos(np.deg2rad( = "weights"
    # Compute weighted mean
    air_weighted = ds.weighted(weights)
    weighted_mean = air_weighted.mean(("lon", "lat"))
    return weighted_mean

Compute weighted average over the entire globe#

weighted_mean = computeWeightedMean(dset)

Comparison with unweighted mean#

  • We select a time range

  • Note how the weighted mean temperature is higher than the unweighted.

weighted_mean['tas'].sel(time=slice('2000-01-01', '2010-01-01')).plot(label="weighted")
dset['tas'].sel(time=slice('2000-01-01', '2010-01-01')).mean(("lon", "lat")).plot(label="unweighted")

Compute Weigted arctic average#

Let’s try to also take only the data above 60$^\circ$

weighted_mean = computeWeightedMean(dset.where(dset['lat']>60.))
weighted_mean['tas'].sel(time=slice('2000-01-01', '2010-01-01')).plot(label="weighted")
dset['tas'].where(dset['lat']>60.).sel(time=slice('2000-01-01', '2010-01-01')).mean(("lon", "lat")).plot(label="unweighted")

