Source code for rompy.core.boundary

"""Boundary classes."""

import logging
from pathlib import Path
from typing import Literal, Optional, Union

import numpy as np
import wavespectra
import xarray as xr
from pydantic import Field, field_validator, model_validator

from rompy.core.data import (
    DataGrid,
    SourceBase,
    SourceDatamesh,
    SourceDataset,
    SourceFile,
    SourceIntake,
)
from rompy.core.grid import RegularGrid
from rompy.core.time import TimeRange
from rompy.settings import BOUNDARY_SOURCE_TYPES, SPEC_BOUNDARY_SOURCE_TYPES
from rompy.utils import process_setting

logger = logging.getLogger(__name__)


def find_minimum_distance(points: list[tuple[float, float]]) -> float:
    """Find the minimum distance between a set of points.

    Parameters
    ----------
    points: list[tuple[float, float]]
        List of points as (x, y) tuples.

    Returns
    -------
    min_distance: float
        Minimum distance between all points.

    """

    def calculate_distance(x1, y1, x2, y2):
        return np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)

    n = len(points)
    if n <= 1:
        return float("inf")

    # Sort points by x-coordinate
    points.sort()

    # Recursive step
    if n == 2:
        return calculate_distance(*points[0], *points[1])

    mid = n // 2
    left_points = points[:mid]
    right_points = points[mid:]

    # Divide and conquer
    left_min = find_minimum_distance(left_points)
    right_min = find_minimum_distance(right_points)

    min_distance = min(left_min, right_min)

    # Find the closest pair across the dividing line
    strip = []
    for point in points:
        if abs(point[0] - points[mid][0]) < min_distance:
            strip.append(point)

    strip_min = min_distance
    strip_len = len(strip)
    for i in range(strip_len - 1):
        j = i + 1
        while j < strip_len and (strip[j][1] - strip[i][1]) < strip_min:
            distance = calculate_distance(*strip[i], *strip[j])
            if distance < strip_min:
                strip_min = distance
            j += 1

    return min(min_distance, strip_min)


