Source code for rompy.core.types

"""Rompy types."""
from typing import Any, Optional, Union
from datetime import datetime
from pydantic import (
    field_validator,
    model_validator,
    ConfigDict,
    BaseModel,
    Field,
)


[docs] class RompyBaseModel(BaseModel): # The config below prevents https://github.com/pydantic/pydantic/discussions/7121 model_config = ConfigDict(protected_namespaces=())
[docs] class Latitude(BaseModel): """Latitude""" lat: float = Field(description="Latitude", ge=-90, le=90)
[docs] @field_validator("lat") @classmethod def validate_lat(cls, v: float) -> float: if not (-90 <= v <= 90): raise ValueError("latitude must be between -90 and 90") return v
def __repr__(self): return f"Latitude(lat={self.lat})" def __str__(self): return f"{self.lat}" def __eq__(self, other): return self.lat == other.lat def __hash__(self): return hash((self.lat,))
[docs] class Longitude(BaseModel): """Longitude""" lon: float = Field(description="Longitude", ge=-180, le=180)
[docs] @field_validator("lon") @classmethod def validate_lon(cls, v: float) -> float: if not (-180 <= v <= 180): raise ValueError("longitude must be between -180 and 180") return v
def __repr__(self): return f"Longitude(lon={self.lon})" def __str__(self): return f"{self.lon}" def __eq__(self, other): return self.lon == other.lon def __hash__(self): return hash((self.lon,))
[docs] class Coordinate(BaseModel): """Coordinate""" lon: float = Field(description="Longitude") lat: float = Field(description="Latitude")
[docs] @field_validator("lon") @classmethod def validate_lon(cls, v: float) -> float: if not (-180 <= v <= 180): raise ValueError("longitude must be between -180 and 180") return v
[docs] @field_validator("lat") @classmethod def validate_lat(cls, v: float) -> float: if not (-90 <= v <= 90): raise ValueError("latitude must be between -90 and 90") return v
def __repr__(self): return f"Coordinate(lon={self.lon}, lat={self.lat})" def __str__(self): return f"({self.lon}, {self.lat})" def __eq__(self, other): return self.lon == other.lon and self.lat == other.lat def __hash__(self): return hash((self.lon, self.lat))
[docs] class Bbox(BaseModel): """Bounding box""" minlon: Longitude = Field(description="Minimum longitude") minlat: Latitude = Field(description="Minimum latitude") maxlon: Longitude = Field(description="Maximum longitude") maxlat: Latitude = Field(description="Maximum latitude") def __repr__(self): return f"Bbox(minlon={self.minlon}, minlat={self.minlat}, maxlon={self.maxlon}, maxlat={self.maxlat})" def __str__(self): return f"({self.minlon}, {self.minlat}, {self.maxlon}, {self.maxlat})" def __eq__(self, other): return ( self.minlon == other.minlon and self.minlat == other.minlat and self.maxlon == other.maxlon and self.maxlat == other.maxlat ) def __hash__(self): return hash((self.minlon, self.minlat, self.maxlon, self.maxlat))
[docs] @model_validator(mode="after") def validate_coords(self) -> "Bbox": if self.minlon >= self.maxlon: raise ValueError("minlon must be less than maxlon") if self.minlat >= self.maxlat: raise ValueError("minlat must be less than maxlon") return self
@property def width(self): return self.maxlon - self.minlon @property def height(self): return self.maxlat - self.minlat @property def center(self): return Coordinate( lon=(self.minlon + self.maxlon) / 2, lat=(self.minlat + self.maxlat) / 2 ) @property def area(self): return self.width * self.height
[docs] def contains(self, other): if isinstance(other, Bbox): return ( self.minlon <= other.minlon and self.minlat <= other.minlat and self.maxlon >= other.maxlon and self.maxlat >= other.maxlat ) elif isinstance(other, Coordinate): return ( self.minlon <= other.lon and self.minlat <= other.lat and self.maxlon >= other.lon and self.maxlat >= other.lat ) else: raise TypeError("other must be a Bbox or a Coordinate")
[docs] def intersects(self, other): if isinstance(other, Bbox): return ( self.minlon <= other.maxlon and self.minlat <= other.maxlat and self.maxlon >= other.minlon and self.maxlat >= other.minlat ) elif isinstance(other, Coordinate): return ( self.minlon <= other.lon and self.minlat <= other.lat and self.maxlon >= other.lon and self.maxlat >= other.lat ) else: raise TypeError("other must be a Bbox or a Coordinate")
[docs] def union(self, other): if isinstance(other, Bbox): return Bbox( minlon=min(self.minlon, other.minlon), minlat=min(self.minlat, other.minlat), maxlon=max(self.maxlon, other.maxlon), maxlat=max(self.maxlat, other.maxlat), ) elif isinstance(other, Coordinate): return Bbox( minlon=min(self.minlon, other.lon), minlat=min(self.minlat, other.lat), maxlon=max(self.maxlon, other.lon), maxlat=max(self.maxlat, other.lat), ) else: raise TypeError("other must be a Bbox or a Coordinate")
[docs] def intersection(self, other): if isinstance(other, Bbox): return Bbox( minlon=max(self.minlon, other.minlon), minlat=max(self.minlat, other.minlat), maxlon=min(self.maxlon, other.maxlon), maxlat=min(self.maxlat, other.maxlat), ) elif isinstance(other, Coordinate): if self.contains(other): return other else: return None else: raise TypeError("other must be a Bbox or a Coordinate")
[docs] def expand(self, amount): return Bbox( minlon=self.minlon - amount, minlat=self.minlat - amount, maxlon=self.maxlon + amount, maxlat=self.maxlat + amount, )
[docs] def shrink(self, amount): return Bbox( minlon=self.minlon + amount, minlat=self.minlat + amount, maxlon=self.maxlon - amount, maxlat=self.maxlat - amount, )
[docs] def scale(self, amount): return Bbox( minlon=self.minlon * amount, minlat=self.minlat * amount, maxlon=self.maxlon * amount, maxlat=self.maxlat * amount, )
[docs] def translate(self, amount): return Bbox( minlon=self.minlon + amount.lon, minlat=self.minlat + amount.lat, maxlon=self.maxlon + amount.lon, maxlat=self.maxlat + amount.lat, )
[docs] def to_geojson(self): return { "type": "Polygon", "coordinates": [ [ [self.minlon, self.minlat], [self.minlon, self.maxlat], [self.maxlon, self.maxlat], [self.maxlon, self.minlat], [self.minlon, self.minlat], ] ], }
[docs] def to_wkt(self): return f"POLYGON(({self.minlon} {self.minlat}, {self.minlon} {self.maxlat}, {self.maxlon} {self.maxlat}, {self.maxlon} {self.minlat}, {self.minlon} {self.minlat}))"
[docs] def to_wkb(self): return wkb.dumps(self.to_geojson())
[docs] @classmethod def from_geojson(cls, geojson): if geojson["type"] == "Polygon": return cls.from_polygon(geojson["coordinates"]) elif geojson["type"] == "MultiPolygon": return cls.from_multipolygon(geojson["coordinates"]) else: raise ValueError("geojson must be a Polygon or a MultiPolygon")
[docs] @classmethod def from_wkt(cls, wkt): return cls.from_geojson(wkt.loads(wkt))
[docs] @classmethod def from_wkb(cls, wkb): return cls.from_geojson(wkb.loads(wkb))
[docs] @classmethod def from_polygon(cls, polygon): if len(polygon) != 1: raise ValueError("polygon must have exactly one ring") return cls.from_ring(polygon[0])
[docs] class Spectrum(RompyBaseModel): """A class to represent a wave spectrum.""" fmin: float = Field(0.0464, description="Minimum frequency in Hz") fmax: float = Field(1.0, description="Maximum frequency in Hz") nfreqs: int = Field(31, description="Number of frequency components") ndirs: int = Field(36, description="Number of directional components")
# TODO make more flexible.
[docs] class DatasetCoords(RompyBaseModel): """Coordinates representation.""" t: Optional[str] = Field("time", description="Name of the time coordinate") x: Optional[str] = Field("longitude", description="Name of the x coordinate") y: Optional[str] = Field("latitude", description="Name of the y coordinate") z: Optional[str] = Field("depth", description="Name of the z coordinate")
[docs] class Slice(BaseModel): """Basic float or datetime slice representation""" start: Optional[Union[float, datetime, str]] = Field( default=None, description="Slice start" ) stop: Optional[Union[float, datetime, str]] = Field( default=None, description="Slice stop" )
[docs] @classmethod def from_dict(cls, d: dict): return cls(start=d['start'], stop=d['stop'])
[docs] @classmethod def from_slice(cls, s: slice): return cls(start=s.start, stop=s.stop)
[docs] def to_slice(self) -> slice: return slice(self.start, self.stop)
@classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): if isinstance(v, slice): return cls.from_slice(v) elif isinstance(v, cls): return v else: raise ValueError(f"Cannot convert {type(v)} to Slice")