diff --git a/notebooks/install.txt b/notebooks/install.txt new file mode 100644 index 00000000..b5e47885 --- /dev/null +++ b/notebooks/install.txt @@ -0,0 +1,16 @@ +module load python3/new + +python3 -m venv /home/erpl/venvs/aifs-inference +source /home/erpl/venvs/aifs-inference/bin/activate + +pip install ipykernel +pip install anemoi-inference[huggingface]==0.4.9 anemoi-models==0.3.1 +pip install earthkit-regrid==0.4.0 ecmwf-opendata + +pip install psutil +pip install flash_attn==2.7.2.post1 --no-build-isolation + +pip install matplotlib +pip install cartopy + +python -m ipykernel install --user --name=kernel-aifs-inference diff --git a/notebooks/run_AIFS_v1.ipynb b/notebooks/run_AIFS_v1.ipynb new file mode 100644 index 00000000..53a0e908 --- /dev/null +++ b/notebooks/run_AIFS_v1.ipynb @@ -0,0 +1,3028 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "nf6PiMYmLEtZ" + }, + "source": [ + "This notebook runs ECMWF's aifs-single-v1 data-driven model, using ECMWF's [open data](https://www.ecmwf.int/en/forecasts/datasets/open-data) dataset and the [anemoi-inference](https://anemoi-inference.readthedocs.io/en/latest/apis/level1.html) package." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YDZrZ8HVxLfU" + }, + "source": [ + "# 1. Install Required Packages and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 44979, + "status": "ok", + "timestamp": 1733384436074, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "_sT8Re5jLRAH", + "outputId": "384c13e8-e739-4e31-90c5-11970da5808f" + }, + "outputs": [], + "source": [ + "# Uncomment the lines below to install the required packages\n", + "#!pip install -q anemoi-inference[huggingface]==0.4.9 anemoi-models==0.3.1\n", + "#!pip install -q earthkit-regrid==0.4.0 ecmwf-opendata \n", + "#!pip install -q flash_attn" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "executionInfo": { + "elapsed": 3143, + "status": "ok", + "timestamp": 1733384445221, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "VBJmsrqGLEtb" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import datetime\n", + "from collections import defaultdict\n", + "import copy\n", + "\n", + "from typing import Callable\n", + "import earthkit.data as ekd\n", + "import earthkit.regrid as ekr\n", + "\n", + "from anemoi.inference.runners.simple import SimpleRunner\n", + "\n", + "from ecmwf.opendata import Client as OpendataClient" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import Tensor\n", + "import torch\n", + "from torch.autograd.functional import vjp" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x-JJcbnhLEtc" + }, + "source": [ + "# 2. Retrieve Initial Conditions from ECMWF Open Data\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B8DkdKfUxaPr" + }, + "source": [ + "### List of parameters to retrieve form ECMWF open data" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "executionInfo": { + "elapsed": 195, + "status": "ok", + "timestamp": 1733384595021, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "yZOyITo6LEtd" + }, + "outputs": [], + "source": [ + "PARAM_SFC = [\"10u\", \"10v\", \"2d\", \"2t\", \"msl\", \"skt\", \"sp\", \"tcw\", \"lsm\", \"z\", \"slor\", \"sdor\"]\n", + "PARAM_SOIL =[\"vsw\",\"sot\"]\n", + "PARAM_PL = [\"gh\", \"t\", \"u\", \"v\", \"w\", \"q\"]\n", + "LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]\n", + "SOIL_LEVELS = [1,2]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-CdbpgHLxczB" + }, + "source": [ + "### Select a date" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "executionInfo": { + "elapsed": 2109, + "status": "ok", + "timestamp": 1733384601142, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "XwWLA0OcLEtd" + }, + "outputs": [], + "source": [ + "DATE = OpendataClient().latest()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 388, + "status": "ok", + "timestamp": 1733384601842, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "3_Fy3a0WLEte", + "outputId": "96799389-eb2d-4b1d-f5f3-60e8a309c9ef" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initial date is 2025-04-11 06:00:00\n" + ] + } + ], + "source": [ + "print(\"Initial date is\", DATE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UMmVhb5uxiB9" + }, + "source": [ + "### Get the data from the ECMWF Open Data API" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "executionInfo": { + "elapsed": 215, + "status": "ok", + "timestamp": 1733384619515, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "8X2ShHMeLEtf" + }, + "outputs": [], + "source": [ + "def get_open_data(param, levelist=[]):\n", + " fields = defaultdict(list)\n", + " # Get the data for the current date and the previous date\n", + " for date in [DATE - datetime.timedelta(hours=6), DATE]:\n", + " data = ekd.from_source(\"ecmwf-open-data\", date=date, param=param, levelist=levelist)\n", + " for f in data:\n", + " # Open data is between -180 and 180, we need to shift it to 0-360\n", + " assert f.to_numpy().shape == (721,1440)\n", + " values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)\n", + " # Interpolate the data to from 0.25 to N320\n", + " values = ekr.interpolate(values, {\"grid\": (0.25, 0.25)}, {\"grid\": \"N320\"})\n", + " # Add the values to the list\n", + " name = f\"{f.metadata('param')}_{f.metadata('levelist')}\" if levelist else f.metadata(\"param\")\n", + " fields[name].append(values)\n", + "\n", + " # Create a single matrix for each parameter\n", + " for param, values in fields.items():\n", + " fields[param] = np.stack(values)\n", + "\n", + " return fields" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_yjso9wvxli0" + }, + "source": [ + "### Get Input Fields" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "executionInfo": { + "elapsed": 186, + "status": "ok", + "timestamp": 1733384638318, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "as23RAZiLEtf" + }, + "outputs": [], + "source": [ + "fields = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VJs1usVsxq5Q" + }, + "source": [ + "#### Add the single levels fields" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "a7c36aeb00074c6c92735e169b765251", + "7ba380ae6b76488299d17aa1bffece67", + "b695ad246f574c3993d4ca3925be2853", + "6c51d6cbcf924425824f8c4390164834", + "ba596bf1fa9545eda887a2439af3c70c", + "d3f5816d005b48fa981d2e9e947fe020", + "7d9d217ab1bd4010be8eee17e335083a", + "ea41eba5fbed4832b97520086a567d67", + "f362732783d741cd982a21f73b647a18", + "ee9c57247f1d4e63b255784b4342b7b6", + "fa7fbe6608b1416a8b3ad52bea2a06eb", + "ee8fecacdf934cacb4c5966da66f68ba", + "b8f734cd22254a559cbf0d4270ceacb4", + "7487efa049164d9a92b01d41deac2a1d", + "00131dfdac6c4787b3c6aefb08f13beb", + "769c57208a284dd3a281d225ec9276f3", + "a53a407a6f6e451497336766244135b1", + "caf04b00d2ee45de9d5667ecac91cf1d", + "cbbd457e67034ce58f82a31228a85ff3", + "9db824f32e8c49f298aa110df69465bd", + "38181e60ddf14b2ab0b98a46d04c5892", + "a8d1c1ad0e40429a85bce2f7d2ba15be", + "7fc7795fd4b14f9f912629eaee028df3", + "305be4f74c6c42a6abd1e83649b5e635", + "9f5ed1c328ef49b783f8e3f62711fb5f", + "52b672a66d2a475394bdef455a4cebf8", + "5f91729034eb47fab123411dc3353e22", + "d6974595ebd043039970461f78f0195e", + "ebee72a8dae74add85d1e58c5a7ecc0b", + "46015c2b0e4d4788b07468f874345557", + "29607e1346af4a3188a33f7324bd2ff2", + "36ff6fe16f494b6e8d75fd46569724e7", + "3d271316983f41e0b2d0ef9548003f55" + ] + }, + "executionInfo": { + "elapsed": 32963, + "status": "ok", + "timestamp": 1733384688321, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": -60 + }, + "id": "b8sjaN5SLEtf", + "outputId": "228d3a4d-ba64-4500-9b7e-64cfe1430401", + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], + "source": [ + "fields.update(get_open_data(param=PARAM_SFC))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "20250411000000-0h-oper-fc.grib2: 0%| | 0.00/2.12M [00:00 Callable:\n", + "def model_wrapper(input_state:dict) -> Callable: \n", + " \n", + " #\n", + " #t_input:tuple[Tensor]=()\n", + " #t_input_keys=()\n", + " #for k, v in input_state[\"fields\"].items():\n", + " # t_input += (Tensor(v),)\n", + " # t_input_keys += (k,)\n", + " \n", + " t_input_keys=()\n", + " for k in input_state[\"fields\"].keys():\n", + " t_input_keys += (k,)\n", + "\n", + " def model_step(*args) -> tuple[Tensor]:\n", + " print(\"ARGUMENTS\")\n", + " print(\" model_wrapper args: number of Tensor in tuple \",len(args))\n", + " #print(\" model_wrapper args: norm of 2t \",np.linalg.norm(args[3].detach().numpy())) #index 3 is 2t in t_input\n", + " #Convert tuple[Tensor] in dictionary\n", + " for k,v in zip(t_input_keys, args):\n", + " input_state[\"fields\"][k] = v\n", + "\n", + " #Call runner.run\n", + " generator_state=runner.run(input_state=input_state, lead_time=12) \n", + " output_state=next(generator_state) \n", + " \n", + " print(output_state)\n", + " \n", + " #Convert Dictionary to tuple[Tensor]\n", + " t_output:tuple[Tensor]=()\n", + " t_output_keys=()\n", + " for k,v in output_state[\"fields\"].items():\n", + " t_output += (Tensor(v),)\n", + " t_output_keys += (k,)\n", + " #print(\" t_output: number of Tensor in tuple \",len(t_output))\n", + " #print(\" t_output_keys \",t_output_keys)\n", + " \n", + " return t_output\n", + " \n", + " return model_step" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "LOAD INPUT IN DICTIONARY\n", + " input_state: number of fields in fields 94\n", + " input_state: norm of 2t 299821.67010851036\n", + " convert dictionary to tuple[Tensor]\n", + " t_input: number of Tensor in tuple 94\n", + " t_input: norm of 2t 299821.75\n", + "\n", + "CREATE PERTURBATION\n", + " Integrate the model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 96420.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " output_state: number of fields in fields 102\n", + " output_state: norm of 2t 212410.98\n", + " Perturbing field 2t index 81\n", + " output_pert: number of fields in fields 102\n", + " output_pert: norm of 2t 2124.1091\n", + " convert dictionary to tuple[Tensor]\n", + " t_pert: number of Tensor in tuple 102\n", + " t_pert: norm of 2t 2124.1091\n", + "\n", + "COMPUTE DERIVATIVES\n", + "ARGUMENTS\n", + " model_wrapper args: number of Tensor in tuple 94\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[36]\u001b[39m\u001b[32m, line 60\u001b[39m\n\u001b[32m 57\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mCOMPUTE DERIVATIVES\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.autograd.set_detect_anomaly(\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[32m---> \u001b[39m\u001b[32m60\u001b[39m t_output, t_dx_output = \u001b[43mvjp\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 61\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel_wrapper\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_state\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 62\u001b[39m \u001b[43m \u001b[49m\u001b[43mt_input\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# \"flattened\" inputs\u001b[39;49;00m\n\u001b[32m 63\u001b[39m \u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m=\u001b[49m\u001b[43mt_pert\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# output perturbations\u001b[39;49;00m\n\u001b[32m 64\u001b[39m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 65\u001b[39m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 66\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 70\u001b[39m \u001b[38;5;28mprint\u001b[39m(np.linalg.norm(t_output[\u001b[32m81\u001b[39m]))\n\u001b[32m 71\u001b[39m \u001b[38;5;28mprint\u001b[39m(np.linalg.norm(t_dx_output[\u001b[32m3\u001b[39m])) \n", + "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/aifs-inference/lib/python3.11/site-packages/torch/autograd/functional.py:327\u001b[39m, in \u001b[36mvjp\u001b[39m\u001b[34m(func, inputs, v, create_graph, strict)\u001b[39m\n\u001b[32m 324\u001b[39m is_inputs_tuple, inputs = _as_tuple(inputs, \u001b[33m\"\u001b[39m\u001b[33minputs\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mvjp\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 325\u001b[39m inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m327\u001b[39m outputs = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 328\u001b[39m is_outputs_tuple, outputs = _as_tuple(\n\u001b[32m 329\u001b[39m outputs, \u001b[33m\"\u001b[39m\u001b[33moutputs of the user-provided function\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mvjp\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 330\u001b[39m )\n\u001b[32m 331\u001b[39m _check_requires_grad(outputs, \u001b[33m\"\u001b[39m\u001b[33moutputs\u001b[39m\u001b[33m\"\u001b[39m, strict=strict)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[35]\u001b[39m\u001b[32m, line 25\u001b[39m, in \u001b[36mmodel_wrapper..model_step\u001b[39m\u001b[34m(*args)\u001b[39m\n\u001b[32m 23\u001b[39m \u001b[38;5;66;03m#Call runner.run\u001b[39;00m\n\u001b[32m 24\u001b[39m generator_state=runner.run(input_state=input_state, lead_time=\u001b[32m12\u001b[39m) \n\u001b[32m---> \u001b[39m\u001b[32m25\u001b[39m output_state=\u001b[38;5;28mnext\u001b[39m(generator_state) \n\u001b[32m 27\u001b[39m \u001b[38;5;28mprint\u001b[39m(output_state)\n\u001b[32m 29\u001b[39m \u001b[38;5;66;03m#Convert Dictionary to tuple[Tensor]\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/aifs-inference/lib/python3.11/site-packages/anemoi/inference/runner.py:128\u001b[39m, in \u001b[36mRunner.run\u001b[39m\u001b[34m(self, input_state, lead_time)\u001b[39m\n\u001b[32m 125\u001b[39m \u001b[38;5;28mself\u001b[39m.lead_time = lead_time\n\u001b[32m 126\u001b[39m \u001b[38;5;28mself\u001b[39m.time_step = \u001b[38;5;28mself\u001b[39m.checkpoint.timestep\n\u001b[32m--> \u001b[39m\u001b[32m128\u001b[39m input_tensor = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mprepare_input_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_state\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 130\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 131\u001b[39m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m.postprocess(\u001b[38;5;28mself\u001b[39m.forecast(lead_time, input_tensor, input_state))\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/aifs-inference/lib/python3.11/site-packages/anemoi/inference/runner.py:213\u001b[39m, in \u001b[36mRunner.prepare_input_tensor\u001b[39m\u001b[34m(self, input_state, dtype)\u001b[39m\n\u001b[32m 211\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m check:\n\u001b[32m 212\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mDuplicate variable \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvar\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m in input fields\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m213\u001b[39m \u001b[43minput_tensor_numpy\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m = field\n\u001b[32m 214\u001b[39m check.add(i)\n\u001b[32m 216\u001b[39m \u001b[38;5;28mself\u001b[39m._input_tensor_by_name[i] = var\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/aifs-inference/lib/python3.11/site-packages/torch/_tensor.py:1196\u001b[39m, in \u001b[36mTensor.__array__\u001b[39m\u001b[34m(self, dtype)\u001b[39m\n\u001b[32m 1194\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.numpy()\n\u001b[32m 1195\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1196\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m.astype(dtype, copy=\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[31mRuntimeError\u001b[39m: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead." + ] + } + ], + "source": [ + "#Load input as dictionary\n", + "print(\"\")\n", + "print(\"LOAD INPUT IN DICTIONARY\")\n", + "input_state = dict(date=DATE, fields=fields)\n", + "print(\" input_state: number of fields in fields \",len(input_state[\"fields\"]))\n", + "print(\" input_state: norm of 2t \",np.linalg.norm(input_state[\"fields\"][\"2t\"]))\n", + "\n", + "#Not sure why input_state has a different structure when called a second time\n", + "#t_input_keys: first call gives ('10u', '10v', '2d', '2t', 'msl',..., 'z_150', 'z_100', 'z_50')\n", + "#t_input_keys: second call gives ('10u', '10v', '2d', '2t', 'msl',..., 'z_150', 'z_100', 'z_50', 'cos_latitude', 'cos_longitude', 'sin_latitude', 'sin_longitude', 'cos_julian_day', 'cos_local_time', 'sin_julian_day', 'sin_local_time', 'insolation')\n", + "\n", + "\n", + "print(\" convert dictionary to tuple[Tensor]\")\n", + "t_input:tuple[Tensor]=()\n", + "t_input_keys=()\n", + "for k,v in input_state[\"fields\"].items():\n", + " t_input += (Tensor(v),)\n", + " t_input_keys += (k,)\n", + "print(\" t_input: number of Tensor in tuple \",len(t_input))\n", + "print(\" t_input: norm of 2t \",np.linalg.norm(t_input[3])) #index 3 is 2t in t_input\n", + "\n", + "#Create perturbation (from the output of the forecast model)\n", + "print(\"\")\n", + "print(\"CREATE PERTURBATION\")\n", + "pert_name=\"2t\"\n", + "print(\" Integrate the model\")\n", + "generator_state=runner.run(input_state=input_state, lead_time=12) \n", + "output_state=next(generator_state)\n", + "print(\" output_state: number of fields in fields \",len(output_state[\"fields\"]))\n", + "print(\" output_state: norm of 2t \",np.linalg.norm(output_state[\"fields\"][\"2t\"]))\n", + "\n", + "output_pert=copy.deepcopy(output_state)\n", + "for idx,(k,v) in enumerate(output_state[\"fields\"].items()):\n", + " if k==pert_name:\n", + " print(\" Perturbing field %s index %s\"%(pert_name,idx))\n", + " output_pert[\"fields\"][k]=0.01*output_state[\"fields\"][k]\n", + " else:\n", + " output_pert[\"fields\"][k]=0.0*output_state[\"fields\"][k]\n", + " \n", + "print(\" output_pert: number of fields in fields \",len(output_pert[\"fields\"]))\n", + "print(\" output_pert: norm of 2t \",np.linalg.norm(output_pert[\"fields\"][\"2t\"]))\n", + "\n", + "#Convert dictionary to tuple[Tensor]\n", + "print(\" convert dictionary to tuple[Tensor]\")\n", + "t_pert:tuple[Tensor]=()\n", + "t_pert_keys=()\n", + "for k,v in output_pert[\"fields\"].items():\n", + " t_pert += (Tensor(v),) \n", + " t_pert_keys += (k,)\n", + "print(\" t_pert: number of Tensor in tuple \",len(t_pert)) \n", + "print(\" t_pert: norm of 2t \",np.linalg.norm(t_pert[81])) #index 81 is 2t in t_output\n", + "\n", + "\n", + "\n", + "#Compute derivatives\n", + "print(\"\")\n", + "print(\"COMPUTE DERIVATIVES\")\n", + "with torch.autograd.set_detect_anomaly(False):\n", + " \n", + " t_output, t_dx_output = vjp(\n", + " model_wrapper(input_state),\n", + " t_input, # \"flattened\" inputs\n", + " v=t_pert, # output perturbations\n", + " create_graph=False,\n", + " strict=False,\n", + " )\n", + " \n", + " \n", + " \n", + "print(np.linalg.norm(t_output[81]))\n", + "print(np.linalg.norm(t_dx_output[3])) \n", + "print(len(t_dx_output))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note** \n", + "Due to the non-determinism of GPUs, users will be unable to exactly reproduce an official AIFS forecast when running AIFS Single themselves.\n", + "If you want to enforece determinism at GPU level, you can do so enforcing the following settings:\n", + "\n", + "```\n", + "#First in your terminal\n", + "export CUBLAS_WORKSPACE_CONFIG=:4096:8\n", + "\n", + "#And then before running inference:\n", + "import torch\n", + "torch.backends.cudnn.benchmark = False\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.use_deterministic_algorithms(True)\n", + "\n", + "```\n", + "Using the above approach will significantly increase runtime. Additionally, the input conditions come from open data, which we reproject from o1280 (the original projection of IFS initial conditions) to n320 (AIFS resolution) by first converting them to a 0.25-degree grid. In the operational setup, however, data is reprojected directly from o1280 to n320. This difference in reprojection methods may lead to variations in the resulting input conditions, causing minor differences in the forecast." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Inspect the generated forecast" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plot a field" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To be able to run the plotting section below you need to install additional dependencies\n", + "\n", + "# !pip install -q matplotlib\n", + "# !pip install -q cartopy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6C9NyonKLEth" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import cartopy.crs as ccrs\n", + "import cartopy.feature as cfeature\n", + "import matplotlib.tri as tri" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 446 + }, + "id": "-5PkTWaFLEth", + "outputId": "0fe9d1ff-14bc-42b6-91d7-0f3c24b3fc42" + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'state' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[21]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mfix\u001b[39m(lons):\n\u001b[32m 2\u001b[39m \u001b[38;5;66;03m# Shift the longitudes from 0-360 to -180-180\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.where(lons > \u001b[32m180\u001b[39m, lons - \u001b[32m360\u001b[39m, lons)\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m latitudes = \u001b[43mstate\u001b[49m[\u001b[33m\"\u001b[39m\u001b[33mlatitudes\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 6\u001b[39m longitudes = state[\u001b[33m\"\u001b[39m\u001b[33mlongitudes\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 7\u001b[39m values = state[\u001b[33m\"\u001b[39m\u001b[33mfields\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33m100u\u001b[39m\u001b[33m\"\u001b[39m]\n", + "\u001b[31mNameError\u001b[39m: name 'state' is not defined" + ] + } + ], + "source": [ + "def fix(lons):\n", + " # Shift the longitudes from 0-360 to -180-180\n", + " return np.where(lons > 180, lons - 360, lons)\n", + "\n", + "latitudes = state[\"latitudes\"]\n", + "longitudes = state[\"longitudes\"]\n", + "values = state[\"fields\"][\"100u\"]\n", + "\n", + "fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={\"projection\": ccrs.PlateCarree()})\n", + "ax.coastlines()\n", + "ax.add_feature(cfeature.BORDERS, linestyle=\":\")\n", + "\n", + "triangulation = tri.Triangulation(fix(longitudes), latitudes)\n", + "\n", + "contour=ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap=\"RdBu\")\n", + "cbar = fig.colorbar(contour, ax=ax, orientation=\"vertical\", shrink=0.7, label=\"100u\")\n", + "\n", + "plt.title(\"100m winds (100u) at {}\".format(state[\"date\"]))\n", + "plt.show()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [ + { + "file_id": "https://huggingface.co/ecmwf/aifs-single/blob/main/run_AIFS_v0_2_1.ipynb", + "timestamp": 1733385064965 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "aifs-inference", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "00131dfdac6c4787b3c6aefb08f13beb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_38181e60ddf14b2ab0b98a46d04c5892", + "placeholder": "​", + "style": "IPY_MODEL_a8d1c1ad0e40429a85bce2f7d2ba15be", + "value": " 7.00M/7.11M [00:06<00:00, 1.86MB/s]" + } + }, + "04d53ab404ea4571938fa1049296a322": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bd4aa43801dd48c88a6354bbb0b80115", + "placeholder": "​", + "style": "IPY_MODEL_32341120752344898e2932b6ce5f4dfb", + "value": "aifs_single_v0.2.1.ckpt: 100%" + } + }, + "0ffec54991a94c10b6e455696389f725": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "12ddda55fc884a48b8d4563843da20fe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "16b342b766a840e7b2466add51d83e1a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "29607e1346af4a3188a33f7324bd2ff2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "305be4f74c6c42a6abd1e83649b5e635": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d6974595ebd043039970461f78f0195e", + "placeholder": "​", + "style": "IPY_MODEL_ebee72a8dae74add85d1e58c5a7ecc0b", + "value": "<multiple>: 100%" + } + }, + "32341120752344898e2932b6ce5f4dfb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "36ff6fe16f494b6e8d75fd46569724e7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "38181e60ddf14b2ab0b98a46d04c5892": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3d271316983f41e0b2d0ef9548003f55": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "46015c2b0e4d4788b07468f874345557": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4e3e32358ece4093af2e06615531ed77": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "50215614639344198a7b297ba05b75a0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_04d53ab404ea4571938fa1049296a322", + "IPY_MODEL_b7eaa9417e2144a3a91a4defa4883a28", + "IPY_MODEL_5d53b4f6d81c4d63a8f717cc4575d1f5" + ], + "layout": "IPY_MODEL_d4a1986815394711bf612145074cfa7e" + } + }, + "52b672a66d2a475394bdef455a4cebf8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_36ff6fe16f494b6e8d75fd46569724e7", + "placeholder": "​", + "style": "IPY_MODEL_3d271316983f41e0b2d0ef9548003f55", + "value": " 7.29M/7.32M [00:07<00:00, 1.25MB/s]" + } + }, + "5d53b4f6d81c4d63a8f717cc4575d1f5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_726b5cbb17504bbcbd8053b52208a72d", + "placeholder": "​", + "style": "IPY_MODEL_c250ee19694e4b08855f1a851f4b44d0", + "value": " 1.01G/1.01G [00:23<00:00, 42.5MB/s]" + } + }, + "5f91729034eb47fab123411dc3353e22": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "611bc2338a1343d1b534eada92fd3545": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fb25c8b6d79446dc94c6b977a1b99572", + "max": 57274181, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c5903fc789d54233a5250705b0267337", + "value": 57274181 + } + }, + "612a1882b3184d41ba337f0da2c513be": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_83d83d3305d84c98a871fc8559e2e177", + "max": 57147358, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_12ddda55fc884a48b8d4563843da20fe", + "value": 57147358 + } + }, + "68473001b7a746a5ae5a819071839a70": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d3be2aaf4f27438e82ce7433a865b9ba", + "placeholder": "​", + "style": "IPY_MODEL_f8074f08f07c485c9de25abcc50c8456", + "value": " 54.6M/54.6M [00:46<00:00, 928kB/s]" + } + }, + "6c51d6cbcf924425824f8c4390164834": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ee9c57247f1d4e63b255784b4342b7b6", + "placeholder": "​", + "style": "IPY_MODEL_fa7fbe6608b1416a8b3ad52bea2a06eb", + "value": " 7.09M/7.31M [00:07<00:00, 1.13MB/s]" + } + }, + "6d5a830229a049babef7b6fed8b3da2a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "726b5cbb17504bbcbd8053b52208a72d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7487efa049164d9a92b01d41deac2a1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cbbd457e67034ce58f82a31228a85ff3", + "max": 7453633, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9db824f32e8c49f298aa110df69465bd", + "value": 7453633 + } + }, + "769c57208a284dd3a281d225ec9276f3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "7a27c1645f924655805568959f6256c1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7ba380ae6b76488299d17aa1bffece67": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d3f5816d005b48fa981d2e9e947fe020", + "placeholder": "​", + "style": "IPY_MODEL_7d9d217ab1bd4010be8eee17e335083a", + "value": "<multiple>:  97%" + } + }, + "7d9d217ab1bd4010be8eee17e335083a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "7fc7795fd4b14f9f912629eaee028df3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_305be4f74c6c42a6abd1e83649b5e635", + "IPY_MODEL_9f5ed1c328ef49b783f8e3f62711fb5f", + "IPY_MODEL_52b672a66d2a475394bdef455a4cebf8" + ], + "layout": "IPY_MODEL_5f91729034eb47fab123411dc3353e22" + } + }, + "83d83d3305d84c98a871fc8559e2e177": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "85fa8cffc2a54c34a0a5661c309ccd95": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "95909816a021463082be3fbd717b91f8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdb25a1d306e40539295ea995d59d011", + "placeholder": "​", + "style": "IPY_MODEL_4e3e32358ece4093af2e06615531ed77", + "value": "<multiple>: 100%" + } + }, + "9db824f32e8c49f298aa110df69465bd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "9f5ed1c328ef49b783f8e3f62711fb5f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_46015c2b0e4d4788b07468f874345557", + "max": 7672864, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_29607e1346af4a3188a33f7324bd2ff2", + "value": 7672864 + } + }, + "a53a407a6f6e451497336766244135b1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a7c36aeb00074c6c92735e169b765251": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7ba380ae6b76488299d17aa1bffece67", + "IPY_MODEL_b695ad246f574c3993d4ca3925be2853", + "IPY_MODEL_6c51d6cbcf924425824f8c4390164834" + ], + "layout": "IPY_MODEL_ba596bf1fa9545eda887a2439af3c70c" + } + }, + "a8d1c1ad0e40429a85bce2f7d2ba15be": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b0691e09b17944879def85aa0ed8397b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b695ad246f574c3993d4ca3925be2853": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ea41eba5fbed4832b97520086a567d67", + "max": 7668008, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f362732783d741cd982a21f73b647a18", + "value": 7668008 + } + }, + "b7eaa9417e2144a3a91a4defa4883a28": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b0691e09b17944879def85aa0ed8397b", + "max": 1006672855, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b9543d271edf446aa4d0b1c4c61b62e7", + "value": 1006672855 + } + }, + "b8f734cd22254a559cbf0d4270ceacb4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a53a407a6f6e451497336766244135b1", + "placeholder": "​", + "style": "IPY_MODEL_caf04b00d2ee45de9d5667ecac91cf1d", + "value": "9533e90f8433424400ab53c7fafc87ba1a04453093311c0b5bd0b35fedc1fb83.npz:  98%" + } + }, + "b9543d271edf446aa4d0b1c4c61b62e7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ba596bf1fa9545eda887a2439af3c70c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "bd4aa43801dd48c88a6354bbb0b80115": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c250ee19694e4b08855f1a851f4b44d0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c5903fc789d54233a5250705b0267337": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "caf04b00d2ee45de9d5667ecac91cf1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "cbbd457e67034ce58f82a31228a85ff3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cf80c790c9994f168bd349fd93180fa8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7a27c1645f924655805568959f6256c1", + "placeholder": "​", + "style": "IPY_MODEL_0ffec54991a94c10b6e455696389f725", + "value": " 54.5M/54.5M [00:49<00:00, 867kB/s]" + } + }, + "d3be2aaf4f27438e82ce7433a865b9ba": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d3f5816d005b48fa981d2e9e947fe020": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d4a1986815394711bf612145074cfa7e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d6974595ebd043039970461f78f0195e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dbd3608f85634dcba7242d7e2672fa0f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f197a28a70164daeb51f84247056f00d", + "placeholder": "​", + "style": "IPY_MODEL_16b342b766a840e7b2466add51d83e1a", + "value": "<multiple>: 100%" + } + }, + "e172d23a055d4184bc00a7f7a5f59a8a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_dbd3608f85634dcba7242d7e2672fa0f", + "IPY_MODEL_611bc2338a1343d1b534eada92fd3545", + "IPY_MODEL_68473001b7a746a5ae5a819071839a70" + ], + "layout": "IPY_MODEL_6d5a830229a049babef7b6fed8b3da2a" + } + }, + "e85831a585b4474284eeccc7403c6d6a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_95909816a021463082be3fbd717b91f8", + "IPY_MODEL_612a1882b3184d41ba337f0da2c513be", + "IPY_MODEL_cf80c790c9994f168bd349fd93180fa8" + ], + "layout": "IPY_MODEL_85fa8cffc2a54c34a0a5661c309ccd95" + } + }, + "ea41eba5fbed4832b97520086a567d67": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ebee72a8dae74add85d1e58c5a7ecc0b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ee8fecacdf934cacb4c5966da66f68ba": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b8f734cd22254a559cbf0d4270ceacb4", + "IPY_MODEL_7487efa049164d9a92b01d41deac2a1d", + "IPY_MODEL_00131dfdac6c4787b3c6aefb08f13beb" + ], + "layout": "IPY_MODEL_769c57208a284dd3a281d225ec9276f3" + } + }, + "ee9c57247f1d4e63b255784b4342b7b6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f197a28a70164daeb51f84247056f00d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f362732783d741cd982a21f73b647a18": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f8074f08f07c485c9de25abcc50c8456": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fa7fbe6608b1416a8b3ad52bea2a06eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fb25c8b6d79446dc94c6b977a1b99572": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fdb25a1d306e40539295ea995d59d011": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 77299af2..8e7e2087 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -183,6 +183,11 @@ def variable_to_input_tensor_index(self) -> Any: """Get the variable to input tensor index.""" return self._metadata.variable_to_input_tensor_index + @property + def variable_to_output_tensor_index(self) -> Any: + """Get the variable to input tensor index.""" + return self._metadata.variable_to_output_tensor_index + @property def model_computed_variables(self) -> Any: """Get the model computed variables.""" @@ -213,6 +218,11 @@ def prognostic_input_mask(self) -> Any: """Get the prognostic input mask.""" return self._metadata.prognostic_input_mask + @property + def input_tensor_index_to_variable(self) -> Any: + """Get the output tensor index to variable.""" + return self._metadata.input_tensor_index_to_variable + @property def output_tensor_index_to_variable(self) -> Any: """Get the output tensor index to variable.""" diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index 0c3210df..bb06032e 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -207,6 +207,25 @@ def variable_to_input_tensor_index(self) -> frozendict: return frozendict({v: mapping[i] for i, v in enumerate(self.variables) if i in mapping}) + @property + def variable_to_output_tensor_index(self) -> frozendict: + """Return the mapping between variable name and output tensor index.""" + mapping = self._make_indices_mapping( + self._indices.data.output.full, + self._indices.model.output.full, + ) + + return frozendict({v: mapping[i] for i, v in enumerate(self.variables) if i in mapping}) + + @cached_property + def input_tensor_index_to_variable(self) -> frozendict: + """Return the mapping between output tensor index and variable name.""" + mapping = self._make_indices_mapping( + self._indices.model.input.full, + self._indices.data.input.full, + ) + return frozendict({k: self.variables[v] for k, v in mapping.items()}) + @cached_property def output_tensor_index_to_variable(self) -> frozendict: """Return the mapping between output tensor index and variable name.""" diff --git a/src/anemoi/inference/perturbation.py b/src/anemoi/inference/perturbation.py new file mode 100644 index 00000000..59f2cc5d --- /dev/null +++ b/src/anemoi/inference/perturbation.py @@ -0,0 +1,102 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +import torch + +from .checkpoint import Checkpoint + +LOG = logging.getLogger(__name__) +R = 6371.0 # Earth radius in km + + +def haversine(coords: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + """Compute haversine distance between multiple coordinates and a single point. + + Args: + coords: (N, 2) tensor of [lat, lon] in degrees + point: (2,) tensor of [lat, lon] in degrees + + Returns: + (N,) tensor of distances in kilometers + """ + coords_rad = torch.deg2rad(coords) + point_rad = torch.deg2rad(point) + + lat1, lon1 = coords_rad[:, 0], coords_rad[:, 1] + lat2, lon2 = point_rad[0], point_rad[1] + + dlat, dlon = lat2 - lat1, lon2 - lon1 + + a = torch.sin(dlat / 2) ** 2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon / 2) ** 2 + c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a)) + + return R * c + + +class Perturbation: + """Perturbation class.""" + + def __init__( + self, + checkpoint: str, + perturbed_variable: str, + perturbation_location: float, + perturbation_radius_km: float = 100.0, + patch_metadata: dict[str, Any] = {}, + ) -> None: + """Initialize the Perturbation. + + Parameters + ---------- + perturbed_variable : str + The variable to perturb. + """ + assert len(perturbation_location) == 2, "perturbation_location must be a tuple of (lat, lon)" + assert perturbation_location[0] >= -90 and perturbation_location[0] <= 90, "Latitude must be between -90 and 90" + assert perturbation_location[1] >= 0 and perturbation_location[1] <= 360, "Longitude must be between 0 and 360" + self.perturbed_variable = perturbed_variable + self.perturbation_location = torch.tensor(perturbation_location) + self.perturbation_radius_km = perturbation_radius_km + self._checkpoint = Checkpoint(checkpoint, patch_metadata=patch_metadata) + + @property + def variable_to_output_tensor_index(self) -> dict[str, int]: + return self._checkpoint._metadata.variable_to_output_tensor_index + + @property + def output_shape(self) -> tuple[int, ...]: + return ( + 1, + 1, + self._checkpoint._metadata.number_of_grid_points, + len(self._checkpoint._metadata.variable_to_output_tensor_index), + ) + + @property + def coords(self) -> torch.Tensor: + lats = torch.from_numpy(self._checkpoint._metadata._supporting_arrays["latitudes"]) + lons = torch.from_numpy(self._checkpoint._metadata._supporting_arrays["longitudes"]) + return torch.stack([lats, lons], dim=-1) + + def create(self, *args, **kwargs) -> torch.Tensor: + """Get the perturbation data.""" + var_idx = self.variable_to_output_tensor_index[self.perturbed_variable] + perturbation = torch.zeros(self.output_shape) + + # Get index of the closest point + dists = haversine(self.coords, self.perturbation_location) + closest_idx = torch.where(dists < self.perturbation_radius_km)[0] + + assert len(closest_idx) > 0, "No grid points found within the specified perturbation radius." + + perturbation[..., closest_idx, var_idx] = 1.0 + return perturbation diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index e5ead0c0..86a5e793 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -91,6 +91,7 @@ def __init__( write_initial_state: bool = True, use_profiler: bool = False, typed_variables: dict[str, dict] = {}, + variables_to_perturb: list = [], ) -> None: """Parameters ------------- @@ -146,6 +147,12 @@ def __init__( self.output_frequency = output_frequency self.write_initial_state = write_initial_state self.use_profiler = use_profiler + self.variables_to_perturb = variables_to_perturb + if len(self.variables_to_perturb): + for v in self.variables_to_perturb: + assert ( + v in self.checkpoint._metadata.variable_to_output_tensor_index + ), f"The variable {v} is not in the output state." # For the moment, until we have a better solution self.typed_variables = {k: VariableFromMarsVocabulary(k, v) for k, v in typed_variables.items()} diff --git a/src/anemoi/inference/runners/sensitivities.py b/src/anemoi/inference/runners/sensitivities.py new file mode 100644 index 00000000..76d23698 --- /dev/null +++ b/src/anemoi/inference/runners/sensitivities.py @@ -0,0 +1,254 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +from collections.abc import Generator +from typing import Any +from typing import Callable + +import numpy as np +import torch +from anemoi.utils.dates import frequency_to_timedelta as to_timedelta +from anemoi.utils.timer import Timer + +from anemoi.inference.types import FloatArray +from anemoi.inference.types import State + +from ..perturbation import Perturbation +from ..profiler import ProfilingLabel +from ..profiler import ProfilingRunner +from .simple import SimpleRunner + +LOG = logging.getLogger(__name__) + + +class SensitivitiesRunner(SimpleRunner): + """Sensitivities runner.""" + + def __init__(self, *args: Any, perturb_normalised_space: bool = False, **kwargs: Any) -> None: + """Initialize the SimpleRunner. + + Parameters + ---------- + *args : tuple + Positional arguments. + **kwargs : dict + Keyword arguments. + """ + super().__init__(*args, **kwargs) + self.perturb_normalised_space = perturb_normalised_space + + def wrap_model(self, model: torch.nn.Module) -> Callable: + """Wrap the model to be used for sensitivities.""" + + def model_wrapper(x: torch.Tensor) -> torch.Tensor: + x = x[:, :, None, ...] # add dummy ensemble dimension as 3rd index + x = model.pre_processors(x, in_place=False) + y_hat = model.model(x) + if not self.perturb_normalised_space: + y_hat = model.post_processors(y_hat, in_place=False) + return y_hat + + return model_wrapper + + def perturb_prediction_linearly( + self, output: torch.Tensor, idx: int, perturbation_perc: float = 0.01 + ) -> torch.Tensor: + """Perturb the output.""" + # Use a perturbation of 1% of the forecasted value + pert = torch.zeros_like(output.clone()) + pert[..., idx] = perturbation_perc * output[..., idx] + + return pert + + def predict_step( + self, model: torch.nn.Module, input_tensor_torch: torch.Tensor, perturbation: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: + """Predict sensitivities.""" + model_func = self.wrap_model(model) + + # Compute the sensitivities + input_tensor_torch.requires_grad_(True) + + # This is needed to avoid issues with activation checkpointing. + # The first time the function is called, you may get a checkpointing error. + try: + with torch.enable_grad(): + with torch.autocast(device_type=self.device.type, dtype=self.autocast): + y_pred, t_dx_output = torch.autograd.functional.vjp( + model_func, + input_tensor_torch, + v=perturbation, + create_graph=False, + strict=False, + ) + except torch.utils.checkpoint.CheckpointError: + LOG.warning("Checkpointing error occurred.") + + with torch.enable_grad(): + with torch.autocast(device_type=self.device.type, dtype=self.autocast): + y_pred, t_dx_output = torch.autograd.functional.vjp( + model_func, + input_tensor_torch, + v=perturbation, + create_graph=False, + strict=False, + ) + + return t_dx_output[0, ...] # (time, values, variables) + + def forecast( + self, lead_time: str, input_tensor_numpy: FloatArray, input_state: State, perturbation: Perturbation + ) -> Generator[State, None, None]: + """Forecast the future states. + + Parameters + ---------- + lead_time : str + The lead time. + input_tensor_numpy : FloatArray + The input tensor. + input_state : State + The input state. + + Returns + ------- + Any + The forecasted state. + """ + self.model.eval() + + torch.set_grad_enabled(False) + + # Create pytorch input tensor + input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device) + + lead_time = to_timedelta(lead_time) + + new_state = input_state.copy() # We should not modify the input state + new_state["fields"] = dict() + new_state["step"] = to_timedelta(0) + + start = input_state["date"] + + # The variable `check` is used to keep track of which variables have been updated + # In the input tensor. `reset` is used to reset `check` to False except + # when the values are of the constant in time variables + + reset = np.full((input_tensor_torch.shape[-1],), False) + variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index + typed_variables = self.checkpoint.typed_variables + for variable, i in variable_to_input_tensor_index.items(): + if typed_variables[variable].is_constant_in_time: + reset[i] = True + + check = reset.copy() + + if self.verbosity > 0: + self._print_input_tensor("First input tensor", input_tensor_torch) + + output_perturbation = perturbation.create(self.model).to(self.device) + + for s, (step, date, next_date, is_last_step) in enumerate(self.forecast_stepper(start, lead_time)): + title = f"Forecasting step {step} ({date})" + + new_state["date"] = date + new_state["previous_step"] = new_state.get("step") + new_state["step"] = step + + if self.trace: + self.trace.write_input_tensor( + date, s, input_tensor_torch.cpu().numpy(), variable_to_input_tensor_index, self.checkpoint.timestep + ) + + # Predict next state of atmosphere + with ( + torch.autocast(device_type=self.device.type, dtype=self.autocast), + ProfilingLabel("Predict step", self.use_profiler), + Timer(title), + ): + y_pred = self.predict_step( + self.model, input_tensor_torch, perturbation=output_perturbation, fcstep=s, step=step, date=date + ) + + # Update state + with ProfilingLabel("Updating state (CPU)", self.use_profiler): + for i in range(y_pred.shape[-1]): + new_state["fields"][self.checkpoint.input_tensor_index_to_variable[i]] = y_pred[:, :, i] + + if (s == 0 and self.verbosity > 0) or self.verbosity > 1: + self._print_input_tensor("Sensitivities tensor", y_pred) + + yield new_state + + def run( + self, + *, + input_state: State, + perturbation: Perturbation, + lead_time: str | int | datetime.timedelta, + return_numpy: bool = True, + ) -> Generator[State, None, None]: + """Run the model. + + Parameters + ---------- + input_state : State + The input state. + lead_time : Union[str, int, datetime.timedelta] + The lead time. + return_numpy : bool, optional + Whether to return the output state fields as numpy arrays, by default True. + Otherwise, it will return torch tensors. + + Returns + ------- + Generator[State, None, None] + The forecasted state. + """ + # Shallow copy to avoid modifying the user's input state + input_state = input_state.copy() + input_state["fields"] = input_state["fields"].copy() + + self.constant_forcings_inputs = self.create_constant_forcings_inputs(input_state) + self.dynamic_forcings_inputs = self.create_dynamic_forcings_inputs(input_state) + self.boundary_forcings_inputs = self.create_boundary_forcings_inputs(input_state) + + LOG.info("-" * 80) + LOG.info("Input state:") + LOG.info(f" {list(input_state['fields'].keys())}") + + LOG.info("Constant forcings inputs:") + for f in self.constant_forcings_inputs: + LOG.info(f" {f}") + + LOG.info("Dynamic forcings inputs:") + for f in self.dynamic_forcings_inputs: + LOG.info(f" {f}") + + LOG.info("Boundary forcings inputs:") + for f in self.boundary_forcings_inputs: + LOG.info(f" {f}") + LOG.info("-" * 80) + + lead_time = to_timedelta(lead_time) + + with ProfilingRunner(self.use_profiler): + with ProfilingLabel("Prepare input tensor", self.use_profiler): + input_tensor = self.prepare_input_tensor(input_state) + + try: + yield from self.prepare_output_state( + self.forecast(lead_time, input_tensor, input_state, perturbation), return_numpy + ) + except (TypeError, ModuleNotFoundError, AttributeError): + if self.report_error: + self.checkpoint.report_error() + raise diff --git a/test_XAI.py b/test_XAI.py new file mode 100644 index 00000000..d4e4e675 --- /dev/null +++ b/test_XAI.py @@ -0,0 +1,137 @@ +import datetime +import logging +from collections import defaultdict +from pathlib import Path + +import earthkit.data as ekd +import earthkit.regrid as ekr +import matplotlib.pyplot as plt +import numpy as np +from ecmwf.opendata import Client as OpendataClient + +from anemoi.inference.outputs.printer import print_state +from anemoi.inference.runners.sensitivities import Perturbation +from anemoi.inference.runners.sensitivities import SensitivitiesRunner + +LOGGER = logging.getLogger(__name__) + + +GRID_RESOLUTION = "O96" +PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] +PARAM_SOIL = ["vsw", "sot"] +PARAM_PL = ["gh", "t", "u", "v", "w", "q"] +LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] +SOIL_LEVELS = [1, 2] + +DATE = OpendataClient().latest() + + +def load_state(file) -> dict: + with np.load(file, allow_pickle=False) as data: + fields = {k: data[k] for k in data.files} + state = {"date": datetime.datetime(2025, 8, 29, 6, 0), "fields": fields} + return state + + +def get_open_data(param, levelist=[]): + fields = defaultdict(list) + # Get the data for the current date and the previous date + for date in [DATE - datetime.timedelta(hours=6), DATE]: + data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) + for f in data: + # Open data is between -180 and 180, we need to shift it to 0-360 + assert f.to_numpy().shape == (721, 1440) + values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) + # Interpolate the data to from 0.25 to grid + values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": GRID_RESOLUTION}) + # Add the values to the list + name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") + fields[name].append(values) + + # Create a single matrix for each parameter + for param, values in fields.items(): + fields[param] = np.stack(values) + + return fields + + +def rename_keys(state: dict, mapping: dict) -> dict: + for old_key, new_key in mapping.items(): + state[new_key] = state.pop(old_key) + + return state + + +def transform_GH_to_Z(fields: dict, levels: list[str]) -> dict: + for level in levels: + fields[f"z_{level}"] = fields.pop(f"gh_{level}") * 9.80665 + + return fields + + +def load_current_state() -> dict: + fields = {} + fields.update(get_open_data(param=PARAM_SFC)) + # fields.update(get_open_data(param=PARAM_SOIL,levelist=SOIL_LEVELS)) + fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) + + # fields = rename_keys(fields, {'sot_1': 'stl1', 'sot_2': 'stl2', 'vsw_1': 'swvl1', 'vsw_2': 'swvl2'}) + fields = transform_GH_to_Z(fields, LEVELS) + + return dict(date=DATE, fields=fields) + + +def save_state(state, outfile): + np.savez(outfile, **state["fields"]) + + +def plot_sensitivities(state: dict, field: str): + num_times = state["fields"][field].shape[0] + fig, axs = plt.subplots(num_times, 1, figsize=(6 * num_times, 8)) + + # Get the combined min/max for color normalization + vmin = min(state["fields"][field][0].min(), state["fields"][field][1].min()) + vmax = max(state["fields"][field][0].max(), state["fields"][field][1].max()) + lim = max(abs(vmin), abs(vmax)) + cmap_kwargs = dict(cmap="PuOr", vmin=-lim, vmax=lim) + + for i in range(num_times): + axs[i].set_title(f"{field} (at -{(num_times-i)*6}H)") + axs[i].scatter(state["longitudes"], state["latitudes"], c=state["fields"][field][i], **cmap_kwargs) + + # Remove x and y axes + for ax in axs: + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_xlabel("") + ax.set_ylabel("") + + fig.savefig(f"sensitivities_{field}.png") + + +def main(initial_conditions_file, ckpt: str = {"huggingface": "ecmwf/aifs-single-1.0"}): + # Load initial conditions + if initial_conditions_file.exists(): + input_state = load_state(initial_conditions_file) + LOGGER.info("DATE is wrong") + else: + input_state = load_current_state() + LOGGER.info("State created") + save_state(input_state, initial_conditions_file) + + # Load model + runner = SensitivitiesRunner(ckpt, device="cuda", perturb_normalised_space=True) + + perturbation = Perturbation( + ckpt, perturbed_variable="2t", perturbation_location=(40, 120), perturbation_radius_km=150.0 + ) + + # Compute sensitivities + for state in runner.run(input_state=input_state, perturbation=perturbation, lead_time="6h"): + print_state(state) + plot_sensitivities(state, "2t") + plot_sensitivities(state, "z") + + +if __name__ == "__main__": + main(Path("input_state-o96.npz"), ckpt="../inference-aifs-o96.ckpt")