"""Boundary classes."""

import logging
from pathlib import Path
from typing import Literal, Optional, Union

import numpy as np
import wavespectra
import xarray as xr
from pydantic import Field, field_validator, model_validator

from import (
from rompy.core.grid import RegularGrid
from rompy.core.time import TimeRange
from rompy.utils import process_setting

logger = logging.getLogger(__name__)

def find_minimum_distance(points: list[tuple[float, float]]) -> float:
    """Find the minimum distance between a set of points.

    points: list[tuple[float, float]]
        List of points as (x, y) tuples.

    min_distance: float
        Minimum distance between all points.


    def calculate_distance(x1, y1, x2, y2):
        return np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)

    n = len(points)
    if n <= 1:
        return float("inf")

    # Sort points by x-coordinate

    # Recursive step
    if n == 2:
        return calculate_distance(*points[0], *points[1])

    mid = n // 2
    left_points = points[:mid]
    right_points = points[mid:]

    # Divide and conquer
    left_min = find_minimum_distance(left_points)
    right_min = find_minimum_distance(right_points)

    min_distance = min(left_min, right_min)

    # Find the closest pair across the dividing line
    strip = []
    for point in points:
        if abs(point[0] - points[mid][0]) < min_distance:

    strip_min = min_distance
    strip_len = len(strip)
    for i in range(strip_len - 1):
        j = i + 1
        while j < strip_len and (strip[j][1] - strip[i][1]) < strip_min:
            distance = calculate_distance(*strip[i], *strip[j])
            if distance < strip_min:
                strip_min = distance
            j += 1

    return min(min_distance, strip_min)

[docs] class SourceWavespectra(SourceBase): """Wavespectra dataset from wavespectra reader.""" model_type: Literal["wavespectra"] = Field( default="wavespectra", description="Model type discriminator", ) uri: str | Path = Field(description="Path to the dataset") reader: str = Field( description="Name of the wavespectra reader to use, e.g., read_swan", ) kwargs: dict = Field( default={}, description="Keyword arguments to pass to the wavespectra reader", ) def __str__(self) -> str: return f"SourceWavespectra(uri={self.uri}, reader={self.reader})" def _open(self): return getattr(wavespectra, self.reader)(self.uri, **self.kwargs)
BOUNDARY_SOURCE_MODELS = process_setting(BOUNDARY_SOURCE_TYPES) SPEC_BOUNDARY_SOURCE_MODELS = process_setting(SPEC_BOUNDARY_SOURCE_TYPES) class DataBoundary(DataGrid): model_type: Literal["boundary"] = Field( default="data_boundary", description="Model type discriminator", ) id: str = Field(description="Unique identifier for this data source") spacing: Optional[Union[float, Literal["parent"]]] = Field( default=None, description=( "Spacing between points along the grid boundary to retrieve data for. If " "None (default), points are defined from the the actual grid object " "passed to the `get` method. If 'parent', the resolution of the parent " "dataset is used to define the spacing." ), ) sel_method: Literal["sel", "interp"] = Field( default="sel", description=( "Xarray method to use for selecting boundary points from the dataset" ), ) sel_method_kwargs: dict = Field( default={}, description="Keyword arguments for sel_method" ) crop_data: bool = Field( default=True, description="Update crop filter from Time object if passed to get method", ) @field_validator("spacing") @classmethod def spacing_gt_zero(cls, v): if v not in (None, "parent") and v <= 0.0: raise ValueError("Spacing must be greater than zero") return v def _source_grid_spacing(self) -> float: """Return the lowest grid spacing in the source dataset. In a gridded dataset this is defined as the lowest spacing between adjacent points in the dataset. In other dataset types such as a station dataset this method needs to be overriden to return the lowest spacing between points. """ dx = np.diff(sorted(self.ds[self.coords.x].values)).min() dy = np.diff(sorted(self.ds[self.coords.y].values)).min() return min(dx, dy) def _set_spacing(self) -> float: """Define spacing from the parent dataset if required.""" if self.spacing == "parent": return self._source_grid_spacing() else: return self.spacing def _boundary_points(self, grid) -> tuple: """Returns the x and y arrays representing the boundary points to select. This method can be overriden to define custom boundary points. """ xbnd, ybnd = grid.boundary_points(spacing=self._set_spacing()) return xbnd, ybnd def _sel_boundary(self, grid) -> xr.Dataset: """Select the boundary points from the dataset.""" xbnd, ybnd = self._boundary_points(grid=grid) coords = { self.coords.x: xr.DataArray(xbnd, dims=("site",)), self.coords.y: xr.DataArray(ybnd, dims=("site",)), } return getattr(self.ds, self.sel_method)(coords, **self.sel_method_kwargs) def get( self, destdir: str | Path, grid: RegularGrid, time: Optional[TimeRange] = None ) -> str: """Write the selected boundary data to a netcdf file. Parameters ---------- destdir : str | Path Destination directory for the netcdf file. grid : RegularGrid Grid instance to use for selecting the boundary points. time: TimeRange, optional The times to filter the data to, only used if `self.crop_data` is True. Returns ------- outfile : Path Path to the netcdf file. """ if self.crop_data and time is not None: self._filter_time(time) ds = self._sel_boundary(grid) outfile = Path(destdir) / f"{}.nc" ds.to_netcdf(outfile) return outfile def plot(self, model_grid=None, cmap="turbo", fscale=10, ax=None, **kwargs): return scatter_plot( self, model_grid=model_grid, cmap=cmap, fscale=fscale, ax=ax, **kwargs ) def plot_boundary(self, grid=None, fscale=10, ax=None, **kwargs): """Plot the boundary points on a map.""" ds = self._sel_boundary(grid) fig, ax = grid.plot(ax=ax, fscale=fscale, **kwargs) return scatter_plot( self, ds=ds, fscale=fscale, ax=ax, **kwargs, )
[docs] class BoundaryWaveStation(DataBoundary): """Wave boundary data from station datasets. Note ---- The `tolerance` behaves differently with sel_methods `idw` and `nearest`; in `idw` sites with no enough neighbours within `tolerance` are masked whereas in `nearest` an exception is raised (see wavespectra documentation for more details). Note ---- Be aware that when using `idw` missing values will be returned for sites with less than 2 neighbours within `tolerance` in the original dataset. This is okay for land mask areas but could cause boundary issues when on an open boundary location. To avoid this either use `nearest` or increase `tolerance` to include more neighbours. """ grid_type: Literal["boundary_wave_station"] = Field( default="boundary_wave_station", description="Model type discriminator", ) source: SPEC_BOUNDARY_SOURCE_MODELS = Field( description=( "Dataset source reader, must return a wavespectra-enabled " "xarray dataset in the open method" ), discriminator="model_type", ) sel_method: Literal["idw", "nearest"] = Field( default="idw", description=( "Wavespectra method to use for selecting boundary points from the dataset" ), ) buffer: float = Field( default=2.0, description="Space to buffer the grid bounding box if `filter_grid` is True", ) def model_post_init(self, __context): self.variables = ["efth", "lon", "lat"] # @model_validator(mode="after") # def assert_has_wavespectra_accessor(self) -> "BoundaryWaveStation": # dset = # if not hasattr(dset, "spec"): # raise ValueError(f"Wavespectra compatible source is required") # return self def _source_grid_spacing(self, grid) -> float: """Return the lowest spacing between points in the source dataset.""" # Select dataset points just outside the actual grid to optimise the search xbnd, ybnd = grid.boundary().exterior.coords.xy dx = np.diff(xbnd).min() dy = np.diff(ybnd).min() buffer = 2 * min(dx, dy) x0, y0, x1, y1 = grid.bbox(buffer=buffer) ds = self.ds.spec.sel([x0, x1], [y0, y1], method="bbox") # Return the closest distance between adjacent points in cropped dataset points = list(zip(ds.lon.values, return find_minimum_distance(points) def _set_spacing(self, grid) -> float: """Define spacing from the parent dataset if required.""" if self.spacing == "parent": return self._source_grid_spacing(grid) else: return self.spacing def _boundary_points(self, grid) -> tuple: """Returns the x and y arrays representing the boundary points to select. Override the default method to use grid when setting the default spacing. """ xbnd, ybnd = grid.boundary_points(spacing=self._set_spacing(grid)) return xbnd, ybnd def _sel_boundary(self, grid) -> xr.Dataset: """Select the boundary points from the dataset.""" xbnd, ybnd = self._boundary_points(grid=grid) ds = self.ds.spec.sel( lons=xbnd, lats=ybnd, method=self.sel_method, **self.sel_method_kwargs, ) return ds @property def ds(self): """Return the filtered xarray dataset instance.""" dset = super().ds if dset.efth.size == 0: raise ValueError(f"Empty dataset after applying filter {self.filter}") return dset
[docs] def get( self, destdir: str | Path, grid: RegularGrid, time: Optional[TimeRange] = None ) -> str: """Write the selected boundary data to a netcdf file. Parameters ---------- destdir : str | Path Destination directory for the netcdf file. grid : RegularGrid Grid instance to use for selecting the boundary points. time: TimeRange, optional The times to filter the data to, only used if `self.crop_data` is True. Returns ------- outfile : Path Path to the netcdf file. """ if self.crop_data: if time is not None: self._filter_time(time) if grid is not None: self._filter_grid(grid) ds = self._sel_boundary(grid) outfile = Path(destdir) / f"{}.nc" ds.spec.to_netcdf(outfile) return outfile
def scatter_plot(bnd, ds=None, fscale=10, ax=None, **kwargs): """Plot the grid""" import as ccrs import cartopy.feature as cfeature import matplotlib.pyplot as plt from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER if ds is None: ds = bnd.ds # First set some plot parameters: minLon, minLat, maxLon, maxLat = ( min(ds[bnd.coords.x]), min(ds[bnd.coords.y]), max(ds[bnd.coords.x]), max(ds[bnd.coords.y]), ) extents = [minLon, maxLon, minLat, maxLat] if ax is None: # create figure and plot/map figsize = figsize = (fscale, fscale * (maxLat - minLat) / (maxLon - minLon)) subplot_kw = {"projection": ccrs.PlateCarree()} fig, ax = plt.subplots(1, 1, figsize=figsize, subplot_kw=subplot_kw) # ax.set_extent(extents, crs=ccrs.PlateCarree()) coastline = cfeature.GSHHSFeature( scale="auto", edgecolor="black", facecolor=cfeature.COLORS["land"] ) ax.add_feature(coastline) ax.add_feature(cfeature.BORDERS, linewidth=2) gl = ax.gridlines( crs=ccrs.PlateCarree(), draw_labels=True, linewidth=2, color="gray", alpha=0.5, linestyle="--", ) gl.xformatter = LONGITUDE_FORMATTER gl.yformatter = LATITUDE_FORMATTER ax.scatter(ds[bnd.coords.x], ds[bnd.coords.y], transform=ccrs.PlateCarree()) return ax