This notebook accompanies the YieldSAT paper:
YieldSAT: A Multimodal Benchmark Dataset for High-Resolution Crop Yield Prediction
CVPR 2026 — Project Page
to illustrate is ML-readiness. Using the preprocessed xarray.Dataset datastructure, we can build fast and flexible models on multimodal input data.
This tutorial shows how to use the preprocessed Germany release for a small LightGBM baseline in a single-crop setting.
It covers five steps:
Dataset, andThe example uses only Sentinel-2 bands to keep the model and feature tensor small.
xarray for ML Development¶The preprocessed YieldSAT release is the model-ready version of the dataset. It is stored as an xarray.Dataset, which makes it convenient to train machine learning models, inspect metadata, and select subsets by country, crop, year, farm, or field.
In this format, each valid 10 m pixel is represented as a season-aligned temporal sample. The temporal representation contains 24 time steps spanning a two-calendar-year window defined relative to seeding and harvesting, such that the harvest month always lies in the second year. For each time step, the selected observation corresponds to the least-cloudy available image, and observations outside the growing season are masked. All input modalities are aligned via concatenation and temporal and spatial repetition (input fusion), resulting in a standardized tensor representation.
We provide detailed information about the xarray data structure in the data overview Notebook.
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import xarray as xr
from pprint import pprint
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
plt.style.use("seaborn-v0_8-whitegrid")
pd.options.display.max_columns = 200
pd.options.display.width = 120
preprocessed-24-ts/ contains one model-ready NetCDF file per country. For this tutorial, we select the .nc file for Germany because of its smaller size.
The xarray (netcdf) data format is a processed version of the YieldSAT dataset, ready for training DL models at the pixel level. The data uses a unified time series of 24 time steps, including all data modalities. Each Modality is aligned in time and space using concatenation and spatial repetition at the input level (input fusion). The processing is described and referenced in the main paper. We provide the processed data for each country separately. Each country then contains all the crop types available in that country.
For more information on Xarray, see: https://docs.xarray.dev/en/latest/index.html
The netCDF file has the following structure:
Coordinates:
Data Variables
Attributes
DATA_ROOT = Path("") # update this to path where you organize the data
PREPROCESSED_ROOT = DATA_ROOT / "preprocessed-24-ts/"
COUNTRY = "Germany"
TARGET_CROP = "rapeseed"
preprocessed_path = PREPROCESSED_ROOT / COUNTRY / "merged" / "merge_s2-soil-dem-weather-coords.nc"
assert PREPROCESSED_ROOT.exists(), PREPROCESSED_ROOT
assert preprocessed_path.exists(), preprocessed_path
#reading netcdf file
ds = xr.open_dataset(preprocessed_path)
ds
<xarray.Dataset>
Dimensions: (index: 609645, time_step: 24, band: 120)
Coordinates:
* index (index) object '5d35849b-ace1-4dd4-962d-da80c6c56bac' ...
* time_step (time_step) int64 0 1 2 3 4 5 6 ... 17 18 19 20 21 22 23
* band (band) object 'B01' 'B02' 'B03' ... 'coord_y' 'coord_z'
Data variables: (12/17)
target (index) float32 ...
times (index, time_step) datetime64[ns] ...
seeding_date (index) uint8 ...
harvesting_date (index) uint8 ...
farm_identifier (index) uint8 ...
country (index) uint8 ...
... ...
col (index) uint8 ...
stats-mean (band) float32 469.2 566.0 864.8 ... -0.2598 0.443
stats-min (band) float32 0.0 1.0 1.0 ... -0.8557 -0.9154 -0.5302
stats-max (band) float32 8.976e+03 4.396e+03 ... 0.4852 0.9992
stats-std (band) float32 273.9 260.7 327.6 ... 0.2491 0.4512 0.6131
sample (index, time_step, band) float32 ...
Attributes: (12/299)
Germany_DUP3_farm5_field265_rapeseed_2020_<>_yield_ground_truth: 2.892
Germany_DUP3_farm2_field170_wheat_2019_<>_yield_ground_truth: 9.16
Germany_DUP3_farm1_field13_rapeseed_2016_<>_yield_ground_truth: 3.03555...
Germany_DUP3_farm1_field52_wheat_2019_<>_yield_ground_truth: 7.45555...
Germany_DUP3_farm2_field128_wheat_2019_<>_yield_ground_truth: 7.94
Germany_DUP3_farm6_field285_rapeseed_2020_<>_yield_ground_truth: 3.87
... ...
Germany_DUP3_farm5_field269_rapeseed_2017_<>_yield_ground_truth: 2.205
Germany_DUP3_farm2_field105_wheat_2019_<>_yield_ground_truth: 9.84
Germany_DUP3_farm6_field281_rapeseed_2019_<>_yield_ground_truth: 3.74
Germany_DUP3_farm2_field99_wheat_2018_<>_yield_ground_truth: 7.65
Germany_DUP3_farm5_field275_rapeseed_2018_<>_yield_ground_truth: 4.454
Germany_DUP3_farm6_field278_rapeseed_2018_<>_yield_ground_truth: 3.8array(['5d35849b-ace1-4dd4-962d-da80c6c56bac',
'12666e72-ec1d-4dec-9918-daf3671d6007',
'804a0598-e8e3-4528-9acd-05265c6e37bd', ...,
'2143be39-5455-41d2-9fa8-40fd116b46f7',
'af826426-856d-4a3a-856b-0ecf0fa286f3',
'fc0f346e-b273-43cc-b09c-655339046244'], dtype=object)array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23])array(['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11',
'B12', 'B8A', 'aspect', 'cec_0-5', 'cec_0-5_uncertainty', 'cec_100-200',
'cec_100-200_uncertainty', 'cec_15-30', 'cec_15-30_uncertainty',
'cec_30-60', 'cec_30-60_uncertainty', 'cec_5-15',
'cec_5-15_uncertainty', 'cec_60-100', 'cec_60-100_uncertainty',
'cfvo_0-5', 'cfvo_0-5_uncertainty', 'cfvo_100-200',
'cfvo_100-200_uncertainty', 'cfvo_15-30', 'cfvo_15-30_uncertainty',
'cfvo_30-60', 'cfvo_30-60_uncertainty', 'cfvo_5-15',
'cfvo_5-15_uncertainty', 'cfvo_60-100', 'cfvo_60-100_uncertainty',
'clay_0-5', 'clay_0-5_uncertainty', 'clay_100-200',
'clay_100-200_uncertainty', 'clay_15-30', 'clay_15-30_uncertainty',
'clay_30-60', 'clay_30-60_uncertainty', 'clay_5-15',
'clay_5-15_uncertainty', 'clay_60-100', 'clay_60-100_uncertainty',
'curvature', 'dem', 'nitrogen_0-5', 'nitrogen_0-5_uncertainty',
'nitrogen_100-200', 'nitrogen_100-200_uncertainty', 'nitrogen_15-30',
'nitrogen_15-30_uncertainty', 'nitrogen_30-60',
'nitrogen_30-60_uncertainty', 'nitrogen_5-15',
'nitrogen_5-15_uncertainty', 'nitrogen_60-100',
'nitrogen_60-100_uncertainty', 'phh2o_0-5', 'phh2o_0-5_uncertainty',
'phh2o_100-200', 'phh2o_100-200_uncertainty', 'phh2o_15-30',
'phh2o_15-30_uncertainty', 'phh2o_30-60', 'phh2o_30-60_uncertainty',
'phh2o_5-15', 'phh2o_5-15_uncertainty', 'phh2o_60-100',
'phh2o_60-100_uncertainty', 'sand_0-5', 'sand_0-5_uncertainty',
'sand_100-200', 'sand_100-200_uncertainty', 'sand_15-30',
'sand_15-30_uncertainty', 'sand_30-60', 'sand_30-60_uncertainty',
'sand_5-15', 'sand_5-15_uncertainty', 'sand_60-100',
'sand_60-100_uncertainty', 'silt_0-5', 'silt_0-5_uncertainty',
'silt_100-200', 'silt_100-200_uncertainty', 'silt_15-30',
'silt_15-30_uncertainty', 'silt_30-60', 'silt_30-60_uncertainty',
'silt_5-15', 'silt_5-15_uncertainty', 'silt_60-100',
'silt_60-100_uncertainty', 'slope', 'soc_0-5', 'soc_0-5_uncertainty',
'soc_100-200', 'soc_100-200_uncertainty', 'soc_15-30',
'soc_15-30_uncertainty', 'soc_30-60', 'soc_30-60_uncertainty',
'soc_5-15', 'soc_5-15_uncertainty', 'soc_60-100',
'soc_60-100_uncertainty', 'twi', 'temp_mean', 'temp_max', 'temp_min',
'total_prec', 'coord_x', 'coord_y', 'coord_z'], dtype=object)[609645 values with dtype=float32]
[14631480 values with dtype=datetime64[ns]]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
[609645 values with dtype=uint16]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
[609645 values with dtype=uint8]
array([ 4.692109e+02, 5.659758e+02, 8.648016e+02, 8.458782e+02,
1.385048e+03, 2.816638e+03, 3.310711e+03, 3.465321e+03,
3.517868e+03, 1.952083e+03, 1.302678e+03, 3.508458e+03,
1.466668e+02, 2.923672e+02, 1.609534e+02, 1.511189e+02,
1.546581e+02, 1.625051e+02, 1.562471e+02, 1.563665e+02,
1.546317e+02, 1.873578e+02, 1.552501e+02, 1.580394e+02,
1.547470e+02, 8.064277e+01, 1.636329e+02, 1.190424e+02,
2.029094e+02, 7.301254e+01, 1.710131e+02, 8.227399e+01,
1.764964e+02, 8.387760e+01, 1.538846e+02, 9.716603e+01,
1.906530e+02, 1.687633e+02, 1.933106e+02, 2.219411e+02,
1.839360e+02, 1.841485e+02, 1.883025e+02, 2.128548e+02,
1.898220e+02, 1.497745e+02, 1.792004e+02, 2.223051e+02,
1.843334e+02, 1.018857e-02, 1.440826e+02, 6.881603e+02,
1.967202e+02, 1.845898e+02, 4.029062e+02, 2.212720e+02,
2.032493e+02, 2.091076e+02, 2.706213e+02, 2.150289e+02,
1.591017e+02, 2.082806e+02, 3.857381e+02, 6.296049e+01,
1.323334e+02, 6.712882e+01, 1.321708e+02, 6.522775e+01,
1.317232e+02, 6.602905e+01, 1.320717e+02, 6.454297e+01,
1.315354e+02, 6.676228e+01, 1.321646e+02, 5.153842e+02,
1.436389e+02, 4.946397e+02, 1.447188e+02, 5.259023e+02,
1.417046e+02, 5.056294e+02, 1.426902e+02, 5.256458e+02,
1.427869e+02, 4.964501e+02, 1.433246e+02, 3.119841e+02,
1.549032e+02, 2.795531e+02, 1.653452e+02, 2.861007e+02,
1.594743e+02, 2.776399e+02, 1.611905e+02, 3.207123e+02,
1.537937e+02, 2.773682e+02, 1.615185e+02, 3.481279e+03,
5.289557e+02, 1.002775e+02, 7.750367e+01, 8.737428e+01,
1.797713e+02, 3.449158e+01, 1.112004e+02, 5.472574e+01,
2.312487e+02, 3.809411e+01, 8.609125e+01, 7.656245e+01,
4.542884e+00, 8.225069e+03, 8.328023e+03, 8.117643e+03,
5.370276e-02, -3.076400e-01, -2.597631e-01, 4.430459e-01],
dtype=float32)array([ 0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00,
8.900000e+01, 3.840000e+02, 4.490000e+02, 1.000000e+00,
3.400000e+02, 1.080000e+02, 3.400000e+01, 4.770000e+02,
9.107309e-03, 1.400000e+01, 1.700000e+01, 9.000000e+00,
1.800000e+01, 9.000000e+00, 1.900000e+01, 9.000000e+00,
1.700000e+01, 1.000000e+01, 1.900000e+01, 9.000000e+00,
1.900000e+01, 4.000000e+00, 2.300000e+01, 8.000000e+00,
2.900000e+01, 4.000000e+00, 2.300000e+01, 6.000000e+00,
2.300000e+01, 4.000000e+00, 1.700000e+01, 7.000000e+00,
2.700000e+01, 7.000000e+00, 3.300000e+01, 1.000000e+01,
2.900000e+01, 8.000000e+00, 2.600000e+01, 1.000000e+01,
2.600000e+01, 6.000000e+00, 1.800000e+01, 1.000000e+01,
2.700000e+01, -2.219799e+00, 4.232933e+01, 4.200000e+01,
4.500000e+01, 8.000000e+00, 4.900000e+01, 1.000000e+01,
2.300000e+01, 1.100000e+01, 2.500000e+01, 9.000000e+00,
9.000000e+00, 1.100000e+01, 4.600000e+01, 3.000000e+00,
4.000000e+00, 3.000000e+00, 4.000000e+00, 3.000000e+00,
3.000000e+00, 3.000000e+00, 4.000000e+00, 3.000000e+00,
3.000000e+00, 3.000000e+00, 4.000000e+00, 2.600000e+01,
7.000000e+00, 2.200000e+01, 1.000000e+01, 2.300000e+01,
7.000000e+00, 2.200000e+01, 7.000000e+00, 2.600000e+01,
4.000000e+00, 2.200000e+01, 9.000000e+00, 1.400000e+01,
1.300000e+01, 1.500000e+01, 2.200000e+01, 1.700000e+01,
1.300000e+01, 1.600000e+01, 1.400000e+01, 1.500000e+01,
1.300000e+01, 1.600000e+01, 1.800000e+01, 2.690213e+00,
2.900000e+01, 7.000000e+00, 3.000000e+00, 2.000000e+00,
1.000000e+01, 1.000000e+00, 6.000000e+00, 3.000000e+00,
9.000000e+00, 1.000000e+00, 4.000000e+00, 3.000000e+00,
2.189456e+00, 2.815921e+02, 2.842222e+02, 2.782175e+02,
0.000000e+00, -8.556525e-01, -9.153565e-01, -5.301749e-01],
dtype=float32)array([ 8.976000e+03, 4.396000e+03, 5.640000e+03, 7.364000e+03,
8.606000e+03, 9.465000e+03, 1.011000e+04, 8.992000e+03,
1.421200e+04, 7.731000e+03, 8.530000e+03, 1.052300e+04,
3.552538e+02, 3.450000e+02, 3.121600e+04, 2.140000e+02,
3.121600e+04, 2.180000e+02, 3.121600e+04, 2.160000e+02,
3.121600e+04, 2.360000e+02, 3.121600e+04, 2.210000e+02,
3.121600e+04, 1.290000e+02, 3.121700e+04, 3.050000e+02,
3.121900e+04, 1.590000e+02, 3.121700e+04, 2.160000e+02,
3.122000e+04, 1.320000e+02, 3.121600e+04, 2.780000e+02,
3.122000e+04, 2.950000e+02, 3.121800e+04, 3.520000e+02,
3.121800e+04, 3.550000e+02, 3.121800e+04, 3.780000e+02,
3.121800e+04, 2.820000e+02, 3.121700e+04, 3.670000e+02,
3.121800e+04, 2.558861e+00, 3.051469e+02, 9.510000e+02,
3.121900e+04, 5.090000e+02, 3.122600e+04, 5.230000e+02,
3.121600e+04, 5.040000e+02, 3.122100e+04, 6.080000e+02,
3.121500e+04, 5.210000e+02, 3.123500e+04, 6.800000e+01,
3.121500e+04, 7.200000e+01, 3.121500e+04, 7.100000e+01,
3.121500e+04, 7.200000e+01, 3.121500e+04, 7.100000e+01,
3.121500e+04, 7.200000e+01, 3.121500e+04, 7.610000e+02,
3.121500e+04, 7.020000e+02, 3.121500e+04, 7.340000e+02,
3.121500e+04, 6.880000e+02, 3.121500e+04, 7.880000e+02,
3.121500e+04, 6.800000e+02, 3.121500e+04, 5.910000e+02,
3.121600e+04, 4.580000e+02, 3.121600e+04, 4.900000e+02,
3.121600e+04, 4.690000e+02, 3.121600e+04, 6.200000e+02,
3.121600e+04, 4.770000e+02, 3.121600e+04, 4.891014e+04,
9.160000e+02, 2.290000e+02, 3.240000e+02, 9.830000e+02,
4.170000e+02, 2.360000e+02, 3.390000e+02, 4.740000e+02,
7.390000e+02, 1.590000e+02, 3.100000e+02, 7.780000e+02,
7.063821e+00, 7.585111e+04, 7.677048e+04, 7.492453e+04,
4.545932e-01, -3.710826e-02, 4.851567e-01, 9.992239e-01],
dtype=float32)array([2.739096e+02, 2.607090e+02, 3.276268e+02, 5.118132e+02, 4.881274e+02,
9.410225e+02, 1.178405e+03, 1.223790e+03, 1.136179e+03, 7.682332e+02,
7.839894e+02, 1.165684e+03, 9.302200e+01, 2.237377e+01, 9.110311e+02,
1.952165e+01, 9.118627e+02, 1.949120e+01, 9.118493e+02, 1.992290e+01,
9.118406e+02, 1.880487e+01, 9.115914e+02, 2.182832e+01, 9.117225e+02,
1.415783e+01, 9.121185e+02, 5.752878e+01, 9.109675e+02, 3.015131e+01,
9.116788e+02, 4.210875e+01, 9.121594e+02, 1.177039e+01, 9.123146e+02,
5.772824e+01, 9.118611e+02, 5.497548e+01, 9.112198e+02, 5.906892e+01,
9.112420e+02, 7.622435e+01, 9.110255e+02, 6.941881e+01, 9.108978e+02,
6.405766e+01, 9.121707e+02, 6.274215e+01, 9.114283e+02, 3.205083e-01,
5.850031e+01, 1.009564e+02, 9.111879e+02, 6.578802e+01, 9.279523e+02,
5.715844e+01, 9.115829e+02, 6.324229e+01, 9.150939e+02, 5.008537e+01,
9.111169e+02, 6.426334e+01, 9.223598e+02, 3.064360e+00, 9.126016e+02,
3.301334e+00, 9.126320e+02, 3.246355e+00, 9.125950e+02, 2.972308e+00,
9.126135e+02, 3.508766e+00, 9.125985e+02, 3.279095e+00, 9.126285e+02,
1.840106e+02, 9.132842e+02, 1.409348e+02, 9.128356e+02, 1.545039e+02,
9.129053e+02, 1.451634e+02, 9.128641e+02, 1.983783e+02, 9.133834e+02,
1.360968e+02, 9.128299e+02, 1.316137e+02, 9.122123e+02, 8.804158e+01,
9.124551e+02, 8.203354e+01, 9.127570e+02, 8.112008e+01, 9.128556e+02,
1.356809e+02, 9.122052e+02, 7.972389e+01, 9.127299e+02, 2.874282e+03,
8.186103e+01, 3.801859e+01, 4.007310e+01, 5.073972e+01, 4.225420e+01,
1.559865e+01, 4.090974e+01, 2.718931e+01, 5.952565e+01, 2.155230e+01,
3.661815e+01, 3.517658e+01, 9.597474e-01, 5.025724e+03, 5.083611e+03,
4.967352e+03, 4.326866e-02, 2.490765e-01, 4.512463e-01, 6.131307e-01],
dtype=float32)[1755777600 values with dtype=float32]
print(f"Dataset sizes: \n{dict(ds.sizes)} \n")
print("Available crops:")
print({int(k): v for k, v in ds["crop"].attrs.items()})
print(f"\nThe datset contains two crop types (rapeseed and wheat). Since we want to train a single class model ({TARGET_CROP}), we must filter the dataset.")
Dataset sizes:
{'index': 609645, 'time_step': 24, 'band': 120}
Available crops:
{0: 'rapeseed', 1: 'wheat'}
The datset contains two crop types (rapeseed and wheat). Since we want to train a single class model (rapeseed), we must filter the dataset.
ds["crop"] # The crop types are stored in the attrs of the datset and are encoded.
<xarray.DataArray 'crop' (index: 609645)>
[609645 values with dtype=uint8]
Coordinates:
* index (index) object '5d35849b-ace1-4dd4-962d-da80c6c56bac' ... 'fc0f3...
Attributes:
0: rapeseed
1: wheat[609645 values with dtype=uint8]
array(['5d35849b-ace1-4dd4-962d-da80c6c56bac',
'12666e72-ec1d-4dec-9918-daf3671d6007',
'804a0598-e8e3-4528-9acd-05265c6e37bd', ...,
'2143be39-5455-41d2-9fa8-40fd116b46f7',
'af826426-856d-4a3a-856b-0ecf0fa286f3',
'fc0f346e-b273-43cc-b09c-655339046244'], dtype=object)def filter_ds_by_attribute(dataset: xr.Dataset, filter_variable: str="crop", condition=None) -> xr.Dataset:
"""Filter a dataset by attribuites
Args:
-----
dataset: xr dataset
filter_criterion: (str) variable to filter
condition: (str) filter condition
Returns:
--------
filtered xr dataset
"""
mapping = dict((v, k) for k, v in dataset[filter_variable].attrs.items())
filter_val = [int(mapping[condition])]
selected_ind = dataset.data_vars[filter_variable].isin(filter_val)
dataset = dataset.sel(index=selected_ind)
for k in mapping.keys(): # drop other attributes
if k != condition:
del dataset[filter_variable].attrs[mapping[k]]
field_id = dataset["field_shared_name"]
unique_field_id = np.unique(field_id.values) #farms that are still in data
attrs_field_id = [int(v) for v in field_id.attrs.keys()]
fields_to_delete_name = []
for fields_delete in list(set(attrs_field_id) - set(unique_field_id) ):
fields_to_delete_name.append(field_id.attrs[str(fields_delete)])
del field_id.attrs[str(fields_delete)]
for info in dataset.attrs.copy():
if info.split("_<>_")[0] in fields_to_delete_name:
del dataset.attrs[info]
return dataset
#filter dataset by crop type
filtered_ds = filter_ds_by_attribute(ds, filter_variable="crop", condition=TARGET_CROP)
filtered_ds
<xarray.Dataset>
Dimensions: (index: 302802, time_step: 24, band: 120)
Coordinates:
* index (index) object '5d35849b-ace1-4dd4-962d-da80c6c56bac' ...
* time_step (time_step) int64 0 1 2 3 4 5 6 ... 17 18 19 20 21 22 23
* band (band) object 'B01' 'B02' 'B03' ... 'coord_y' 'coord_z'
Data variables: (12/17)
target (index) float32 ...
times (index, time_step) datetime64[ns] ...
seeding_date (index) uint8 ...
harvesting_date (index) uint8 ...
farm_identifier (index) uint8 ...
country (index) uint8 ...
... ...
col (index) uint8 ...
stats-mean (band) float32 469.2 566.0 864.8 ... -0.2598 0.443
stats-min (band) float32 0.0 1.0 1.0 ... -0.8557 -0.9154 -0.5302
stats-max (band) float32 8.976e+03 4.396e+03 ... 0.4852 0.9992
stats-std (band) float32 273.9 260.7 327.6 ... 0.2491 0.4512 0.6131
sample (index, time_step, band) float32 ...
Attributes: (12/111)
Germany_DUP3_farm5_field265_rapeseed_2020_<>_yield_ground_truth: 2.892
Germany_DUP3_farm1_field13_rapeseed_2016_<>_yield_ground_truth: 3.03555...
Germany_DUP3_farm6_field285_rapeseed_2020_<>_yield_ground_truth: 3.87
Germany_DUP3_farm2_field169_rapeseed_2018_<>_yield_ground_truth: 3.59
Germany_DUP3_farm1_field14_rapeseed_2016_<>_yield_ground_truth: 1.35666...
Germany_DUP3_farm6_field290_rapeseed_2021_<>_yield_ground_truth: 2.38
... ...
Germany_DUP3_farm6_field280_rapeseed_2019_<>_yield_ground_truth: 2.87
Germany_DUP3_farm6_field294_rapeseed_2016_<>_yield_ground_truth: 3.97
Germany_DUP3_farm5_field269_rapeseed_2017_<>_yield_ground_truth: 2.205
Germany_DUP3_farm6_field281_rapeseed_2019_<>_yield_ground_truth: 3.74
Germany_DUP3_farm5_field275_rapeseed_2018_<>_yield_ground_truth: 4.454
Germany_DUP3_farm6_field278_rapeseed_2018_<>_yield_ground_truth: 3.8array(['5d35849b-ace1-4dd4-962d-da80c6c56bac',
'12666e72-ec1d-4dec-9918-daf3671d6007',
'804a0598-e8e3-4528-9acd-05265c6e37bd', ...,
'2143be39-5455-41d2-9fa8-40fd116b46f7',
'af826426-856d-4a3a-856b-0ecf0fa286f3',
'fc0f346e-b273-43cc-b09c-655339046244'], dtype=object)array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23])array(['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B11',
'B12', 'B8A', 'aspect', 'cec_0-5', 'cec_0-5_uncertainty', 'cec_100-200',
'cec_100-200_uncertainty', 'cec_15-30', 'cec_15-30_uncertainty',
'cec_30-60', 'cec_30-60_uncertainty', 'cec_5-15',
'cec_5-15_uncertainty', 'cec_60-100', 'cec_60-100_uncertainty',
'cfvo_0-5', 'cfvo_0-5_uncertainty', 'cfvo_100-200',
'cfvo_100-200_uncertainty', 'cfvo_15-30', 'cfvo_15-30_uncertainty',
'cfvo_30-60', 'cfvo_30-60_uncertainty', 'cfvo_5-15',
'cfvo_5-15_uncertainty', 'cfvo_60-100', 'cfvo_60-100_uncertainty',
'clay_0-5', 'clay_0-5_uncertainty', 'clay_100-200',
'clay_100-200_uncertainty', 'clay_15-30', 'clay_15-30_uncertainty',
'clay_30-60', 'clay_30-60_uncertainty', 'clay_5-15',
'clay_5-15_uncertainty', 'clay_60-100', 'clay_60-100_uncertainty',
'curvature', 'dem', 'nitrogen_0-5', 'nitrogen_0-5_uncertainty',
'nitrogen_100-200', 'nitrogen_100-200_uncertainty', 'nitrogen_15-30',
'nitrogen_15-30_uncertainty', 'nitrogen_30-60',
'nitrogen_30-60_uncertainty', 'nitrogen_5-15',
'nitrogen_5-15_uncertainty', 'nitrogen_60-100',
'nitrogen_60-100_uncertainty', 'phh2o_0-5', 'phh2o_0-5_uncertainty',
'phh2o_100-200', 'phh2o_100-200_uncertainty', 'phh2o_15-30',
'phh2o_15-30_uncertainty', 'phh2o_30-60', 'phh2o_30-60_uncertainty',
'phh2o_5-15', 'phh2o_5-15_uncertainty', 'phh2o_60-100',
'phh2o_60-100_uncertainty', 'sand_0-5', 'sand_0-5_uncertainty',
'sand_100-200', 'sand_100-200_uncertainty', 'sand_15-30',
'sand_15-30_uncertainty', 'sand_30-60', 'sand_30-60_uncertainty',
'sand_5-15', 'sand_5-15_uncertainty', 'sand_60-100',
'sand_60-100_uncertainty', 'silt_0-5', 'silt_0-5_uncertainty',
'silt_100-200', 'silt_100-200_uncertainty', 'silt_15-30',
'silt_15-30_uncertainty', 'silt_30-60', 'silt_30-60_uncertainty',
'silt_5-15', 'silt_5-15_uncertainty', 'silt_60-100',
'silt_60-100_uncertainty', 'slope', 'soc_0-5', 'soc_0-5_uncertainty',
'soc_100-200', 'soc_100-200_uncertainty', 'soc_15-30',
'soc_15-30_uncertainty', 'soc_30-60', 'soc_30-60_uncertainty',
'soc_5-15', 'soc_5-15_uncertainty', 'soc_60-100',
'soc_60-100_uncertainty', 'twi', 'temp_mean', 'temp_max', 'temp_min',
'total_prec', 'coord_x', 'coord_y', 'coord_z'], dtype=object)[302802 values with dtype=float32]
[7267248 values with dtype=datetime64[ns]]
[302802 values with dtype=uint8]
[302802 values with dtype=uint8]
[302802 values with dtype=uint8]
[302802 values with dtype=uint8]
array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)
[302802 values with dtype=uint8]
array([264, 264, 264, ..., 277, 277, 277], dtype=uint16)
[302802 values with dtype=uint8]
[302802 values with dtype=uint8]
[302802 values with dtype=uint8]
array([ 4.692109e+02, 5.659758e+02, 8.648016e+02, 8.458782e+02,
1.385048e+03, 2.816638e+03, 3.310711e+03, 3.465321e+03,
3.517868e+03, 1.952083e+03, 1.302678e+03, 3.508458e+03,
1.466668e+02, 2.923672e+02, 1.609534e+02, 1.511189e+02,
1.546581e+02, 1.625051e+02, 1.562471e+02, 1.563665e+02,
1.546317e+02, 1.873578e+02, 1.552501e+02, 1.580394e+02,
1.547470e+02, 8.064277e+01, 1.636329e+02, 1.190424e+02,
2.029094e+02, 7.301254e+01, 1.710131e+02, 8.227399e+01,
1.764964e+02, 8.387760e+01, 1.538846e+02, 9.716603e+01,
1.906530e+02, 1.687633e+02, 1.933106e+02, 2.219411e+02,
1.839360e+02, 1.841485e+02, 1.883025e+02, 2.128548e+02,
1.898220e+02, 1.497745e+02, 1.792004e+02, 2.223051e+02,
1.843334e+02, 1.018857e-02, 1.440826e+02, 6.881603e+02,
1.967202e+02, 1.845898e+02, 4.029062e+02, 2.212720e+02,
2.032493e+02, 2.091076e+02, 2.706213e+02, 2.150289e+02,
1.591017e+02, 2.082806e+02, 3.857381e+02, 6.296049e+01,
1.323334e+02, 6.712882e+01, 1.321708e+02, 6.522775e+01,
1.317232e+02, 6.602905e+01, 1.320717e+02, 6.454297e+01,
1.315354e+02, 6.676228e+01, 1.321646e+02, 5.153842e+02,
1.436389e+02, 4.946397e+02, 1.447188e+02, 5.259023e+02,
1.417046e+02, 5.056294e+02, 1.426902e+02, 5.256458e+02,
1.427869e+02, 4.964501e+02, 1.433246e+02, 3.119841e+02,
1.549032e+02, 2.795531e+02, 1.653452e+02, 2.861007e+02,
1.594743e+02, 2.776399e+02, 1.611905e+02, 3.207123e+02,
1.537937e+02, 2.773682e+02, 1.615185e+02, 3.481279e+03,
5.289557e+02, 1.002775e+02, 7.750367e+01, 8.737428e+01,
1.797713e+02, 3.449158e+01, 1.112004e+02, 5.472574e+01,
2.312487e+02, 3.809411e+01, 8.609125e+01, 7.656245e+01,
4.542884e+00, 8.225069e+03, 8.328023e+03, 8.117643e+03,
5.370276e-02, -3.076400e-01, -2.597631e-01, 4.430459e-01],
dtype=float32)array([ 0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00,
8.900000e+01, 3.840000e+02, 4.490000e+02, 1.000000e+00,
3.400000e+02, 1.080000e+02, 3.400000e+01, 4.770000e+02,
9.107309e-03, 1.400000e+01, 1.700000e+01, 9.000000e+00,
1.800000e+01, 9.000000e+00, 1.900000e+01, 9.000000e+00,
1.700000e+01, 1.000000e+01, 1.900000e+01, 9.000000e+00,
1.900000e+01, 4.000000e+00, 2.300000e+01, 8.000000e+00,
2.900000e+01, 4.000000e+00, 2.300000e+01, 6.000000e+00,
2.300000e+01, 4.000000e+00, 1.700000e+01, 7.000000e+00,
2.700000e+01, 7.000000e+00, 3.300000e+01, 1.000000e+01,
2.900000e+01, 8.000000e+00, 2.600000e+01, 1.000000e+01,
2.600000e+01, 6.000000e+00, 1.800000e+01, 1.000000e+01,
2.700000e+01, -2.219799e+00, 4.232933e+01, 4.200000e+01,
4.500000e+01, 8.000000e+00, 4.900000e+01, 1.000000e+01,
2.300000e+01, 1.100000e+01, 2.500000e+01, 9.000000e+00,
9.000000e+00, 1.100000e+01, 4.600000e+01, 3.000000e+00,
4.000000e+00, 3.000000e+00, 4.000000e+00, 3.000000e+00,
3.000000e+00, 3.000000e+00, 4.000000e+00, 3.000000e+00,
3.000000e+00, 3.000000e+00, 4.000000e+00, 2.600000e+01,
7.000000e+00, 2.200000e+01, 1.000000e+01, 2.300000e+01,
7.000000e+00, 2.200000e+01, 7.000000e+00, 2.600000e+01,
4.000000e+00, 2.200000e+01, 9.000000e+00, 1.400000e+01,
1.300000e+01, 1.500000e+01, 2.200000e+01, 1.700000e+01,
1.300000e+01, 1.600000e+01, 1.400000e+01, 1.500000e+01,
1.300000e+01, 1.600000e+01, 1.800000e+01, 2.690213e+00,
2.900000e+01, 7.000000e+00, 3.000000e+00, 2.000000e+00,
1.000000e+01, 1.000000e+00, 6.000000e+00, 3.000000e+00,
9.000000e+00, 1.000000e+00, 4.000000e+00, 3.000000e+00,
2.189456e+00, 2.815921e+02, 2.842222e+02, 2.782175e+02,
0.000000e+00, -8.556525e-01, -9.153565e-01, -5.301749e-01],
dtype=float32)array([ 8.976000e+03, 4.396000e+03, 5.640000e+03, 7.364000e+03,
8.606000e+03, 9.465000e+03, 1.011000e+04, 8.992000e+03,
1.421200e+04, 7.731000e+03, 8.530000e+03, 1.052300e+04,
3.552538e+02, 3.450000e+02, 3.121600e+04, 2.140000e+02,
3.121600e+04, 2.180000e+02, 3.121600e+04, 2.160000e+02,
3.121600e+04, 2.360000e+02, 3.121600e+04, 2.210000e+02,
3.121600e+04, 1.290000e+02, 3.121700e+04, 3.050000e+02,
3.121900e+04, 1.590000e+02, 3.121700e+04, 2.160000e+02,
3.122000e+04, 1.320000e+02, 3.121600e+04, 2.780000e+02,
3.122000e+04, 2.950000e+02, 3.121800e+04, 3.520000e+02,
3.121800e+04, 3.550000e+02, 3.121800e+04, 3.780000e+02,
3.121800e+04, 2.820000e+02, 3.121700e+04, 3.670000e+02,
3.121800e+04, 2.558861e+00, 3.051469e+02, 9.510000e+02,
3.121900e+04, 5.090000e+02, 3.122600e+04, 5.230000e+02,
3.121600e+04, 5.040000e+02, 3.122100e+04, 6.080000e+02,
3.121500e+04, 5.210000e+02, 3.123500e+04, 6.800000e+01,
3.121500e+04, 7.200000e+01, 3.121500e+04, 7.100000e+01,
3.121500e+04, 7.200000e+01, 3.121500e+04, 7.100000e+01,
3.121500e+04, 7.200000e+01, 3.121500e+04, 7.610000e+02,
3.121500e+04, 7.020000e+02, 3.121500e+04, 7.340000e+02,
3.121500e+04, 6.880000e+02, 3.121500e+04, 7.880000e+02,
3.121500e+04, 6.800000e+02, 3.121500e+04, 5.910000e+02,
3.121600e+04, 4.580000e+02, 3.121600e+04, 4.900000e+02,
3.121600e+04, 4.690000e+02, 3.121600e+04, 6.200000e+02,
3.121600e+04, 4.770000e+02, 3.121600e+04, 4.891014e+04,
9.160000e+02, 2.290000e+02, 3.240000e+02, 9.830000e+02,
4.170000e+02, 2.360000e+02, 3.390000e+02, 4.740000e+02,
7.390000e+02, 1.590000e+02, 3.100000e+02, 7.780000e+02,
7.063821e+00, 7.585111e+04, 7.677048e+04, 7.492453e+04,
4.545932e-01, -3.710826e-02, 4.851567e-01, 9.992239e-01],
dtype=float32)array([2.739096e+02, 2.607090e+02, 3.276268e+02, 5.118132e+02, 4.881274e+02,
9.410225e+02, 1.178405e+03, 1.223790e+03, 1.136179e+03, 7.682332e+02,
7.839894e+02, 1.165684e+03, 9.302200e+01, 2.237377e+01, 9.110311e+02,
1.952165e+01, 9.118627e+02, 1.949120e+01, 9.118493e+02, 1.992290e+01,
9.118406e+02, 1.880487e+01, 9.115914e+02, 2.182832e+01, 9.117225e+02,
1.415783e+01, 9.121185e+02, 5.752878e+01, 9.109675e+02, 3.015131e+01,
9.116788e+02, 4.210875e+01, 9.121594e+02, 1.177039e+01, 9.123146e+02,
5.772824e+01, 9.118611e+02, 5.497548e+01, 9.112198e+02, 5.906892e+01,
9.112420e+02, 7.622435e+01, 9.110255e+02, 6.941881e+01, 9.108978e+02,
6.405766e+01, 9.121707e+02, 6.274215e+01, 9.114283e+02, 3.205083e-01,
5.850031e+01, 1.009564e+02, 9.111879e+02, 6.578802e+01, 9.279523e+02,
5.715844e+01, 9.115829e+02, 6.324229e+01, 9.150939e+02, 5.008537e+01,
9.111169e+02, 6.426334e+01, 9.223598e+02, 3.064360e+00, 9.126016e+02,
3.301334e+00, 9.126320e+02, 3.246355e+00, 9.125950e+02, 2.972308e+00,
9.126135e+02, 3.508766e+00, 9.125985e+02, 3.279095e+00, 9.126285e+02,
1.840106e+02, 9.132842e+02, 1.409348e+02, 9.128356e+02, 1.545039e+02,
9.129053e+02, 1.451634e+02, 9.128641e+02, 1.983783e+02, 9.133834e+02,
1.360968e+02, 9.128299e+02, 1.316137e+02, 9.122123e+02, 8.804158e+01,
9.124551e+02, 8.203354e+01, 9.127570e+02, 8.112008e+01, 9.128556e+02,
1.356809e+02, 9.122052e+02, 7.972389e+01, 9.127299e+02, 2.874282e+03,
8.186103e+01, 3.801859e+01, 4.007310e+01, 5.073972e+01, 4.225420e+01,
1.559865e+01, 4.090974e+01, 2.718931e+01, 5.952565e+01, 2.155230e+01,
3.661815e+01, 3.517658e+01, 9.597474e-01, 5.025724e+03, 5.083611e+03,
4.967352e+03, 4.326866e-02, 2.490765e-01, 4.512463e-01, 6.131307e-01],
dtype=float32)[872069760 values with dtype=float32]
print(f"Filtered Dataset size: \n{dict(filtered_ds.sizes)} \n")
print("Available crops after filterin:")
print({int(k): v for k, v in filtered_ds["crop"].attrs.items()})
Filtered Dataset size:
{'index': 302802, 'time_step': 24, 'band': 120}
Available crops after filterin:
{0: 'rapeseed'}
This cell defines a Pytorch YieldSATDataset class and performs a grouped split by field_id for robust cross-validation. The goal is to prevent leakage from the same field between train/test folds.
from torch.utils.data import Dataset, DataLoader
import torch
# Check if cuda is available for faster GPU training
torch.cuda.is_available()
True
class YieldSATDataset(Dataset):
def __init__(self,
data: str,
fill_value=-1,
band=['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09','B11', 'B12',], #S2 only for demonstration. Include other available features as required
indices=None,
):
super().__init__()
self.data_path = data
self.indices = indices
self.fill_value = fill_value
self.band= band
if isinstance(data, (Path, str)):
assert os.path.exists(data), f"Dataset file {data} not found."
self.ds = xr.open_dataset(data).load()
else:
self.ds = data.load()
self.ds["sample"] = self.ds["sample"].fillna(fill_value)
self.num_features = self.ds.band.values.size
self.len_sequence = self.ds.time_step.values.size
#filter dataset for selected indices for CV training
if self.indices is not None:
self.ds = self.ds.sel(index=self.indices)
#filter dataset for selected bands
if self.band != []:
self.ds = self.ds.sel(band=band)
self.index_ids = list(self.ds.coords["index"].values)
def __len__(self):
return len(self.ds["sample"])
def __getitem__(self, idx):
sample = self.ds['sample'].isel(index=idx).values.astype(np.float32)
target = self.ds['target'].isel(index=idx).values.astype(np.float32)
index = str(self.ds['index'].isel(index=idx).values)
sample = np.nan_to_num(sample, nan=-1.0, posinf=-1.0, neginf=-1.0)
return {"sample": torch.tensor(sample, dtype=torch.float32), "target": torch.tensor(target, dtype=torch.float32), "index": index}
We create a simple dataset splitting for CV training by ensuring that pixels (index) from the same field are either in train or val to avoid information leakage. We can create more advances splittings (e.g., stratifying by region, leave-one-year-out, or leave-one-region-out etc.), to ensure a more realistic training setup.
from sklearn.model_selection import GroupKFold
import random
random.seed(0)
def split_dataset_by_group(data: str, n_splits: int=5, group_key: str=None) -> list:
"""Create train/val splits for CV training based on group-k-fold CV.
This ensures that pixels from the same field are either in train or in val.
Parameters:
-----------
data_path: path to .nc dataset
n_splits: number of folds
group_key : key used for grouping
Returns:
--------
list of splits
"""
if isinstance(data, (str, Path)):
assert os.path.exists(data), f"Dataset file {data} not found."
data = xr.open_dataset(data)
indices = data.coords["index"].values.tolist()
random.shuffle(indices)
groups = data[group_key].sel(index=indices).values if group_key is not None else None
if group_key is not None:
inst = GroupKFold(n_splits=n_splits)
splits = [
(np.take(indices, train_ind).tolist(), np.take(indices, val_ind).tolist())
for train_ind, val_ind in inst.split(indices, groups=groups)
]
for split_i in splits:
assert set(split_i[0]).isdisjoint(set(split_i[1])), "sets overlap"
return splits
splits = split_dataset_by_group(data=filtered_ds, n_splits=2, group_key="field_shared_name")
print(f"Number of splits for CV training: {len(splits)}")
Number of splits for CV training: 2
We define a minimal Pytorch LSTM model and a training pipeline. We will train the model on a Group-K-Fold cross-validation split and evaluate on a seperate validation set using common regression metrics. You can replace it with your model of choice.
class LSTM_Model(torch.nn.Module):
def __init__(self, input_dim, hidden_dim=64, num_layers=1):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
self.fc = torch.nn.Linear(hidden_dim, 1)
def forward(self, x):
out, _ = self.lstm(x)
out = out[:, -1, :]
out = self.fc(out)
return out.squeeze(1)
from sklearn.metrics import mean_squared_error, r2_score
def model_training(model, train_loader, criterion, optimizer, device):
train_loss = 0
for idx, sample_batch in enumerate(train_loader):
sample = sample_batch["sample"].to(device)
target = sample_batch["target"].to(device)
optimizer.zero_grad()
preds = model(sample)
loss = criterion(preds, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * preds.size(0)
total_loss = train_loss / len(train_loader)
print(f"total train loss: {total_loss:.2f}")
def model_validation(model, val_loader, criterion, device, BEST_R2_SCORE):
val_loss = 0
predictions = []
targets = []
indices = []
for idx, sample_batch in enumerate(val_loader):
sample = sample_batch["sample"].to(device)
target = sample_batch["target"].to(device)
index = sample_batch["index"]
with torch.no_grad():
preds = model(sample)
loss = criterion(preds, target)
val_loss += loss.item() * preds.size(0)
predictions.extend(preds.detach().cpu().numpy())
targets.extend(target.detach().cpu().numpy())
indices.extend(index)
total_val_loss = val_loss / len(val_loader)
print(f"total val loss: {total_val_loss:.2f}")
ep_r2 = r2_score(targets, predictions)
if ep_r2 > BEST_R2_SCORE:
BEST_R2_SCORE = ep_r2
print(f"New best R2-Score: {BEST_R2_SCORE:.2f}")
val_df = pd.DataFrame({
"index": indices,
"target": targets,
"prediction": predictions,
})
return val_df, BEST_R2_SCORE
# CV training
n_splits = 2 #Only 2 for demonstration
n_epochs = 15
cv_predictions = []
for cv_i in range(n_splits):
train_indices, val_indices = splits[cv_i]
print(f"xxxxxxx Start CV Run : {cv_i} xxxxxxx")
print("Nr. training indices:", len(train_indices))
print("Nr. validation indices:", len(val_indices))
train_dataset = YieldSATDataset(data=preprocessed_path, indices=train_indices)
val_dataset = YieldSATDataset(data=filtered_ds, indices=val_indices)
num_workers = int(len(os.sched_getaffinity(0)) // 4)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
shuffle=True,
batch_size=1028,
pin_memory=True,
num_workers=num_workers
)
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
shuffle=False,
batch_size=1028,
pin_memory=True,
num_workers=num_workers
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTM_Model(input_dim=12).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
BEST_R2_SCORE = -1e6
for epoch in range(n_epochs):
print(f"------- Epoch: {epoch} -------")
model_training(model=model, train_loader=train_loader, criterion=criterion, optimizer=optimizer, device=device)
val_df, BEST_R2_SCORE = model_validation(
model=model,
val_loader=val_loader,
criterion=criterion,
device=device,
BEST_R2_SCORE=BEST_R2_SCORE,
)
val_df["cv_fold"] = cv_i
val_df["epoch"] = epoch
cv_predictions.append(val_df)
# concatenate per-fold prediction results and save them
cv_predictions_df = pd.concat(cv_predictions, ignore_index=True)
xxxxxxx Start CV Run : 0 xxxxxxx Nr. training indices: 151165 Nr. validation indices: 151637 ------- Epoch: 0 ------- total train loss: 4548.08 total val loss: 2993.71 New best R2-Score: 0.03 ------- Epoch: 1 ------- total train loss: 2006.82 total val loss: 2560.84 New best R2-Score: 0.17 ------- Epoch: 2 ------- total train loss: 1589.69 total val loss: 2476.71 New best R2-Score: 0.20 ------- Epoch: 3 ------- total train loss: 1421.30 total val loss: 2351.50 New best R2-Score: 0.24 ------- Epoch: 4 ------- total train loss: 1337.11 total val loss: 2353.96 ------- Epoch: 5 ------- total train loss: 1266.23 total val loss: 2245.00 New best R2-Score: 0.27 ------- Epoch: 6 ------- total train loss: 1251.52 total val loss: 2358.66 ------- Epoch: 7 ------- total train loss: 1216.00 total val loss: 2286.11 ------- Epoch: 8 ------- total train loss: 1180.61 total val loss: 2580.89 ------- Epoch: 9 ------- total train loss: 1201.00 total val loss: 2339.74 ------- Epoch: 10 ------- total train loss: 1167.83 total val loss: 2292.98 ------- Epoch: 11 ------- total train loss: 1150.64 total val loss: 2372.72 ------- Epoch: 12 ------- total train loss: 1135.26 total val loss: 2458.28 ------- Epoch: 13 ------- total train loss: 1112.53 total val loss: 2362.46 ------- Epoch: 14 ------- total train loss: 1110.14 total val loss: 2267.49 xxxxxxx Start CV Run : 1 xxxxxxx Nr. training indices: 151637 Nr. validation indices: 151165 ------- Epoch: 0 ------- total train loss: 4979.16 total val loss: 2593.18 New best R2-Score: 0.01 ------- Epoch: 1 ------- total train loss: 2586.59 total val loss: 2437.88 New best R2-Score: 0.07 ------- Epoch: 2 ------- total train loss: 2088.74 total val loss: 2281.43 New best R2-Score: 0.13 ------- Epoch: 3 ------- total train loss: 1951.98 total val loss: 2426.43 ------- Epoch: 4 ------- total train loss: 1840.27 total val loss: 2432.72 ------- Epoch: 5 ------- total train loss: 1728.92 total val loss: 2358.15 ------- Epoch: 6 ------- total train loss: 1635.89 total val loss: 2264.06 New best R2-Score: 0.13 ------- Epoch: 7 ------- total train loss: 1563.77 total val loss: 2336.92 ------- Epoch: 8 ------- total train loss: 1519.08 total val loss: 2270.87 ------- Epoch: 9 ------- total train loss: 1496.94 total val loss: 2007.62 New best R2-Score: 0.23 ------- Epoch: 10 ------- total train loss: 1478.89 total val loss: 1964.07 New best R2-Score: 0.25 ------- Epoch: 11 ------- total train loss: 1459.96 total val loss: 1985.62 ------- Epoch: 12 ------- total train loss: 1432.80 total val loss: 2082.75 ------- Epoch: 13 ------- total train loss: 1447.36 total val loss: 1983.69 ------- Epoch: 14 ------- total train loss: 1406.08 total val loss: 1950.60 New best R2-Score: 0.25
#stored results
cv_predictions_df
| index | target | prediction | cv_fold | epoch | |
|---|---|---|---|---|---|
| 0 | 82fbc894-6ce3-462a-ac56-906f053416f8 | 4.460000 | 3.881447 | 0 | 0 |
| 1 | 90542626-23d3-4432-ae67-ced2b7dd93a0 | 4.633333 | 4.266084 | 0 | 0 |
| 2 | 96e1d447-8c55-411b-991c-54d3b7f3b89e | 4.365000 | 4.010424 | 0 | 0 |
| 3 | fd97ab25-00ee-4707-a0cd-d6b70fd41d09 | 7.495000 | 3.897541 | 0 | 0 |
| 4 | fd7e68d8-fa07-41dd-b375-a70e3c18b12d | 3.680000 | 4.050245 | 0 | 0 |
| ... | ... | ... | ... | ... | ... |
| 4542025 | fa0ddbe5-55dc-4ab9-b887-07f1c57da4a3 | 3.020000 | 3.367503 | 1 | 14 |
| 4542026 | d1c87a3f-b690-423e-8ccd-37eddd999518 | 1.524000 | 3.638667 | 1 | 14 |
| 4542027 | 6dce6c67-cc67-45ee-9ed2-4a8f5bb5cacf | 5.525000 | 5.277243 | 1 | 14 |
| 4542028 | bdaa9629-7468-4aac-b54c-8be6a1f9dd94 | 0.666667 | 1.536965 | 1 | 14 |
| 4542029 | 8b01960b-a8eb-4ad1-a171-baddeac026e7 | 4.397778 | 4.782866 | 1 | 14 |
4542030 rows × 5 columns
We can evaluate the predictions at the individual pixel level and over all CV folds.
# Average predictions per index over CV folds
pixel_avg = cv_predictions_df.groupby('index').agg({'prediction': 'mean', 'target': 'first'}).reset_index()
rmse_pixel = np.sqrt(mean_squared_error(pixel_avg['target'], pixel_avg['prediction']))
r2_pixel = r2_score(pixel_avg['target'], pixel_avg['prediction'])
print(f'Pixel-level RMSE: {rmse_pixel:.4f}')
print(f'Pixel-level R2: {r2_pixel:.4f}')
# Scatterplot
plt.figure(figsize=(7,7))
plt.scatter(pixel_avg['target'], pixel_avg['prediction'], alpha=0.5)
max_value = max(np.max(pixel_avg['target']), np.max(pixel_avg['prediction']))
min_value = min(np.min(pixel_avg['target']), np.min(pixel_avg['prediction']), 0)
plt.xlim(left=min_value, right=max_value)
plt.ylim(bottom=min_value, top=max_value)
plt.plot([min_value, max_value], [min_value, max_value], color="k", ls="-")
plt.xlabel('Target (t/ha)')
plt.ylabel('Prediction (t/ha)')
plt.title('Pixel-level: Prediction vs Target')
plt.show()
# Histograms
plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.hist(pixel_avg['prediction'], bins=50, fc=(0.14, 0.38, 0.64, 0.7), label='Prediction (t/ha)')
plt.title('Prediction Distribution')
plt.hist(pixel_avg['target'], bins=50, fc=(0.21, 0.59, 0.44, 0.7), label='Target (t/ha)')
plt.title('Prediction and Target Distribution')
plt.legend()
plt.show()
Pixel-level RMSE: 1.4234 Pixel-level R2: 0.2747
We can also group predictions by each fields and compare against the ground truth fields statistics.
# Load dataset to get field mapping
ds = xr.open_dataset(preprocessed_path)
field_mapping = ds[['index', 'field_shared_name']].to_dataframe().reset_index()
# Merge with pixel_avg to get field_shared_name
pixel_avg = pixel_avg.merge(field_mapping[['index', 'field_shared_name']], on='index', how='left')
# Group by field_shared_name and average predictions and targets
field_avg = pixel_avg.groupby('field_shared_name').agg({'prediction': 'mean', 'target': 'mean'}).reset_index()
# Compute RMSE and R2 at field level
rmse_field = np.sqrt(mean_squared_error(field_avg['target'], field_avg['prediction']))
r2_field = r2_score(field_avg['target'], field_avg['prediction'])
print(f'Field-level RMSE: {rmse_field:.4f}')
print(f'Field-level R2: {r2_field:.4f}')
# Scatterplot
plt.figure(figsize=(8,6))
plt.scatter(field_avg['target'], field_avg['prediction'], alpha=0.8)
max_value = max(np.max(field_avg['target']), np.max(field_avg['prediction']))
min_value = min(np.min(field_avg['target']), np.min(field_avg['prediction']), 0)
plt.xlim(left=min_value, right=max_value)
plt.ylim(bottom=min_value, top=max_value)
plt.plot([min_value, max_value], [min_value, max_value], color="k", ls="-")
plt.xlabel('Target (t/ha)')
plt.ylabel('Prediction (t/ha)')
plt.title('Field-level: Prediction vs Target')
plt.show()
# Histograms
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.hist(field_avg['prediction'], bins=50, fc=(0.14, 0.38, 0.64, 0.7), label='Prediction (t/ha)')
plt.hist(field_avg['target'], bins=50, fc=(0.21, 0.59, 0.44, 0.7), label='Target (t/ha)')
plt.title('Prediction and Target Distribution')
plt.legend()
plt.show()
Field-level RMSE: 0.9666 Field-level R2: 0.4815
Finally, we can also evaluate predictions at the spatial level by evaluate the entire fields of pixel-wise predictions.
# Choose a field to visualize (e.g., the first one)
field_name = field_avg['field_shared_name'].iloc[2]
print(f"Visualizing predictions for field: {field_name}")
# Get pixels for this field
field_pixels = pixel_avg[pixel_avg['field_shared_name'] == field_name]
# Load dataset to get row and col coordinates
ds = xr.open_dataset(preprocessed_path)
field_indices = field_pixels['index'].tolist()
field_ds = ds.sel(index=field_indices)
# Extract row, col, predictions, and targets
rows = field_ds['row'].values
cols = field_ds['col'].values
preds = field_pixels['prediction'].values
targets = field_pixels['target'].values
# Create 2D image arrays
# Shift coordinates to start from 0
min_row = rows.min()
min_col = cols.min()
row_shifted = rows - min_row
col_shifted = cols - min_col
# Determine image dimensions
max_row = int(row_shifted.max()) + 1
max_col = int(col_shifted.max()) + 1
# Initialize images with NaN
pred_image = np.full((max_row, max_col), np.nan)
target_image = np.full((max_row, max_col), np.nan)
mse_image = np.full((max_row, max_col), np.nan)
# Fill in values
for r, c, p, t in zip(row_shifted, col_shifted, preds, targets):
pred_image[int(r), int(c)] = p
target_image[int(r), int(c)] = t
mse_image[int(r), int(c)] = abs(p - t)
# Plot the field predictions, targets, and MSE
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
# Prediction
im1 = axes[0].imshow(pred_image, cmap='viridis', origin='upper')
axes[0].set_title(f'Predictions for Field {field_name}')
axes[0].set_xlabel('Column')
axes[0].set_ylabel('Row')
plt.colorbar(im1, ax=axes[0], label='Prediction (t/ha)')
# Target
im2 = axes[1].imshow(target_image, cmap='viridis', origin='upper')
axes[1].set_title(f'Targets for Field {field_name}')
axes[1].set_xlabel('Column')
axes[1].set_ylabel('Row')
plt.colorbar(im2, ax=axes[1], label='Target (t/ha)')
# MSE
im3 = axes[2].imshow(mse_image, cmap='plasma', origin='upper')
axes[2].set_title(f'Pixel-wise MSE for Field {field_name}')
axes[2].set_xlabel('Column')
axes[2].set_ylabel('Row')
plt.colorbar(im3, ax=axes[2], label='MAE')
plt.tight_layout()
plt.show()
Visualizing predictions for field: 4