{ "cells": [ { "cell_type": "markdown", "id": "d52e464a3b45c1ba", "metadata": {}, "source": [ "# Read CMIP-PPE data and emulate with Gaussian Processor" ] }, { "cell_type": "markdown", "id": "7f45a45c-df17-4238-b264-8b7a9aec1953", "metadata": {}, "source": [ "\n", "**Adjusted for the eScience course from Duncan Watson-Parris' example here: [gist.github.com/duncanwp](https://gist.github.com/duncanwp/89175a17b7221e4d3639765621c7f7f9)**" ] }, { "cell_type": "markdown", "id": "b58e60b8-51f6-4034-8bf3-526d5da67751", "metadata": {}, "source": [ "You have to use the env:ml-notebook to run this example. " ] }, { "cell_type": "code", "execution_count": 31, "id": "d98ac0ce-2f20-4e3d-8620-450325c76c25", "metadata": { "tags": [] }, "outputs": [], "source": [ "import xarray as xr\n", "import numpy as np\n", "import pandas as pd\n", "from esem import gp_model\n", "from esem.utils import validation_plot, get_param_mask\n", "from pathlib import Path\n", "xr.set_options(display_style='html')\n", "import intake\n", "import cftime\n", "import matplotlib.pyplot as plt\n", "import cartopy.crs as ccrs\n", "import datetime\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 32, "id": "82ad6e05-9c2c-4c63-bc19-91eedd386cfa", "metadata": {}, "outputs": [], "source": [ "def global_mean(ds):\n", " weights = np.cos(np.deg2rad(ds.lat))\n", " return ds.weighted(weights).mean(['lat', 'lon'])\n", "\n", "def get_ensemble_member(ds):\n", " fname = ds.encoding['source']\n", " member = int(fname.split('.')[-4])\n", " return ds.assign_coords(member=member).expand_dims('member')" ] }, { "cell_type": "markdown", "id": "fe2babd9-e578-4f7e-84e7-0b663f382b5c", "metadata": {}, "source": [ "## Open the overview over the parameters in the CAM6 CESM PPE" ] }, { "cell_type": "code", "execution_count": 33, "id": "92417e17-40ad-4f5b-82fb-3eb2f7939743", "metadata": {}, "outputs": [], "source": [ "data_path = Path('~/shared-craas1-ns9989k-ns9560k/CAM6_CESM_PPE/')\n", "\n", "params = (xr.open_dataset(data_path / \"parameter_262_w_control.nc\")\n", " .to_pandas()\n", " .drop(columns = ['Sample_nmb'])\n", " )\n" ] }, { "cell_type": "markdown", "id": "ad02937f-4434-47c3-b722-7b4fd03bc5c5", "metadata": {}, "source": [ "### Open CMIP6 online catalog" ] }, { "cell_type": "code", "execution_count": 34, "id": "263acb2b-8f49-48ed-acc7-072f91c4edfb", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "

cesm-ppe catalog with 2 dataset(s) from 32571 asset(s):

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
unique
experiment1
ensemble262
frequency2
variable124
units27
long_name124
vertical_levels3
start_time2
end_time3
time_range3
path32571
derived_variable0
\n", "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cat_url = '/mnt/craas1-ns9989k-geo4992/data/catalogs/cesm-ppe.json'\n", "col = intake.open_esm_datastore(cat_url)\n", "col" ] }, { "cell_type": "markdown", "id": "16661a25-147f-40e3-b96a-a3c170b3ffc9", "metadata": {}, "source": [ "\n", "### Search corresponding data " ] }, { "cell_type": "markdown", "id": "a94ff4c4-8f03-46db-b19c-be4abe6d6b29", "metadata": {}, "source": [ "Please check [here](https://pangeo-data.github.io/escience-2022/pangeo101/data_discovery.html?highlight=cmip6) for info about CMIP and variables :) \n", "\n", "Particularly useful is maybe the variable search which you find here: https://clipc-services.ceda.ac.uk/dreq/mipVars.html " ] }, { "cell_type": "code", "execution_count": 35, "id": "2549450a-4a44-4e04-bfda-1b8eeade4f27", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
experimentensemblefrequencyvariableunitslong_namevertical_levelsstart_timeend_timetime_rangepath
0present-day0.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
1present-day1.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
2present-day2.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
3present-day3.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
4present-day4.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
....................................
257present-day258.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
258present-day259.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
259present-day260.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
260present-day261.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
261present-day262.0monthlySWCFW/m2Shortwave cloud forcing1.00001-01-160003-12-160001-01-16-0003-12-16/mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m...
\n", "

262 rows × 11 columns

\n", "
" ], "text/plain": [ " experiment ensemble frequency variable units long_name \\\n", "0 present-day 0.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "1 present-day 1.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "2 present-day 2.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "3 present-day 3.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "4 present-day 4.0 monthly SWCF W/m2 Shortwave cloud forcing \n", ".. ... ... ... ... ... ... \n", "257 present-day 258.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "258 present-day 259.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "259 present-day 260.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "260 present-day 261.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "261 present-day 262.0 monthly SWCF W/m2 Shortwave cloud forcing \n", "\n", " vertical_levels start_time end_time time_range \\\n", "0 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "1 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "2 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "3 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "4 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", ".. ... ... ... ... \n", "257 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "258 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "259 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "260 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "261 1.0 0001-01-16 0003-12-16 0001-01-16-0003-12-16 \n", "\n", " path \n", "0 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "1 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "2 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "3 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "4 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", ".. ... \n", "257 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "258 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "259 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "260 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "261 /mnt/craas1-ns9989k-ns9560k/CAM6_CESM_PPE/PD/m... \n", "\n", "[262 rows x 11 columns]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cat = col.search(\n", " experiment=['present-day'], \n", " variable = ['SWCF'], \n", " frequency='monthly'\n", ")\n", "\n", "cat.df\n" ] }, { "cell_type": "code", "execution_count": 36, "id": "4befa75c-7f95-4e1d-b5a2-f604fe54d12f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['SWCF'], dtype=object)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cat.df['variable'].unique()" ] }, { "cell_type": "markdown", "id": "2d7858ef-8c0b-4d9d-9127-d1cef642b7ec", "metadata": {}, "source": [ "### Create dictionary from the list of datasets we found\n", "- This step may take several minutes so be patient!" ] }, { "cell_type": "code", "execution_count": 37, "id": "66132009-2292-4985-ad47-edf940999491", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "--> The keys in the returned dictionary of datasets are constructed as follows:\n", "\t'experiment.frequency'\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [1/1 00:25<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dset_dict = cat.to_dataset_dict()#preprocess = get_ensemble_member,)" ] }, { "cell_type": "code", "execution_count": 38, "id": "b469a9fd-6e3f-4880-b5c3-4857147a9833", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 2GB\n",
       "Dimensions:   (lat: 192, lon: 288, time: 36, ensemble: 262)\n",
       "Coordinates:\n",
       "  * lat       (lat) float64 2kB -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0\n",
       "  * lon       (lon) float64 2kB 0.0 1.25 2.5 3.75 ... 355.0 356.2 357.5 358.8\n",
       "  * time      (time) object 288B 0001-01-16 12:00:00 ... 0003-12-16 12:00:00\n",
       "  * ensemble  (ensemble) float64 2kB 0.0 1.0 2.0 3.0 ... 259.0 260.0 261.0 262.0\n",
       "Data variables:\n",
       "    SWCF      (ensemble, time, lat, lon) float32 2GB dask.array<chunksize=(1, 36, 192, 288), meta=np.ndarray>\n",
       "Attributes:\n",
       "    intake_esm_vars:                   ['SWCF']\n",
       "    intake_esm_attrs:experiment:       present-day\n",
       "    intake_esm_attrs:frequency:        monthly\n",
       "    intake_esm_attrs:variable:         SWCF\n",
       "    intake_esm_attrs:units:            W/m2\n",
       "    intake_esm_attrs:long_name:        Shortwave cloud forcing\n",
       "    intake_esm_attrs:vertical_levels:  1.0\n",
       "    intake_esm_attrs:start_time:       0001-01-16\n",
       "    intake_esm_attrs:end_time:         0003-12-16\n",
       "    intake_esm_attrs:time_range:       0001-01-16-0003-12-16\n",
       "    intake_esm_attrs:_data_format_:    netcdf\n",
       "    intake_esm_dataset_key:            present-day.monthly
" ], "text/plain": [ " Size: 2GB\n", "Dimensions: (lat: 192, lon: 288, time: 36, ensemble: 262)\n", "Coordinates:\n", " * lat (lat) float64 2kB -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0\n", " * lon (lon) float64 2kB 0.0 1.25 2.5 3.75 ... 355.0 356.2 357.5 358.8\n", " * time (time) object 288B 0001-01-16 12:00:00 ... 0003-12-16 12:00:00\n", " * ensemble (ensemble) float64 2kB 0.0 1.0 2.0 3.0 ... 259.0 260.0 261.0 262.0\n", "Data variables:\n", " SWCF (ensemble, time, lat, lon) float32 2GB dask.array\n", "Attributes:\n", " intake_esm_vars: ['SWCF']\n", " intake_esm_attrs:experiment: present-day\n", " intake_esm_attrs:frequency: monthly\n", " intake_esm_attrs:variable: SWCF\n", " intake_esm_attrs:units: W/m2\n", " intake_esm_attrs:long_name: Shortwave cloud forcing\n", " intake_esm_attrs:vertical_levels: 1.0\n", " intake_esm_attrs:start_time: 0001-01-16\n", " intake_esm_attrs:end_time: 0003-12-16\n", " intake_esm_attrs:time_range: 0001-01-16-0003-12-16\n", " intake_esm_attrs:_data_format_: netcdf\n", " intake_esm_dataset_key: present-day.monthly" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dset_dict['present-day.monthly']" ] }, { "cell_type": "code", "execution_count": 39, "id": "ede46016-c40e-41e5-ab92-f75e9e96b9d8", "metadata": {}, "outputs": [], "source": [ "ds = dset_dict['present-day.monthly']\n", "SWCF = global_mean(ds['SWCF']).mean('time').compute()" ] }, { "cell_type": "code", "execution_count": 40, "id": "fea91992-0e87-4560-a526-8bfb18b1df3d", "metadata": {}, "outputs": [], "source": [ "# Some of the PPE ensemble members are missing data so just select the params we actually have\n", "sub_params = params.iloc[SWCF.ensemble.values]\n", "# Unit normalise all the parameters\n", "ppe_params = (sub_params - sub_params.min()) / (sub_params.max() - sub_params.min())" ] }, { "cell_type": "code", "execution_count": 41, "id": "df911922-a001-4ad2-8e25-eb19f5ae97c7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 9 iterations, i.e. alpha=4.397e-01, with an active set of 9 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 20 iterations, i.e. alpha=1.975e-01, with an active set of 20 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 22 iterations, i.e. alpha=1.077e-01, with an active set of 22 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 35 iterations, i.e. alpha=2.229e-02, with an active set of 35 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 39 iterations, i.e. alpha=4.559e-03, with an active set of 39 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 41 iterations, i.e. alpha=2.061e-03, with an active set of 41 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 42 iterations, i.e. alpha=5.161e-04, with an active set of 42 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 42 iterations, i.e. alpha=1.493e-04, with an active set of 42 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n", "/opt/conda/envs/ml-notebook/lib/python3.11/site-packages/sklearn/linear_model/_least_angle.py:688: ConvergenceWarning: Regressors in active set degenerate. Dropping a regressor, after 43 iterations, i.e. alpha=6.496e-05, with an active set of 43 regressors, and the smallest cholesky pivot element being 5.960e-08. Reduce max_iter or increase eps parameters.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "Index(['micro_mg_autocon_nd_exp', 'micro_mg_dcs', 'cldfrc_dp1', 'cldfrc_dp2',\n", " 'clubb_C6thlb', 'clubb_C8', 'clubb_c1', 'clubb_c11', 'clubb_c14',\n", " 'dust_emis_fact', 'micro_mg_accre_enhan_fact', 'micro_mg_autocon_fact',\n", " 'micro_mg_autocon_lwp_exp', 'micro_mg_berg_eff_factor',\n", " 'micro_mg_homog_size', 'micro_mg_iaccr_factor', 'micro_mg_vtrmi_factor',\n", " 'microp_aero_wsub_scale', 'microp_aero_wsubi_scale',\n", " 'seasalt_emis_scale', 'zmconv_capelmt', 'zmconv_ke',\n", " 'zmconv_tiedke_add'],\n", " dtype='object')" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We can use an information criterion to choose the best parameters automatically:\n", "best_params = ppe_params[ppe_params.columns[get_param_mask(ppe_params, SWCF)]]\n", "best_params.columns" ] }, { "cell_type": "code", "execution_count": 42, "id": "22a8ba21-2ad6-4ee4-9e28-7f9cb07a56ec", "metadata": {}, "outputs": [], "source": [ "n_test = 25\n", "\n", "X_test, X_train = best_params[:n_test], best_params[n_test:]\n", "Y_test, Y_train = SWCF[:n_test], SWCF[n_test:]" ] }, { "cell_type": "markdown", "id": "733d85db-b666-42be-8db5-196f978670b4", "metadata": {}, "source": [ "## Global mean GP Model\n" ] }, { "cell_type": "code", "execution_count": 43, "id": "268df2dc-a02d-4559-932b-4a6caa152747", "metadata": {}, "outputs": [], "source": [ "# Can try different kernels here\n", "gp = gp_model(X_train, Y_train, kernel=['Linear', 'RBF'])" ] }, { "cell_type": "code", "execution_count": 44, "id": "f6ee5f93-0fa2-412f-a540-943cd13cee21", "metadata": {}, "outputs": [], "source": [ "gp.train()\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "da3f2342-ef8c-4684-a28d-205222695a52", "metadata": {}, "outputs": [], "source": [ "m, v = gp.predict(X_test)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9b22deab-48de-4ac1-9834-392b616a0646", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 46, "id": "42f0c7e2-bfef-4f65-b663-97e7fa2e55c7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Proportion of 'Bad' estimates : 8.00%\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "validation_plot(Y_test.data, m, v, figsize=(4,4))\n" ] }, { "cell_type": "markdown", "id": "922c6192-05da-485b-9f12-ef4a6c52d739", "metadata": {}, "source": [ "## Calibrate" ] }, { "cell_type": "code", "execution_count": 47, "id": "624c3fae-0ec3-498b-b8f1-f60adc2d4cad", "metadata": {}, "outputs": [], "source": [ "from esem.utils import get_random_params\n", "from esem.abc_sampler import ABCSampler, constrain" ] }, { "cell_type": "code", "execution_count": 48, "id": "a59b9590-cb49-41e2-834d-ac96512d3ed7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
micro_mg_autocon_nd_expmicro_mg_dcscldfrc_dp1cldfrc_dp2clubb_C6thlbclubb_C8clubb_c1clubb_c11clubb_c14dust_emis_fact...micro_mg_berg_eff_factormicro_mg_homog_sizemicro_mg_iaccr_factormicro_mg_vtrmi_factormicrop_aero_wsub_scalemicrop_aero_wsubi_scaleseasalt_emis_scalezmconv_capelmtzmconv_kezmconv_tiedke_add
nmb_sim
250.9715000.8058470.2064240.6714710.6842290.7939500.7797540.2276610.3005340.974161...0.1063400.5562510.4691440.5054580.9317980.0494840.9553140.0744890.7989510.634364
260.6498660.8946490.6181480.9383610.5183240.5382490.6655060.8129900.5670410.865573...0.0323050.3019090.4995150.0884590.7315000.0927610.4320790.0585860.0849650.833772
270.7037680.7675920.0511990.3489670.0864440.5492570.2732960.9793760.7156800.107133...0.7830430.6448650.8118430.9058360.9912890.2687220.4684690.8237170.6986780.404756
280.9483120.0000000.8882820.4815790.0584400.9642240.3793810.5302530.0697770.730504...0.7563680.7497650.7728500.6559380.4893640.4155440.1743670.4695520.5911950.371451
290.1235600.3652550.6749390.9795950.0720700.2719050.7065850.6542030.9235450.517474...0.8871580.4386710.9605350.4335220.6822040.2556070.7571540.7020680.8356470.854877
..................................................................
2580.7256040.9248490.0924770.6316730.7054860.7478390.6052310.5861430.3778380.898157...0.3575680.3207850.1393840.9896300.3797180.6670790.7515090.9236230.5225420.081568
2590.6336060.3411120.0767880.2648900.9774080.6683250.0178420.8786020.9769710.009132...0.9715030.0630190.1323570.0349850.0571830.4949460.3973520.8186350.4912360.056177
2600.6198600.5681490.7659960.9150390.3387420.9832730.2435260.0993710.3141780.179294...0.6934110.4488540.2050110.3710560.2736510.1853710.7721560.3230950.9594860.484692
2610.9926540.1260660.2554410.6986650.2495640.6835990.7187310.9176750.6277530.204757...0.4163350.6674220.2263510.1780840.9447020.4523940.8031220.0491630.2476090.997869
2620.2978730.8687200.8300210.9616610.9979730.3751010.0326010.5867630.0507060.727236...0.1590030.8560630.9251820.1082610.1058630.5604360.7200560.6497580.2054580.675950
\n", "

237 rows × 23 columns

\n", "
" ], "text/plain": [ " micro_mg_autocon_nd_exp micro_mg_dcs cldfrc_dp1 cldfrc_dp2 \\\n", "nmb_sim \n", "25 0.971500 0.805847 0.206424 0.671471 \n", "26 0.649866 0.894649 0.618148 0.938361 \n", "27 0.703768 0.767592 0.051199 0.348967 \n", "28 0.948312 0.000000 0.888282 0.481579 \n", "29 0.123560 0.365255 0.674939 0.979595 \n", "... ... ... ... ... \n", "258 0.725604 0.924849 0.092477 0.631673 \n", "259 0.633606 0.341112 0.076788 0.264890 \n", "260 0.619860 0.568149 0.765996 0.915039 \n", "261 0.992654 0.126066 0.255441 0.698665 \n", "262 0.297873 0.868720 0.830021 0.961661 \n", "\n", " clubb_C6thlb clubb_C8 clubb_c1 clubb_c11 clubb_c14 \\\n", "nmb_sim \n", "25 0.684229 0.793950 0.779754 0.227661 0.300534 \n", "26 0.518324 0.538249 0.665506 0.812990 0.567041 \n", "27 0.086444 0.549257 0.273296 0.979376 0.715680 \n", "28 0.058440 0.964224 0.379381 0.530253 0.069777 \n", "29 0.072070 0.271905 0.706585 0.654203 0.923545 \n", "... ... ... ... ... ... \n", "258 0.705486 0.747839 0.605231 0.586143 0.377838 \n", "259 0.977408 0.668325 0.017842 0.878602 0.976971 \n", "260 0.338742 0.983273 0.243526 0.099371 0.314178 \n", "261 0.249564 0.683599 0.718731 0.917675 0.627753 \n", "262 0.997973 0.375101 0.032601 0.586763 0.050706 \n", "\n", " dust_emis_fact ... micro_mg_berg_eff_factor micro_mg_homog_size \\\n", "nmb_sim ... \n", "25 0.974161 ... 0.106340 0.556251 \n", "26 0.865573 ... 0.032305 0.301909 \n", "27 0.107133 ... 0.783043 0.644865 \n", "28 0.730504 ... 0.756368 0.749765 \n", "29 0.517474 ... 0.887158 0.438671 \n", "... ... ... ... ... \n", "258 0.898157 ... 0.357568 0.320785 \n", "259 0.009132 ... 0.971503 0.063019 \n", "260 0.179294 ... 0.693411 0.448854 \n", "261 0.204757 ... 0.416335 0.667422 \n", "262 0.727236 ... 0.159003 0.856063 \n", "\n", " micro_mg_iaccr_factor micro_mg_vtrmi_factor microp_aero_wsub_scale \\\n", "nmb_sim \n", "25 0.469144 0.505458 0.931798 \n", "26 0.499515 0.088459 0.731500 \n", "27 0.811843 0.905836 0.991289 \n", "28 0.772850 0.655938 0.489364 \n", "29 0.960535 0.433522 0.682204 \n", "... ... ... ... \n", "258 0.139384 0.989630 0.379718 \n", "259 0.132357 0.034985 0.057183 \n", "260 0.205011 0.371056 0.273651 \n", "261 0.226351 0.178084 0.944702 \n", "262 0.925182 0.108261 0.105863 \n", "\n", " microp_aero_wsubi_scale seasalt_emis_scale zmconv_capelmt \\\n", "nmb_sim \n", "25 0.049484 0.955314 0.074489 \n", "26 0.092761 0.432079 0.058586 \n", "27 0.268722 0.468469 0.823717 \n", "28 0.415544 0.174367 0.469552 \n", "29 0.255607 0.757154 0.702068 \n", "... ... ... ... \n", "258 0.667079 0.751509 0.923623 \n", "259 0.494946 0.397352 0.818635 \n", "260 0.185371 0.772156 0.323095 \n", "261 0.452394 0.803122 0.049163 \n", "262 0.560436 0.720056 0.649758 \n", "\n", " zmconv_ke zmconv_tiedke_add \n", "nmb_sim \n", "25 0.798951 0.634364 \n", "26 0.084965 0.833772 \n", "27 0.698678 0.404756 \n", "28 0.591195 0.371451 \n", "29 0.835647 0.854877 \n", "... ... ... \n", "258 0.522542 0.081568 \n", "259 0.491236 0.056177 \n", "260 0.959486 0.484692 \n", "261 0.247609 0.997869 \n", "262 0.205458 0.675950 \n", "\n", "[237 rows x 23 columns]" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train" ] }, { "cell_type": "code", "execution_count": 49, "id": "bce220b1-d3b8-407b-b04f-282e242dc9db", "metadata": {}, "outputs": [], "source": [ "# Setup sampler with 1 million points\n", "sample_points = pd.DataFrame(data=get_random_params(23, int(1e6)), columns=X_train.columns)\n", "sampler = ABCSampler(gp, np.asarray([-40.5]), obs_uncertainty=0.5)" ] }, { "cell_type": "code", "execution_count": 50, "id": "1612dc29-7e37-4b3e-9c5d-bdfd906bb808", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e2e1d419a6d54186bca265f5b0417216", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000000 [00:00" ] } ], "source": [ "valid_samples = sampler.batch_constrain(sample_points, batch_size=10000)\n" ] }, { "cell_type": "code", "execution_count": 51, "id": "f576276d-6729-4d04-977a-8d0283aff792", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Remaining points: 100\n" ] } ], "source": [ "print(\"Remaining points: {}\".format(valid_samples.sum()))\n" ] }, { "cell_type": "code", "execution_count": 52, "id": "6d556619-d81c-42b3-bf7d-e72b485b3741", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "numpy.ndarray" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(valid_samples)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:ml-notebook]", "language": "python", "name": "conda-env-ml-notebook-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }