from __future__ import annotations
import concurrent.futures
import logging
import os
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from hydrobricks._exceptions import ConfigurationError, DataError, DependencyError
from hydrobricks._optional import (
HAS_NETCDF,
HAS_RASTERIO,
HAS_RIOXARRAY,
pyproj,
rasterio,
rxr,
xr,
)
from hydrobricks._utils import date_as_mjd
logger = logging.getLogger(__name__)
[docs]
class TimeSeries:
"""Class for generic time series data"""
def __init__(self) -> None:
"""Initialize TimeSeries with empty time and data containers."""
self.time: list[Any] | pd.Series | pd.DatetimeIndex = []
self.data: list[np.ndarray] = []
self.data_name: list[str] = []
[docs]
def get_dates_as_mjd(self) -> float | np.ndarray:
"""
Convert time series dates to modified Julian dates.
Returns
-------
float | np.ndarray
Modified Julian dates. Returns float if single date,
array if multiple dates.
"""
return date_as_mjd(self.time)
class TimeSeries1D(TimeSeries):
"""Class for generic 1D time series data"""
def __init__(self) -> None:
"""Initialize 1D TimeSeries."""
super().__init__()
def load_from_csv(
self,
path: str | Path,
column_time: str,
time_format: str,
content: dict[str, str],
start_date: datetime | pd.Timestamp | None = None,
end_date: datetime | pd.Timestamp | None = None,
) -> None:
"""
Read time series data from CSV file.
Parameters
----------
path
Path to the CSV file containing hydro units data.
column_time
Column name containing the time values.
time_format
Format string for parsing time values (e.g., '%Y-%m-%d').
content
Dictionary mapping variable names/enums to column names in the CSV.
Example: {'precipitation': 'Precipitation (mm)',
'temperature': 'Temperature (C)'}
start_date
Start date of the time series (used to select the period of interest).
If None, the first date of the file is used.
end_date
End date of the time series (used to select the period of interest).
If None, the last date of the file is used.
Raises
------
FileNotFoundError
If the specified file does not exist.
KeyError
If required columns are not found in the CSV file.
"""
file_content = pd.read_csv(
path, parse_dates=[column_time], date_format=time_format
)
if start_date and end_date:
file_content = file_content.loc[
(file_content[column_time] >= start_date)
& (file_content[column_time] <= end_date)
]
self.time = file_content[column_time]
for col in content:
self.data_name.append(col)
self.data.append(file_content[content[col]].to_numpy())
class TimeSeries2D(TimeSeries):
"""Class for generic 2D time series data"""
def __init__(self) -> None:
"""Initialize 2D TimeSeries."""
super().__init__()
def regrid_from_netcdf(
self,
path: str | Path,
file_pattern: str | None = None,
data_crs: int | None = None,
var_name: str | None = None,
dim_time: str = "time",
dim_x: str = "x",
dim_y: str = "y",
hydro_units: Any | None = None,
raster_hydro_units: str | Path | None = None,
weights_block_size: int = 100,
apply_data_gradient: bool = True,
gradient_type: str = "additive",
dem_path: str | Path | None = None,
) -> None:
"""
Regrid time series data from netCDF files. The spatialization is done using a
raster of hydro unit IDs. The meteorological data is resampled to the DEM
resolution.
Parameters
----------
path
Path to a netCDF file containing the data or to a folder containing
multiple files.
file_pattern
Glob pattern of the files to read (e.g., '*.nc'). If None, the path is
considered to be a single file.
data_crs
CRS of the netCDF file (as EPSG code).
If None, the CRS is read from the file.
var_name
Name of the variable to read from the netCDF file.
dim_time
Name of the time dimension. Default: 'time'
dim_x
Name of the x/longitude dimension. Default: 'x'
dim_y
Name of the y/latitude dimension. Default: 'y'
hydro_units
HydroUnits object containing the hydro units to use for the spatialization.
Needed if apply_data_gradient is True.
raster_hydro_units
Path to a raster file containing the hydro unit IDs to use for the
spatialization.
weights_block_size
Size of the block of time steps to use for weight computation.
Default: 100
apply_data_gradient
If True, elevation-based gradients will be retrieved from the data and
applied to the hydro units (e.g., for temperature and precipitation).
If False, the data will be regridded without applying any gradient.
Default: True
gradient_type
Type of gradient to apply: 'additive' or 'multiplicative'.
Default: 'additive'
dem_path
Path to DEM raster file for spatialization (gradient computation).
Needed if apply_data_gradient is True.
Raises
------
ImportError
If required optional dependencies (rasterio, rioxarray, netCDF4)
are not installed.
ValueError
If raster_hydro_units is not provided, or if time/spatial dimensions
don't match.
"""
if not HAS_RASTERIO:
raise DependencyError(
"rasterio is required for regridding from netCDF.",
package_name="rasterio",
operation="TimeSeries2D.regrid_from_netcdf",
install_command="pip install rasterio",
)
if not HAS_RIOXARRAY:
raise DependencyError(
"rioxarray is required for regridding from netCDF.",
package_name="rioxarray",
operation="TimeSeries2D.regrid_from_netcdf",
install_command="pip install rioxarray",
)
if not HAS_NETCDF:
raise DependencyError(
"netCDF4 is required for regridding from netCDF.",
package_name="netCDF4",
operation="TimeSeries2D.regrid_from_netcdf",
install_command="pip install netCDF4",
)
if raster_hydro_units is None:
raise DataError(
"You must provide a raster of the hydro units.",
data_type="raster hydro units",
reason="Missing required data",
)
# Get unit ids
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning) # pyproj
with rxr.open_rasterio(raster_hydro_units) as _raw:
unit_ids = _raw.squeeze().drop_vars("band").load()
logger.debug(
f"Starting regridding from netCDF: "
f"apply_data_gradient={apply_data_gradient}"
)
# Get netCDF dataset
logger.debug(f"Reading netcdf file(s) from {path}...")
if file_pattern is None:
nc_data = xr.open_dataset(path, chunks={})
else:
files = sorted(Path(path).glob(file_pattern))
logger.debug(f"Found {len(files)} files matching pattern '{file_pattern}'")
nc_data = xr.open_mfdataset(files, chunks={})
# Get CRS of the netcdf file
data_crs = self._parse_crs(nc_data, data_crs)
logger.debug(f"NetCDF CRS: {data_crs}")
# Get CRS of the unit ids raster
unit_ids_crs = self._parse_crs(unit_ids, None)
logger.debug(f"Raster CRS: {unit_ids_crs}")
if data_crs != unit_ids_crs:
logger.warning(
"The CRS of the netcdf file does not match the CRS of the "
"hydro unit ids raster. Reprojection will be done from "
f"{unit_ids_crs} to {data_crs}."
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning) # pyproj
unit_ids = unit_ids.rio.reproject(f"epsg:{data_crs}")
# Get list of hydro unit ids
unit_ids_list = hydro_units["id"].values.squeeze()
unit_id_count = len(unit_ids_list)
logger.debug(f"Processing {unit_id_count} hydro units")
# Check if the file has the dimension 'day_of_year'
time_method = None
if "day_of_year" in nc_data.dims:
time_method = "day_of_year"
day_of_year = nc_data.variables["day_of_year"][:]
logger.debug(f"Using day_of_year time method with {len(day_of_year)} days")
if len(self.time) == 0:
raise DataError(
"Other forcing data with a full temporal array have "
"to be loaded and spatialized before data based "
"on 'day_of_year'.",
data_type="time series",
reason="Missing preceding forcing data",
)
else:
time_method = "full"
time_nc = nc_data.variables[dim_time][:]
logger.debug(f"Using full time series with {len(time_nc)} time steps")
if len(self.time) == 0:
self.time = pd.Series(time_nc)
# Check if the time steps are the same
if len(self.time) != len(time_nc):
raise DataError(
f"The length of the netcdf time series ({len(time_nc)}) "
f"does not match the hydro units data ({len(self.time)}).",
data_type="time series",
reason="Mismatched time series length",
)
if self.time[0] != time_nc[0]:
raise DataError(
f"The first time step of the netcdf time series "
f"({time_nc[0].data}) does not match the one from the "
f"hydro units data ({self.time[0]}).",
data_type="time series",
reason="Mismatched start date",
)
if self.time[len(self.time) - 1] != time_nc[len(time_nc) - 1]:
raise DataError(
f"The last time step of the netcdf time series "
f"({time_nc[len(time_nc) - 1].data}) does not match "
f"the one from the hydro units data "
f"({self.time[len(self.time) - 1]}).",
data_type="time series",
reason="Mismatched end date",
)
# Extract the unit id masks
unit_id_masks = []
for unit_id in unit_ids_list:
unit_id_mask = xr.where(unit_ids == unit_id, 1, 0)
unit_id_masks.append(unit_id_mask)
# Initialize data array
data = np.zeros((len(self.time), unit_id_count))
self.data.append(data)
# Drop other variables
other_coords = [
v
for v in nc_data.coords
if v not in [dim_time, dim_x, dim_y, "day_of_year"]
]
nc_data = nc_data.drop_vars(other_coords)
# Extract variable
data_var = nc_data[var_name]
# Specify the CRS if not specified
if data_var.rio.crs is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning) # pyproj
data_var.rio.write_crs(f"epsg:{data_crs}", inplace=True)
# Open the DEM if needed
dem = None
if apply_data_gradient:
with rxr.open_rasterio(dem_path) as _raw:
dem = _raw.squeeze().drop_vars("band").load()
# Get the spatial extent of interest
ref_data = dem if apply_data_gradient else unit_ids
data_var = self._select_relevant_extent(
data_var, data_crs, dim_x, dim_y, ref_data
)
# Rename spatial dimensions
if dim_x != "x":
data_var = data_var.rename({dim_x: "x"})
if dim_y != "y":
data_var = data_var.rename({dim_y: "y"})
# Time the computation
start_time = time.time()
num_threads = os.cpu_count()
time_len = len(self.time)
if time_method == "day_of_year":
time_len = len(day_of_year)
if time_len != 366:
raise DataError(
f"The time series based on 'day_of_year' must have a length of 366 "
f"(got {time_len}).",
data_type="time series",
reason="Invalid time series length for day_of_year",
)
# If we are using the 'apply_data_gradient' option, we need to first compute
# the reference elevation for each data grid cell.
dem_data, dem_dx, dem_dy, hu_elevation = None, None, None, None
if apply_data_gradient:
data_grid = data_var[0].copy()
dem_reproj = dem.rio.reproject_match(
data_grid,
Resampling=rasterio.enums.Resampling.average,
nodata=np.nan,
)
dem_data = dem_reproj.values
# Compute the gradient of the DEM along the x and y axes
dem_dx = dem_reproj.diff("x")
dem_dy = dem_reproj.diff("y")
# Replace small values (<50m) with NaN to avoid irrelevant gradients
dem_dx = xr.where(np.abs(dem_dx) < 50, np.nan, dem_dx).compute()
dem_dy = xr.where(np.abs(dem_dy) < 50, np.nan, dem_dy).compute()
# If both gradients contain more than 60% NaN values, raise a warning
if (
dem_dx.isnull().sum() / dem_dx.size > 0.6
and dem_dy.isnull().sum() / dem_dy.size > 0.6
):
logger.warning(
"More than 60% of the DEM gradients are too small. "
"Defaulting to apply_data_gradient=False."
)
apply_data_gradient = False
# Extract the elevation for each hydro unit
hu_elevation = hydro_units["elevation"].to_numpy().squeeze()
# Create a xarray variable containing the data cell indices
data_idx = data_var[0].copy()
data_idx.values = np.arange(data_idx.size).reshape(data_idx.shape)
data_idx = data_idx.astype(float)
assert len(data_idx.shape) == 2
assert data_idx.shape[0] > 0
assert data_idx.shape[1] > 0
# Reproject the data cell indices to the hydro unit raster
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning) # pyproj
data_idx.rio.write_crs(f"epsg:{data_crs}", inplace=True)
data_idx_reproj = data_idx.rio.reproject_match(
unit_ids, Resampling=rasterio.enums.Resampling.nearest
)
# Create the masks (with the original data shape) for each unit with the
# weights to apply to the gridded data contributing to the unit
unit_weights = []
for u in range(unit_id_count):
# Get the data indices contributing to the unit
mask_unit_id = xr.where(unit_id_masks[u], data_idx_reproj, -1)
mask_unit_id = mask_unit_id.to_numpy().astype(int)
# Get unique values and their counts
data_idx_values, counts = np.unique(
mask_unit_id[mask_unit_id >= 0], return_counts=True
)
# Create a mask of the weights
weights_mask = np.zeros(data_idx.shape)
data_idx_values = np.unravel_index(data_idx_values, data_idx.shape)
weights_mask[data_idx_values] = counts / np.sum(counts)
assert np.isclose(np.sum(weights_mask), 1)
# Add the mask to the list
unit_weights.append(weights_mask)
n_steps = 1 + np.ceil(time_len / weights_block_size).astype(int)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit the tasks for each time step to the executor
if apply_data_gradient:
futures = [
executor.submit(
self._extract_time_step_data_weights_with_gradient,
data_var,
unit_weights,
t_block,
weights_block_size,
gradient_type,
dem_data,
dem_dx,
dem_dy,
hu_elevation,
)
for t_block in range(n_steps)
]
else:
futures = [
executor.submit(
self._extract_time_step_data_weights,
data_var,
unit_weights,
t_block,
weights_block_size,
)
for t_block in range(n_steps)
]
# Wait for all tasks to complete
concurrent.futures.wait(futures)
# If the time method is 'day_of_year', convert to the full time series
if time_method == "day_of_year":
logger.info("Converting to the full time series...")
# Get the indices of jd_unique that match the values of jd
jd = self.time.dt.strftime("%j").to_numpy().astype(int)
indices = np.searchsorted(day_of_year, jd)
whole_daily_pot_radiation = self.data[-1][indices, :]
self.data[-1] = whole_daily_pot_radiation
# Print elapsed time
elapsed_time = time.time() - start_time
logger.info(
f"Elapsed time: {elapsed_time:.2f} seconds "
f"(using {num_threads} threads)"
)
def _extract_time_step_data_weights_with_gradient(
self,
data_var: xr.DataArray,
unit_weights: list,
i_block: int,
block_size: int,
gradient_type: str,
dem: np.ndarray,
dem_dx: xr.DataArray,
dem_dy: xr.DataArray,
hu_elevation: np.ndarray,
) -> None:
"""
Extract time step data with elevation-based gradient corrections.
Applies elevation-based gradients to gridded meteorological data to account
for the effect of elevation differences on temperature, precipitation, etc.
Processes data in blocks to optimize memory usage.
Parameters
----------
data_var
3D xarray DataArray with dimensions (time, y, x) containing gridded data.
unit_weights
List of 2D weight arrays for each hydro unit, summing to 1.
i_block
Block index for this processing step.
block_size
Number of time steps to process in this block.
gradient_type
Type of gradient to apply: 'additive' or 'multiplicative'.
dem
2D array of digital elevation model data.
dem_dx
xarray DataArray with DEM gradients in x direction.
dem_dy
xarray DataArray with DEM gradients in y direction.
hu_elevation
1D array of elevation for each hydro unit.
"""
i_start = i_block * block_size
i_end = min((i_block + 1) * block_size, len(self.time))
i_end = min(i_end, data_var.shape[0])
if i_start >= len(self.time):
return
logger.debug(f"Extracting {self.time[i_start]}")
if gradient_type == "additive":
dat_dx = data_var[i_start:i_end].diff("x")
dat_dy = data_var[i_start:i_end].diff("y")
elif gradient_type == "multiplicative":
data_xr = data_var[i_start:i_end]
dat_dx = data_xr.diff("x")
dat_dy = data_xr.diff("y")
zero_mask = np.abs(data_xr) < 1e-10
with np.errstate(invalid="ignore", divide="ignore"):
dat_dx_np = np.where(
zero_mask.isel(x=slice(1, None)).values,
0,
dat_dx.values / data_xr.isel(x=slice(1, None)).values,
)
dat_dy_np = np.where(
zero_mask.isel(y=slice(1, None)).values,
0,
dat_dy.values / data_xr.isel(y=slice(1, None)).values,
)
dat_dx.data = dat_dx_np
dat_dy.data = dat_dy_np
else:
raise ConfigurationError(
f"Unknown gradient type: {gradient_type}. "
"Use 'additive' or 'multiplicative'.",
item_name="gradient_type",
item_value=gradient_type,
)
# Compute the gradient of the data along the x and y axes
dat_dx = dat_dx / dem_dx
dat_dy = dat_dy / dem_dy
# Fill NaN values
dat_dx = self._fill_nan_gradients(dat_dx)
dat_dy = self._fill_nan_gradients(dat_dy)
# Compute the rolling mean of the gradients
dat_dx = dat_dx.rolling(x=3, y=3, center=True, min_periods=1).mean()
dat_dy = dat_dy.rolling(x=3, y=3, center=True, min_periods=1).mean()
# Combine the gradients into a single gradient
dat_dxy = self._mean_xy_gradient(dat_dx, dat_dy).to_numpy()
if dat_dxy.shape[1:] != data_var.shape[1:]:
raise DataError(
f"The shape of the data gradient ({dat_dxy.shape[1:]}) does not match "
f"the shape of the data variable ({data_var.shape[1:]}).",
data_type="time series",
reason="Shape mismatch",
)
if dem.shape != data_var.shape[1:]:
raise DataError(
f"The shape of the DEM ({dem.shape}) does not match the "
f"shape of the data variable ({data_var.shape[1:]}).",
data_type="time series",
reason="Shape mismatch",
)
# Extract data for each unit
data_array = data_var[i_start:i_end].to_numpy()
for u, unit_weight in enumerate(unit_weights):
data = data_array[:, unit_weight > 0]
weights = unit_weight[unit_weight > 0]
dh = hu_elevation[u] - dem[unit_weight > 0]
dh[np.isnan(dh)] = 0
grads = dat_dxy[:, unit_weight > 0]
if gradient_type == "additive":
data = data + grads * dh
elif gradient_type == "multiplicative":
grads_dh = grads * dh
grads_dh[grads_dh < -1] = -1
data = data * (1 + grads_dh)
else:
raise ConfigurationError(
f"Unsupported gradient type: {gradient_type}",
item_name="gradient_type",
item_value=gradient_type,
reason="Invalid gradient type",
)
self.data[-1][i_start:i_end, u] = np.nansum(data * weights, axis=1)
def _extract_time_step_data_weights(
self, data_var: xr.DataArray, unit_weights: list, i_block: int, block_size: int
) -> None:
"""
Extract time step data and apply spatial weights.
Extracts meteorological data for a block of time steps and applies weighted
averaging based on the spatial distribution of data cells within each
hydro unit.
Parameters
----------
data_var
3D xarray DataArray with dimensions (time, y, x) containing gridded data.
unit_weights
List of 2D weight arrays for each hydro unit, summing to 1.
i_block
Block index for this processing step.
block_size
Number of time steps to process in this block.
"""
i_start = i_block * block_size
i_end = min((i_block + 1) * block_size, len(self.time))
i_end = min(i_end, data_var.shape[0])
if i_start >= len(self.time):
return
logger.debug(f"Extracting {self.time[i_start]}")
# Extract data for each unit
for u, unit_weight in enumerate(unit_weights):
# Mask the meteorological data with the unit weights.
self.data[-1][i_start:i_end, u] = np.nansum(
data_var[i_start:i_end].to_numpy() * unit_weight, axis=(1, 2)
)
def _select_relevant_extent(
self,
data_var: xr.DataArray,
data_crs: int,
dim_x: str,
dim_y: str,
ref_data: xr.DataArray | xr.Dataset,
) -> xr.DataArray:
"""
Select the spatial extent of gridded data relevant to the reference data.
Clips the input data to the bounding box of the reference data
(DEM or hydro units), handling CRS transformations if necessary.
Parameters
----------
data_var
The gridded data variable to clip.
data_crs
CRS of the gridded data (as EPSG code).
dim_x
Name of the x/longitude dimension in data_var.
dim_y
Name of the y/latitude dimension in data_var.
ref_data
Reference dataset (DEM or hydro units raster) to determine extent.
Returns
-------
xr.DataArray
Clipped data variable containing only the relevant spatial extent.
"""
x_ref_min, x_ref_max, y_ref_min, y_ref_max = self._get_spatial_bounds(ref_data)
# Convert the spatial extent to the data CRS
src_crs = self._parse_crs(ref_data)
if src_crs != data_crs:
transformer = pyproj.Transformer.from_crs(src_crs, data_crs, always_xy=True)
x_min_dat, y_min_dat = transformer.transform(x_ref_min, y_ref_min)
x_max_dat, y_max_dat = transformer.transform(x_ref_max, y_ref_max)
else:
x_min_dat, y_min_dat = x_ref_min, y_ref_min
x_max_dat, y_max_dat = x_ref_max, y_ref_max
# Find the coordinates that cover the extent
x_coords = data_var[dim_x].values
y_coords = data_var[dim_y].values
x_reversed = False
if x_coords[0] > x_coords[1]: # Decreasing
x_coords = x_coords[::-1]
x_reversed = True
x_start_idx = np.searchsorted(x_coords, x_min_dat, side="right") - 1
x_end_idx = np.searchsorted(x_coords, x_max_dat, side="left")
y_reversed = False
if y_coords[0] > y_coords[1]: # Decreasing
y_coords = y_coords[::-1]
y_reversed = True
y_start_idx = np.searchsorted(y_coords, y_min_dat, side="right") - 1
y_end_idx = np.searchsorted(y_coords, y_max_dat, side="left")
x_start = x_coords[max(x_start_idx, 0)]
x_end = x_coords[min(x_end_idx, len(x_coords) - 1)]
y_start = y_coords[max(y_start_idx, 0)]
y_end = y_coords[min(y_end_idx, len(y_coords) - 1)]
if x_reversed:
x_sel = slice(max(x_start, x_end), min(x_start, x_end))
else:
x_sel = slice(min(x_start, x_end), max(x_start, x_end))
if y_reversed:
y_sel = slice(max(y_start, y_end), min(y_start, y_end))
else:
y_sel = slice(min(y_start, y_end), max(y_start, y_end))
data_var = data_var.sel({dim_x: x_sel, dim_y: y_sel})
return data_var
@staticmethod
def _parse_crs(data: xr.DataArray | xr.Dataset, file_crs: int | None = None) -> int:
"""
Extract CRS information from xarray data.
Attempts to retrieve CRS from multiple sources: explicit parameter,
data attributes, or rioxarray crs property.
Raises error if CRS cannot be determined.
Parameters
----------
data
xarray DataArray or Dataset to extract CRS from.
file_crs
Explicit CRS as EPSG code. If provided, this value is returned directly.
Returns
-------
int
CRS as EPSG code.
Raises
------
DataError
If no CRS is found and file_crs is not provided.
"""
if file_crs is None:
if "crs" in data.attrs:
# Try to get it from the global attributes
return data.attrs["crs"]
elif data.rio.crs:
# Try to get it from the rio crs
return data.rio.crs.to_epsg()
else:
raise DataError(
"Could not determine the CRS from the data."
"Please provide a CRS (option 'file_crs').",
data_type="spatial data",
reason="Missing CRS information",
)
return file_crs
@staticmethod
def _get_spatial_bounds(ref_data: xr.DataArray | xr.Dataset) -> tuple:
"""
Extract spatial bounds from xarray data.
Determines the minimum and maximum coordinates in x and y dimensions,
automatically detecting dimension names (x/lon/longitude, y/lat/latitude).
Parameters
----------
ref_data
xarray DataArray or Dataset containing spatial data.
Returns
-------
tuple
Tuple of (x_min, x_max, y_min, y_max) spatial bounds.
Raises
------
ValueError
If spatial dimensions cannot be found in the data.
"""
# Possible names for spatial dimensions
x_names = ["x", "lon", "longitude"]
y_names = ["y", "lat", "latitude"]
# Find the actual dimension names
x_dim = next((name for name in x_names if name in ref_data.dims), None)
y_dim = next((name for name in y_names if name in ref_data.dims), None)
if x_dim is None or y_dim is None:
raise DataError(
f"Could not find spatial dimensions in the reference data. "
f"Available dimensions: {list(ref_data.dims)}",
data_type="spatial data",
reason="Missing x/y or lon/lat dimensions",
)
x_ref_min = ref_data[x_dim].min().item()
x_ref_max = ref_data[x_dim].max().item()
y_ref_min = ref_data[y_dim].min().item()
y_ref_max = ref_data[y_dim].max().item()
return x_ref_min, x_ref_max, y_ref_min, y_ref_max
@staticmethod
def _fill_nan_gradients(dat: xr.DataArray) -> xr.DataArray:
"""
Fill NaN values in gradient arrays through interpolation and edge extension.
Uses linear interpolation along x and y dimensions, then extends edge values
to replace any remaining NaN values. Processes up to two iterations to ensure
all NaN values are filled.
Parameters
----------
dat
xarray DataArray containing gradient data with potential NaN values.
Returns
-------
xr.DataArray
DataArray with NaN values filled using interpolation and edge extension.
Notes
-----
This function modifies the input array in-place for efficiency.
"""
# Fill NaN values in the gradients (loop twice)
for _ in range(2):
if np.isnan(dat).any():
# Interpolate NaN values in the gradients
dat = dat.interpolate_na(dim="x").interpolate_na(dim="y")
# Replace NaN values at the edges
arr = dat.values
x_ax = dat.get_axis_num("x")
y_ax = dat.get_axis_num("y")
def _replace_edge(axis, idx, neighbor_idx, arr=arr):
target = [slice(None)] * arr.ndim
neighbor = [slice(None)] * arr.ndim
target[axis] = idx
neighbor[axis] = neighbor_idx
m = np.isnan(arr[tuple(target)])
arr[tuple(target)][m] = arr[tuple(neighbor)][m]
_replace_edge(x_ax, 0, 1) # left
_replace_edge(x_ax, -1, -2) # right
_replace_edge(y_ax, 0, 1) # bottom
_replace_edge(y_ax, -1, -2) # top
dat.data = arr
return dat
@staticmethod
def _mean_xy_gradient(dat_dx: xr.DataArray, dat_dy: xr.DataArray) -> xr.DataArray:
"""
Compute the mean of x and y gradients on a common grid.
Creates a new grid whose 'x' coordinate comes from dat_dy and 'y' coordinate
comes from dat_dx. Reindexes both arrays to this grid using nearest neighbour
for missing edges. Returns the mean.
Parameters
----------
dat_dx
xarray DataArray containing gradients in x direction.
Must have 'x' and 'y' dimensions.
dat_dy
xarray DataArray containing gradients in y direction.
Must have 'x' and 'y' dimensions.
Returns
-------
xr.DataArray
Mean of the x and y gradients on the common grid.
Raises
------
ValueError
If required 'x' and 'y' dimensions are missing from input arrays.
"""
# Ensure both have x and y dims
if "x" not in dat_dx.dims or "y" not in dat_dx.dims:
raise DataError(
f"dat_dx must have 'x' and 'y' dimensions. "
f"Got dimensions: {dat_dx.dims}",
data_type="time series",
reason="Missing x or y dimension",
)
if "x" not in dat_dy.dims or "y" not in dat_dy.dims:
raise DataError(
f"dat_dy must have 'x' and 'y' dimensions. "
f"Got dimensions: {dat_dy.dims}",
data_type="time series",
reason="Missing x or y dimension",
)
# Target coordinates
coords = {}
for dim in dat_dx.dims:
if dim == "x":
coords["x"] = dat_dy["x"]
elif dim == "y":
coords["y"] = dat_dx["y"]
else:
coords[dim] = dat_dx[dim] if dim in dat_dx.coords else dat_dy[dim]
# Create target grid
dat_dxy = xr.DataArray(dims=dat_dx.dims, coords=coords)
# Reindex both arrays to this grid
dat_dx_expanded = dat_dx.reindex_like(dat_dxy, method="nearest")
dat_dy_expanded = dat_dy.reindex_like(dat_dxy, method="nearest")
# Compute mean
return (dat_dx_expanded + dat_dy_expanded) / 2