[docs] class SourceWavespectra(SourceBase): """Wavespectra dataset from wavespectra reader.""" model_type: Literal["wavespectra"] = Field( default="wavespectra", description="Model type discriminator", ) uri: str | Path = Field(description="Path to the dataset") reader: str = Field( description="Name of the wavespectra reader to use, e.g., read_swan", ) kwargs: dict = Field( default={}, description="Keyword arguments to pass to the wavespectra reader", ) def __str__(self) -> str: return f"SourceWavespectra(uri={self.uri}, reader={self.reader})" def _open(self): return getattr(wavespectra, self.reader)(self.uri, **self.kwargs)
BOUNDARY_SOURCE_MODELS = process_setting(BOUNDARY_SOURCE_TYPES) SPEC_BOUNDARY_SOURCE_MODELS = process_setting(SPEC_BOUNDARY_SOURCE_TYPES) class DataBoundary(DataGrid): model_type: Literal["boundary"] = Field( default="data_boundary", description="Model type discriminator", ) id: str = Field(description="Unique identifier for this data source") spacing: Optional[Union[float, Literal["parent"]]] = Field( default=None, description=( "Spacing between points along the grid boundary to retrieve data for. If " "None (default), points are defined from the the actual grid object " "passed to the `get` method. If 'parent', the resolution of the parent " "dataset is used to define the spacing." ), ) sel_method: Literal["sel", "interp"] = Field( default="sel", description=( "Xarray method to use for selecting boundary points from the dataset" ), ) sel_method_kwargs: dict = Field( default={}, description="Keyword arguments for sel_method" ) crop_data: bool = Field( default=True, description="Update crop filter from Time object if passed to get method", ) @field_validator("spacing") @classmethod def spacing_gt_zero(cls, v): if v not in (None, "parent") and v <= 0.0: raise ValueError("Spacing must be greater than zero") return v def _source_grid_spacing(self) -> float: """Return the lowest grid spacing in the source dataset. In a gridded dataset this is defined as the lowest spacing between adjacent points in the dataset. In other dataset types such as a station dataset this method needs to be overriden to return the lowest spacing between points. """ dx = np.diff(sorted(self.ds[self.coords.x].values)).min() dy = np.diff(sorted(self.ds[self.coords.y].values)).min() return min(dx, dy) def _set_spacing(self) -> float: """Define spacing from the parent dataset if required.""" if self.spacing == "parent": return self._source_grid_spacing() else: return self.spacing def _boundary_points(self, grid) -> tuple: """Returns the x and y arrays representing the boundary points to select. This method can be overriden to define custom boundary points. """ xbnd, ybnd = grid.boundary_points(spacing=self._set_spacing()) return xbnd, ybnd def _sel_boundary(self, grid) -> xr.Dataset: """Select the boundary points from the dataset.""" xbnd, ybnd = self._boundary_points(grid=grid) coords = { self.coords.x: xr.DataArray(xbnd, dims=("site",)), self.coords.y: xr.DataArray(ybnd, dims=("site",)), } return getattr(self.ds, self.sel_method)(coords, **self.sel_method_kwargs) def get( self, destdir: str | Path, grid: RegularGrid, time: Optional[TimeRange] = None ) -> str: """Write the selected boundary data to a netcdf file. Parameters ---------- destdir : str | Path Destination directory for the netcdf file. grid : RegularGrid Grid instance to use for selecting the boundary points. time: TimeRange, optional The times to filter the data to, only used if `self.crop_data` is True. Returns ------- outfile : Path Path to the netcdf file. """ if self.crop_data and time is not None: self._filter_time(time) ds = self._sel_boundary(grid) outfile = Path(destdir) / f"{self.id}.nc" ds.to_netcdf(outfile) return outfile def plot(self, model_grid=None, cmap="turbo", fscale=10, ax=None, **kwargs): return scatter_plot( self, model_grid=model_grid, cmap=cmap, fscale=fscale, ax=ax, **kwargs ) def plot_boundary(self, grid=None, fscale=10, ax=None, **kwargs): """Plot the boundary points on a map.""" ds = self._sel_boundary(grid) fig, ax = grid.plot(ax=ax, fscale=fscale, **kwargs) return scatter_plot( self, ds=ds, fscale=fscale, ax=ax, **kwargs, )
[docs] class BoundaryWaveStation(DataBoundary): """Wave boundary data from station datasets. Note ---- The `tolerance` behaves differently with sel_methods `idw` and `nearest`; in `idw` sites with no enough neighbours within `tolerance` are masked whereas in `nearest` an exception is raised (see wavespectra documentation for more details). Note ---- Be aware that when using `idw` missing values will be returned for sites with less than 2 neighbours within `tolerance` in the original dataset. This is okay for land mask areas but could cause boundary issues when on an open boundary location. To avoid this either use `nearest` or increase `tolerance` to include more neighbours. """ grid_type: Literal["boundary_wave_station"] = Field( default="boundary_wave_station", description="Model type discriminator", ) source: SPEC_BOUNDARY_SOURCE_MODELS = Field( description=( "Dataset source reader, must return a wavespectra-enabled " "xarray dataset in the open method" ), discriminator="model_type", ) sel_method: Literal["idw", "nearest"] = Field( default="idw", description=( "Wavespectra method to use for selecting boundary points from the dataset" ), ) buffer: float = Field( default=2.0, description="Space to buffer the grid bounding box if `filter_grid` is True", ) def model_post_init(self, __context): self.variables = ["efth", "lon", "lat"] # @model_validator(mode="after") # def assert_has_wavespectra_accessor(self) -> "BoundaryWaveStation": # dset = self.source.open() # if not hasattr(dset, "spec"): # raise ValueError(f"Wavespectra compatible source is required") # return self def _source_grid_spacing(self, grid) -> float: """Return the lowest spacing between points in the source dataset.""" # Select dataset points just outside the actual grid to optimise the search xbnd, ybnd = grid.boundary().exterior.coords.xy dx = np.diff(xbnd).min() dy = np.diff(ybnd).min() buffer = 2 * min(dx, dy) x0, y0, x1, y1 = grid.bbox(buffer=buffer) ds = self.ds.spec.sel([x0, x1], [y0, y1], method="bbox") # Return the closest distance between adjacent points in cropped dataset points = list(zip(ds.lon.values, ds.lat.values)) return find_minimum_distance(points) def _set_spacing(self, grid) -> float: """Define spacing from the parent dataset if required.""" if self.spacing == "parent": return self._source_grid_spacing(grid) else: return self.spacing def _boundary_points(self, grid) -> tuple: """Returns the x and y arrays representing the boundary points to select. Override the default method to use grid when setting the default spacing. """ xbnd, ybnd = grid.boundary_points(spacing=self._set_spacing(grid)) return xbnd, ybnd def _sel_boundary(self, grid) -> xr.Dataset: """Select the boundary points from the dataset.""" xbnd, ybnd = self._boundary_points(grid=grid) ds = self.ds.spec.sel( lons=xbnd, lats=ybnd, method=self.sel_method, **self.sel_method_kwargs, ) return ds @property def ds(self): """Return the filtered xarray dataset instance.""" dset = super().ds if dset.efth.size == 0: raise ValueError(f"Empty dataset after applying filter {self.filter}") return dset
[docs] def get( self, destdir: str | Path, grid: RegularGrid, time: Optional[TimeRange] = None ) -> str: """Write the selected boundary data to a netcdf file. Parameters ---------- destdir : str | Path Destination directory for the netcdf file. grid : RegularGrid Grid instance to use for selecting the boundary points. time: TimeRange, optional The times to filter the data to, only used if `self.crop_data` is True. Returns ------- outfile : Path Path to the netcdf file. """ if self.crop_data: if time is not None: self._filter_time(time) if grid is not None: self._filter_grid(grid) ds = self._sel_boundary(grid) outfile = Path(destdir) / f"{self.id}.nc" ds.spec.to_netcdf(outfile) return outfile
def scatter_plot(bnd, ds=None, fscale=10, ax=None, **kwargs): """Plot the grid""" import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib.pyplot as plt from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER if ds is None: ds = bnd.ds # First set some plot parameters: minLon, minLat, maxLon, maxLat = ( min(ds[bnd.coords.x]), min(ds[bnd.coords.y]), max(ds[bnd.coords.x]), max(ds[bnd.coords.y]), ) extents = [minLon, maxLon, minLat, maxLat] if ax is None: # create figure and plot/map figsize = figsize = (fscale, fscale * (maxLat - minLat) / (maxLon - minLon)) subplot_kw = {"projection": ccrs.PlateCarree()} fig, ax = plt.subplots(1, 1, figsize=figsize, subplot_kw=subplot_kw) # ax.set_extent(extents, crs=ccrs.PlateCarree()) coastline = cfeature.GSHHSFeature( scale="auto", edgecolor="black", facecolor=cfeature.COLORS["land"] ) ax.add_feature(coastline) ax.add_feature(cfeature.BORDERS, linewidth=2) gl = ax.gridlines( crs=ccrs.PlateCarree(), draw_labels=True, linewidth=2, color="gray", alpha=0.5, linestyle="--", ) gl.xformatter = LONGITUDE_FORMATTER gl.yformatter = LATITUDE_FORMATTER ax.scatter(ds[bnd.coords.x], ds[bnd.coords.y], transform=ccrs.PlateCarree()) return ax