from datetime import datetime
from enum import IntEnum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import numpy as np
import pandas as pd
import scipy as sp
import xarray as xr
from cloudpathlib import AnyPath
from pydantic import ConfigDict, Field, model_validator
from pylib import compute_zcor, read_schism_bpfile, read_schism_hgrid, read_schism_vgrid
from rompy.core.boundary import BoundaryWaveStation, DataBoundary
from rompy.core.data import DataBlob, DataGrid
from rompy.core.time import TimeRange
from rompy.core.types import RompyBaseModel
from rompy.schism.bctides import Bctides
from rompy.schism.boundary import Boundary3D, BoundaryData
from rompy.schism.boundary_core import (
BoundaryHandler,
ElevationType,
TidalDataset,
TracerType,
VelocityType,
create_tidal_boundary,
)
from rompy.schism.grid import SCHISMGrid
from rompy.logging import get_logger
from rompy.formatting import ARROW
from rompy.schism.tides_enhanced import BoundarySetup
from rompy.utils import total_seconds
from .namelists import Sflux_Inputs
logger = get_logger(__name__)
def to_python_type(value):
"""Convert numpy types to Python native types."""
if isinstance(value, np.ndarray):
return value.tolist()
elif isinstance(value, np.integer):
return int(value)
elif isinstance(value, np.floating):
return float(value)
elif isinstance(value, np.bool_):
return bool(value)
else:
return value
[docs]
class SfluxSource(DataGrid):
"""This is a single variable source for and sflux input"""
data_type: Literal["sflux"] = Field(
default="sflux",
description="Model type discriminator",
)
id: str = Field(default="sflux_source", description="id of the source")
relative_weight: float = Field(
1.0,
description="relative weight of the source file if two files are provided",
)
max_window_hours: float = Field(
120.0,
description="maximum number of hours (offset from start time in each file) in each file of set 1",
)
fail_if_missing: bool = Field(
True, description="Fail if the source file is missing"
)
time_buffer: list[int] = Field(
default=[0, 1],
description="Number of source data timesteps to buffer the time range if `filter_time` is True",
)
# The source field needs special handling
source: Any = None
_variable_names = []
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
def __init__(self, **data):
# Special handling for the DataGrid source field
# Pydantic v2 is strict about union tag validation, so we need to handle it manually
source_obj = None
if "source" in data:
source_obj = data.pop("source") # Remove source to avoid validation errors
# Initialize without the source field
try:
super().__init__(**data)
# Set the source object after initialization
if source_obj is not None:
self.source = source_obj
except Exception as e:
logger.error(f"Error initializing SfluxSource: {e}")
logger.error(f"Input data: {data}")
raise
# Initialize variable names
self._set_variables()
@property
def outfile(self) -> str:
# TODO - filenumber is. Hardcoded to 1 for now.
return f'{self.id}.{str(1).rjust(4, "0")}.nc'
def _set_variables(self) -> None:
for variable in self._variable_names:
if getattr(self, variable) is not None:
self.variables.append(getattr(self, variable))
@property
def namelist(self) -> dict:
# ret = self.model_dump()
ret = {}
for key, value in self.model_dump().items():
if key in ["relative_weight", "max_window_hours", "fail_if_missing"]:
ret.update({f"{self.id}_{key}": value})
for varname in self._variable_names:
var = getattr(self, varname)
if var is not None:
ret.update({varname: var})
else:
ret.update({varname: varname.replace("_name", "")})
ret.update({f"{self.id}_file": self.id})
return ret
@property
def ds(self):
"""Return the xarray dataset for this data source."""
ds = self.source.open(
variables=self.variables, filters=self.filter, coords=self.coords
)
# Define a dictionary for potential renaming
rename_dict = {self.coords.y: "ny_grid", self.coords.x: "nx_grid"}
# Construct a valid renaming dictionary
valid_rename_dict = get_valid_rename_dict(ds, rename_dict)
# Perform renaming if necessary
if valid_rename_dict:
ds = ds.rename_dims(valid_rename_dict)
lon, lat = np.meshgrid(ds[self.coords.x], ds[self.coords.y])
ds["lon"] = (("ny_grid", "nx_grid"), lon)
ds["lat"] = (("ny_grid", "nx_grid"), lat)
basedate = pd.to_datetime(ds.time.values[0])
unit = f"days since {basedate.strftime('%Y-%m-%d %H:%M:%S')}"
ds.time.attrs = {
"long_name": "Time",
"standard_name": "time",
"base_date": np.int32(
np.array(
[
basedate.year,
basedate.month,
basedate.day,
basedate.hour,
basedate.minute,
basedate.second,
]
)
),
# "units": unit,
}
ds.time.encoding["units"] = unit
ds.time.encoding["calendar"] = "proleptic_gregorian"
# open bad dataset
# SCHISM doesn't like scale_factor and add_offset attributes and requires Float64 values
for var in ds.data_vars:
# If the variable has scale_factor or add_offset attributes, remove them
if "scale_factor" in ds[var].encoding:
del ds[var].encoding["scale_factor"]
if "add_offset" in ds[var].encoding:
del ds[var].encoding["add_offset"]
# set the data variable encoding to Float64
ds[var].encoding["dtype"] = np.dtypes.Float64DType()
return ds
[docs]
class SfluxAir(SfluxSource):
"""This is a single variable source for and sflux input"""
data_type: Literal["sflux_air"] = Field(
default="sflux_air",
description="Model type discriminator",
)
uwind_name: Optional[str] = Field(
None,
description="name of zonal wind variable in source",
)
vwind_name: Optional[str] = Field(
None,
description="name of meridional wind variable in source",
)
prmsl_name: Optional[str] = Field(
None,
description="name of mean sea level pressure variable in source",
)
stmp_name: Optional[str] = Field(
None,
description="name of surface air temperature variable in source",
)
spfh_name: Optional[str] = Field(
None,
description="name of specific humidity variable in source",
)
# Allow extra fields during validation but exclude them from the model
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_assignment=True,
extra="allow", # Allow extra fields during validation
populate_by_name=True, # Enable population by field name
)
def __init__(self, **data):
# Initialize logger at the beginning
# Pre-process parameters before passing to pydantic
# Map parameters without _name suffix to ones with suffix
name_mappings = {
"uwind": "uwind_name",
"vwind": "vwind_name",
"prmsl": "prmsl_name",
"stmp": "stmp_name",
"spfh": "spfh_name",
}
for old_name, new_name in name_mappings.items():
if old_name in data and new_name not in data:
data[new_name] = data.pop(old_name)
# Extract source to handle it separately (avoiding validation problems)
source_obj = None
if "source" in data:
source_obj = data.pop("source") # Remove source to avoid validation errors
# Import here to avoid circular import
from rompy.core.source import SourceFile, SourceIntake
# If source is a dictionary, convert it to a proper source object
if isinstance(source_obj, dict):
logger.info(
f"Converting source dictionary to source object: {source_obj}"
)
# Handle different source types based on what's in the dictionary
if "uri" in source_obj:
# Create a SourceFile or SourceIntake based on the URI
uri = source_obj["uri"]
if uri.startswith("intake://") or uri.endswith(".yaml"):
source_obj = SourceIntake(uri=uri)
else:
source_obj = SourceFile(uri=uri)
logger.info(f"Created source object from URI: {uri}")
else:
# If no URI, create a minimal valid source
logger.warning(
f"Source dictionary does not contain URI, creating a minimal source"
)
# Default to a sample data source for testing
source_obj = SourceFile(
uri="../../tests/schism/test_data/sample.nc"
)
else:
raise ValueError("SfluxAir requires a 'source' parameter")
# Call the parent constructor with the processed data (without source)
try:
super().__init__(**data)
except Exception as e:
logger.error(f"Error initializing SfluxAir: {e}")
logger.error(f"Input data: {data}")
raise
# Set source manually after initialization
self.source = source_obj
logger.info(
f"Successfully created SfluxAir instance with source type: {type(self.source)}"
)
_variable_names = [
"uwind_name",
"vwind_name",
"prmsl_name",
"stmp_name",
"spfh_name",
]
@property
def ds(self):
"""Return the xarray dataset for this data source."""
ds = super().ds
for variable in self._variable_names:
data_var = getattr(self, variable)
if data_var == None:
proxy_var = variable.replace("_name", "")
ds[proxy_var] = ds[self.uwind_name].copy()
if variable == "spfh_name":
missing = 0.01
else:
missing = -999
ds[proxy_var][:, :, :] = missing
ds.data_vars[proxy_var].attrs["long_name"] = proxy_var
return ds
[docs]
class SfluxRad(SfluxSource):
"""This is a single variable source for and sflux input"""
data_type: Literal["sflux_rad"] = Field(
default="sflux_rad",
description="Model type discriminator",
)
dlwrf_name: str = Field(
None,
description="name of downward long wave radiation variable in source",
)
dswrf_name: str = Field(
None,
description="name of downward short wave radiation variable in source",
)
_variable_names = ["dlwrf_name", "dswrf_name"]
[docs]
class SfluxPrc(SfluxSource):
"""This is a single variable source for and sflux input"""
data_type: Literal["sflux_prc"] = Field(
default="sflux_rad",
description="Model type discriminator",
)
prate_name: str = Field(
None,
description="name of precipitation rate variable in source",
)
_variable_names = ["prate_name"]
[docs]
class SCHISMDataSflux(RompyBaseModel):
data_type: Literal["sflux"] = Field(
default="sflux",
description="Model type discriminator",
)
air_1: Optional[Any] = Field(None, description="sflux air source 1")
air_2: Optional[Any] = Field(None, description="sflux air source 2")
rad_1: Optional[Union[DataBlob, SfluxRad]] = Field(
None, description="sflux rad source 1"
)
rad_2: Optional[Union[DataBlob, SfluxRad]] = Field(
None, description="sflux rad source 2"
)
prc_1: Optional[Union[DataBlob, SfluxPrc]] = Field(
None, description="sflux prc source 1"
)
prc_2: Optional[Union[DataBlob, SfluxPrc]] = Field(
None, description="sflux prc source 2"
)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
def __init__(self, **data):
# Handle 'air' parameter by mapping it to 'air_1'
if "air" in data:
air_value = data.pop("air")
# If air is a dict, convert it to a SfluxAir instance
if isinstance(air_value, dict):
try:
# Import here to avoid circular import
from rompy.schism.data import SfluxAir
air_value = SfluxAir(**air_value)
logger.info(
f"Successfully created SfluxAir instance from dictionary"
)
except Exception as e:
logger.error(f"Failed to create SfluxAir instance: {e}")
# Fall back to passing the dictionary directly
logger.info(f"Falling back to dictionary: {air_value}")
data["air_1"] = air_value
# Call the parent constructor with the processed data
super().__init__(**data)
[docs]
@model_validator(mode="after")
def validate_air_fields(self):
"""Validate air fields after model creation."""
# Convert dictionary to SfluxAir if needed
if isinstance(self.air_1, dict):
try:
# Import here to avoid circular import
from rompy.schism.data import SfluxAir
logger.info(
f"Converting air_1 dictionary to SfluxAir object: {self.air_1}"
)
self.air_1 = SfluxAir(**self.air_1)
logger.info(f"Successfully converted air_1 to SfluxAir instance")
except Exception as e:
logger.error(f"Error converting air_1 dictionary to SfluxAir: {e}")
logger.error(f"Input data: {self.air_1}")
# We'll let validation continue with the dictionary
if isinstance(self.air_2, dict):
try:
from rompy.schism.data import SfluxAir
logger.info(
f"Converting air_2 dictionary to SfluxAir object: {self.air_2}"
)
self.air_2 = SfluxAir(**self.air_2)
logger.info(f"Successfully converted air_2 to SfluxAir instance")
except Exception as e:
logger.error(f"Error converting air_2 dictionary to SfluxAir: {e}")
logger.error(f"Input data: {self.air_2}")
return self
[docs]
def get(
self,
destdir: str | Path,
grid: Optional[SCHISMGrid] = None,
time: Optional[TimeRange] = None,
) -> Path:
"""Writes SCHISM sflux data from a dataset.
Args:
destdir (str | Path): The destination directory to write the sflux data.
grid (Optional[SCHISMGrid], optional): The grid type. Defaults to None.
time (Optional[TimeRange], optional): The time range. Defaults to None.
Returns:
Path: The path to the written sflux data.
"""
ret = {}
destdir = Path(destdir) / "sflux"
destdir.mkdir(parents=True, exist_ok=True)
namelistargs = {}
# Collect information about active variables for logging
active_variables = []
source_info = {}
for variable in ["air_1", "air_2", "rad_1", "rad_2", "prc_1", "prc_2"]:
data = getattr(self, variable)
if data is None:
continue
data.id = variable
active_variables.append(variable)
# Get source information
if hasattr(data, 'source') and hasattr(data.source, 'uri'):
source_info[variable] = str(data.source.uri)
logger.debug(f"Processing {variable}")
namelistargs.update(data.namelist)
# Expand time by one day on each end
if time is not None:
time = TimeRange(start=time.start - pd.Timedelta(days=1), end=time.end + pd.Timedelta(days=1))
ret[variable] = data.get(destdir, grid, time)
# Log summary of atmospheric data processing
if active_variables:
logger.info(f" • Variables: {', '.join(active_variables)}")
if source_info:
unique_sources = list(set(source_info.values()))
if len(unique_sources) == 1:
logger.info(f" • Source: {unique_sources[0]}")
else:
logger.info(f" • Sources: {len(unique_sources)} files")
logger.info(f" • Output: {destdir}")
ret["nml"] = Sflux_Inputs(**namelistargs).write_nml(destdir)
return ret
[docs]
@model_validator(mode="after")
def check_weights(v):
"""Check that relative weights for each pair add to 1.
Args:
cls: The class.
v: The variable.
Raises:
ValueError: If the relative weights for any variable do not add up to 1.0.
"""
for variable in ["air", "rad", "prc"]:
weight = 0
active = False
for i in [1, 2]:
data = getattr(v, f"{variable}_{i}")
if data is None:
continue
if data.fail_if_missing:
continue
weight += data.relative_weight
active = True
if active and weight != 1.0:
raise ValueError(
f"Relative weights for {variable} do not add to 1.0: {weight}"
)
return v
# SCHISM doesn't like scale_factor and add_offset attributes and requires Float64 values
for var in ds.data_vars:
# If the variable has scale_factor or add_offset attributes, remove them
if "scale_factor" in ds[var].encoding:
del ds[var].encoding["scale_factor"]
if "add_offset" in ds[var].encoding:
del ds[var].encoding["add_offset"]
# set the data variable encoding to Float64
ds[var].encoding["dtype"] = np.dtypes.Float64DType()
[docs]
class SCHISMDataWave(BoundaryWaveStation):
"""This class is used to write wave spectral boundary data. Spectral data is extracted
from the nearest points along the grid boundary"""
data_type: Literal["wave"] = Field(
default="wave",
description="Model type discriminator",
)
sel_method: Literal["idw", "nearest"] = Field(
default="nearest",
description="Method for selecting boundary points",
)
sel_method_kwargs: dict = Field(
default={"unique": True},
description="Keyword arguments for sel_method",
)
time_buffer: list[int] = Field(
default=[0, 1],
description="Number of source data timesteps to buffer the time range if `filter_time` is True",
)
[docs]
def get(
self,
destdir: str | Path,
grid: SCHISMGrid,
time: Optional[TimeRange] = None,
) -> str:
"""Write the selected boundary data to a netcdf file.
Parameters
----------
destdir : str | Path
Destination directory for the netcdf file.
grid : SCHISMGrid
Grid instance to use for selecting the boundary points.
time: TimeRange, optional
The times to filter the data to, only used if `self.crop_data` is True.
Returns
-------
outfile : Path
Path to the netcdf file.
"""
logger.debug(f"Processing wave data: {self.id}")
if self.crop_data and time is not None:
self._filter_time(time)
ds = self._sel_boundary(grid)
outfile = Path(destdir) / f"{self.id}.nc"
ds.spec.to_ww3(outfile)
logger.debug(f"Saved wave data to {outfile}")
return outfile
@property
def ds(self):
"""Return the filtered xarray dataset instance."""
ds = super().ds
for var in ds.data_vars:
# If the variable has scale_factor or add_offset attributes, remove them
if "scale_factor" in ds[var].encoding:
del ds[var].encoding["scale_factor"]
if "add_offset" in ds[var].encoding:
del ds[var].encoding["add_offset"]
# set the data variable encoding to Float64
ds[var].encoding["dtype"] = np.dtypes.Float64DType()
return ds
def __str__(self):
return f"SCHISMDataWave"
[docs]
class SCHISMDataBoundary(DataBoundary):
"""This class is used to extract ocean boundary data from a griddd dataset at all open
boundary nodes."""
data_type: Literal["boundary"] = Field(
default="boundary",
description="Model type discriminator",
)
id: str = Field(
"bnd",
description="SCHISM th id of the source",
json_schema_extra={"choices": ["elev2D", "uv3D", "TEM_3D", "SAL_3D", "bnd"]},
)
# This field is used to handle DataGrid sources in Pydantic v2
data_grid_source: Optional[DataGrid] = Field(
None, description="DataGrid source for boundary data"
)
variables: list[str] = Field(
default_factory=list, description="variable name in the dataset"
)
sel_method: Literal["sel", "interp"] = Field(
default="interp",
description=(
"Xarray method to use for selecting boundary points from the dataset"
),
)
time_buffer: list[int] = Field(
default=[0, 1],
description="Number of source data timesteps to buffer the time range if `filter_time` is True",
)
[docs]
def get(
self,
destdir: str | Path,
grid: SCHISMGrid,
time: Optional[TimeRange] = None,
) -> str:
"""Write the selected boundary data to a netcdf file.
Parameters
----------
destdir : str | Path
Destination directory for the netcdf file.
grid : SCHISMGrid
Grid instance to use for selecting the boundary points.
time: TimeRange, optional
The times to filter the data to, only used if `self.crop_data` is True.
Returns
-------
outfile : Path
Path to the netcdf file.
"""
# prepare xarray.Dataset and save forcing netCDF file
outfile = Path(destdir) / f"{self.id}.th.nc"
boundary_ds = self.boundary_ds(grid, time)
boundary_ds.to_netcdf(outfile, "w", "NETCDF3_CLASSIC", unlimited_dims="time")
# Log file details with dimensions
if "time_series" in boundary_ds.data_vars:
shape = boundary_ds.time_series.shape
logger.debug(f"Saved {self.id} to {outfile} (shape: {shape})")
else:
logger.debug(f"Saved boundary data to {outfile}")
return outfile
[docs]
def boundary_ds(self, grid: SCHISMGrid, time: Optional[TimeRange]) -> xr.Dataset:
"""Generate SCHISM boundary dataset from source data.
This function extracts and formats boundary data for SCHISM from a source dataset.
For 3D models, it handles vertical interpolation to the SCHISM sigma levels.
Parameters
----------
grid : SCHISMGrid
The SCHISM grid to extract boundary data for
time : Optional[TimeRange]
The time range to filter data to, if crop_data is True
Returns
-------
xr.Dataset
Dataset formatted for SCHISM boundary input
"""
logger.debug(f"Fetching {self.id}")
if self.crop_data and time is not None:
self._filter_time(time)
# Extract boundary data from source
ds = self._sel_boundary(grid)
# Calculate time step
if len(ds.time) > 1:
dt = total_seconds((ds.time[1] - ds.time[0]).values)
else:
dt = 3600
# Get the variable data - handle multiple variables (e.g., u,v for velocity)
num_components = len(self.variables)
# Process all variables and stack them
variable_data = []
for var in self.variables:
variable_data.append(ds[var].values)
# Stack variables along a new component axis (last axis)
if num_components == 1:
data = variable_data[0]
else:
data = np.stack(variable_data, axis=-1)
# Determine if we're working with 3D data
is_3d_data = grid.is_3d and self.coords.z is not None
# Handle different data dimensions based on 2D or 3D
if is_3d_data:
# Try to determine the dimension order
if hasattr(ds[self.variables[0]], "dims"):
# Get dimension names
dims = list(ds[self.variables[0]].dims)
# Find indices of time, z, and x dimensions
time_dim_idx = dims.index(ds.time.dims[0])
z_dim_idx = (
dims.index(ds[self.coords.z].dims[0])
if self.coords and self.coords.z and self.coords.z in ds
else 1
)
x_dim_idx = (
dims.index(ds[self.coords.x].dims[0])
if self.coords and self.coords.x and self.coords.x in ds
else 2
)
logger.debug(
f"Dimension order: time={time_dim_idx}, z={z_dim_idx}, x={x_dim_idx}"
)
# Reshape data to expected format if needed (time, x, z, [components])
if num_components == 1:
# Single component case - need to transpose to (time, x, z)
if not (time_dim_idx == 0 and x_dim_idx == 1 and z_dim_idx == 2):
trans_dims = list(range(data.ndim))
trans_dims[time_dim_idx] = 0
trans_dims[x_dim_idx] = 1
trans_dims[z_dim_idx] = 2
data = np.transpose(data, trans_dims)
logger.debug(f"Transposed data shape: {data.shape}")
# Add the component dimension for SCHISM
time_series = np.expand_dims(data, axis=3)
else:
# Multiple component case - data is already (time, x, z, components)
# Need to transpose the first 3 dimensions to (time, x, z) if needed
if not (time_dim_idx == 0 and x_dim_idx == 1 and z_dim_idx == 2):
trans_dims = list(
range(data.ndim - 1)
) # Exclude component axis
trans_dims[time_dim_idx] = 0
trans_dims[x_dim_idx] = 1
trans_dims[z_dim_idx] = 2
# Keep component axis at the end
trans_dims.append(data.ndim - 1)
data = np.transpose(data, trans_dims)
logger.debug(f"Transposed data shape: {data.shape}")
# Data already has component dimension from stacking
time_series = data
else:
# Fallback: add component dimension if needed
if num_components == 1:
time_series = np.expand_dims(data, axis=3)
else:
time_series = data
# Calculate zcor for 3D
# For PyLibs vgrid, extract sigma coordinates differently
gd = grid.pylibs_hgrid
vgd = grid.pylibs_vgrid
# Make sure boundaries are computed
if hasattr(gd, "compute_bnd") and not hasattr(gd, "nob"):
gd.compute_bnd()
# Extract boundary information
if not hasattr(gd, "nob") or gd.nob is None or gd.nob == 0:
raise ValueError("No open boundary nodes found in the grid")
# Collect all boundary nodes
boundary_indices = []
for i in range(gd.nob):
boundary_indices.extend(gd.iobn[i])
# Get bathymetry for boundary nodes
boundary_depths = gd.dp[boundary_indices]
# Get sigma levels from vgrid
# Note: This assumes a simple sigma or SZ grid format
# For more complex vgrids, more sophisticated extraction would be needed
if vgd is not None:
if hasattr(vgd, "sigma"):
sigma_levels = vgd.sigma.copy()
num_sigma_levels = len(sigma_levels)
else:
# Default sigma levels if not available
sigma_levels = np.array([-1.0, 0.0])
num_sigma_levels = 2
# Get fixed z levels if available
if hasattr(vgd, "ztot"):
z_levels = vgd.ztot
else:
z_levels = np.array([])
# For each boundary point, determine the total number of vertical levels
# and create appropriate zcor arrays
all_zcors = []
all_nvrt = []
for i, (node_idx, depth) in enumerate(
zip(boundary_indices, boundary_depths)
):
# Check if we're in deep water (depth > first z level)
if z_levels.size > 0 and depth > z_levels[0]:
# In deep water, find applicable z levels (between first z level and actual depth)
first_z_level = z_levels[0]
z_mask = (z_levels > first_z_level) & (z_levels < depth)
applicable_z = z_levels[z_mask] if np.any(z_mask) else []
# Total levels = sigma levels + applicable z levels
total_levels = num_sigma_levels + len(applicable_z)
# Create zcor for this boundary point
node_zcor = np.zeros(total_levels)
# First, calculate sigma levels using the first z level as the "floor"
for j in range(num_sigma_levels):
node_zcor[j] = first_z_level * sigma_levels[j]
# Then, add the fixed z levels below the sigma levels
for j, z_val in enumerate(applicable_z):
node_zcor[num_sigma_levels + j] = z_val
else:
# In shallow water, just use sigma levels scaled to the actual depth
total_levels = num_sigma_levels
# Create zcor for this boundary point
node_zcor = np.zeros(total_levels)
for j in range(total_levels):
node_zcor[j] = depth * sigma_levels[j]
# Store this boundary point's zcor and number of levels
all_zcors.append(node_zcor)
all_nvrt.append(total_levels)
# Now we have a list of zcor arrays with potentially different lengths
# Find the maximum number of levels across all boundary points
max_nvrt = max(all_nvrt) if all_nvrt else num_sigma_levels
# Create a uniform zcor array with the maximum number of levels
zcor = np.zeros((len(boundary_indices), max_nvrt))
# Fill in the values, leaving zeros for levels beyond a particular boundary point's total
for i, (node_zcor, nvrt_i) in enumerate(zip(all_zcors, all_nvrt)):
zcor[i, :nvrt_i] = node_zcor
# Get source z-levels and prepare for interpolation
sigma_values = (
ds[self.coords.z].values
if self.coords and self.coords.z
else np.array([0])
)
data_shape = time_series.shape
# Initialize interpolated data array with the maximum number of vertical levels
if num_components == 1:
interpolated_data = np.zeros((data_shape[0], data_shape[1], max_nvrt))
else:
interpolated_data = np.zeros(
(data_shape[0], data_shape[1], max_nvrt, data_shape[3])
)
# For each time step and boundary point
for t in range(data_shape[0]): # time
for n in range(data_shape[1]): # boundary points
# Get z-coordinates for this point
z_dest = zcor[n, :]
nvrt_n = all_nvrt[
n
] # Get the number of vertical levels for this point
if num_components == 1:
# Extract vertical profile for single component
profile = time_series[t, n, :, 0]
# Create interpolator for this profile
interp = sp.interpolate.interp1d(
sigma_values,
profile,
kind="linear",
bounds_error=False,
fill_value="extrapolate",
)
# Interpolate to SCHISM levels for this boundary point
# Only interpolate up to the actual number of levels for this point
interpolated_data[t, n, :nvrt_n] = interp(z_dest[:nvrt_n])
else:
# Handle multiple components (e.g., u,v for velocity)
for c in range(num_components):
# Extract vertical profile for this component
profile = time_series[t, n, :, c]
# Create interpolator for this profile
interp = sp.interpolate.interp1d(
sigma_values,
profile,
kind="linear",
bounds_error=False,
fill_value="extrapolate",
)
# Interpolate to SCHISM levels for this boundary point
# Only interpolate up to the actual number of levels for this point
interpolated_data[t, n, :nvrt_n, c] = interp(
z_dest[:nvrt_n]
)
# Replace data with interpolated values
data = interpolated_data
if num_components == 1:
time_series = np.expand_dims(data, axis=3)
else:
time_series = data
# Store the variable vertical levels in the output dataset
# Create a 2D array where each row contains the vertical levels for a boundary node
# For nodes with fewer levels, pad with NaN
vert_levels = np.full((len(boundary_indices), max_nvrt), np.nan)
for i, (node_zcor, nvrt_i) in enumerate(zip(all_zcors, all_nvrt)):
vert_levels[i, :nvrt_i] = node_zcor
# Create output dataset
schism_ds = xr.Dataset(
coords={
"time": ds.time,
"nOpenBndNodes": np.arange(time_series.shape[1]),
"nLevels": np.arange(max_nvrt),
"nComponents": np.arange(num_components),
"one": np.array([1]),
},
data_vars={
"time_step": (("one"), np.array([dt])),
"time_series": (
("time", "nOpenBndNodes", "nLevels", "nComponents"),
time_series,
),
"vertical_levels": (
("nOpenBndNodes", "nLevels"),
vert_levels,
),
"num_levels": (
("nOpenBndNodes"),
np.array(all_nvrt),
),
},
)
else:
# # 2D case - simpler handling
# Add level and component dimensions for SCHISM
if num_components == 1:
time_series = np.expand_dims(data, axis=(2, 3))
else:
# Multiple components: add level dimension but keep component dimension
time_series = np.expand_dims(data, axis=2)
# Create output dataset
schism_ds = xr.Dataset(
coords={
"time": ds.time,
"nOpenBndNodes": np.arange(time_series.shape[1]),
"nLevels": np.array([0]), # Single level for 2D
"nComponents": np.arange(num_components),
"one": np.array([1]),
},
data_vars={
"time_step": (("one"), np.array([dt])),
"time_series": (
("time", "nOpenBndNodes", "nLevels", "nComponents"),
time_series,
),
},
)
# Set attributes and encoding
schism_ds.time_step.assign_attrs({"long_name": "time_step"})
basedate = pd.to_datetime(ds.time.values[0])
unit = f"days since {basedate.strftime('%Y-%m-%d %H:%M:%S')}"
schism_ds.time.attrs = {
"long_name": "Time",
"standard_name": "time",
"base_date": np.int32(
np.array(
[
basedate.year,
basedate.month,
basedate.day,
basedate.hour,
basedate.minute,
basedate.second,
]
)
),
}
schism_ds.time.encoding["units"] = unit
schism_ds.time.encoding["calendar"] = "proleptic_gregorian"
# Handle missing values more robustly
null_count = schism_ds.time_series.isnull().sum().item()
if null_count > 0:
logger.debug(f"Found {null_count} null values, applying interpolation and filling")
# Try interpolating along different dimensions
for dim in ["nOpenBndNodes", "time", "nLevels"]:
if dim in schism_ds.dims and len(schism_ds[dim]) > 1:
schism_ds["time_series"] = schism_ds.time_series.interpolate_na(
dim=dim
)
if not schism_ds.time_series.isnull().any():
logger.debug(f"Interpolated missing values along {dim} dimension")
break
# If still have NaNs, use more aggressive filling methods
if schism_ds.time_series.isnull().any():
# Find a reasonable fill value (median of non-NaN values)
valid_values = schism_ds.time_series.values[
~np.isnan(schism_ds.time_series.values)
]
fill_value = np.median(valid_values) if len(valid_values) > 0 else 0.0
schism_ds["time_series"] = schism_ds.time_series.fillna(fill_value)
logger.debug(f"Filled remaining nulls with constant value {fill_value}")
# Clean up encoding
for var in schism_ds.data_vars:
if "scale_factor" in schism_ds[var].encoding:
del schism_ds[var].encoding["scale_factor"]
if "add_offset" in schism_ds[var].encoding:
del schism_ds[var].encoding["add_offset"]
schism_ds[var].encoding["dtype"] = np.dtypes.Float64DType()
return schism_ds
[docs]
class SCHISMData(RompyBaseModel):
"""
This class is used to gather all required input forcing for SCHISM
"""
data_type: Literal["schism"] = Field(
default="schism",
description="Model type discriminator",
)
atmos: Optional[SCHISMDataSflux] = Field(None, description="atmospheric data")
wave: Optional[Union[DataBlob, SCHISMDataWave]] = Field(
None, description="wave data"
)
boundary_conditions: Optional["SCHISMDataBoundaryConditions"] = Field(
None, description="unified boundary conditions (replaces tides and ocean)"
)
[docs]
def get(
self,
destdir: str | Path,
grid: SCHISMGrid,
time: TimeRange,
) -> Dict[str, Any]:
"""
Process all SCHISM forcing data and generate necessary input files.
Parameters
----------
destdir : str | Path
Destination directory
grid : SCHISMGrid
SCHISM grid instance
time : TimeRange
Time range for the simulation
Returns
-------
Dict[str, Any]
Paths to generated files for each data component
"""
from rompy.formatting import ARROW
# Convert destdir to Path object
destdir = Path(destdir)
# Create destdir if it doesn't exist
if not destdir.exists():
destdir.mkdir(parents=True, exist_ok=True)
results = {}
# Process atmospheric data
if self.atmos:
logger.info(f"{ARROW} Processing atmospheric forcing data")
results["atmos"] = self.atmos.get(destdir, grid, time)
logger.info(f"{ARROW} Atmospheric data processed successfully")
# Process wave data
if self.wave:
logger.info(f"{ARROW} Processing wave boundary data")
# Get source information
if hasattr(self.wave, 'source') and hasattr(self.wave.source, 'uri'):
logger.info(f" • Source: {self.wave.source.uri}")
elif hasattr(self.wave, 'source') and hasattr(self.wave.source, 'catalog_uri'):
logger.info(f" • Source: {self.wave.source.catalog_uri} (dataset: {getattr(self.wave.source, 'dataset_id', 'unknown')})")
results["wave"] = self.wave.get(destdir, grid, time)
logger.info(f" • Output: {results['wave']}")
logger.info(f"{ARROW} Wave data processed successfully")
# Process boundary conditions
if self.boundary_conditions:
logger.info(f"{ARROW} Processing boundary conditions")
results["boundary_conditions"] = self.boundary_conditions.get(
destdir, grid, time
)
logger.info(f"{ARROW} Boundary conditions processed successfully")
return results
def _format_value(self, obj):
"""Custom formatter for SCHISMData values.
This method provides special formatting for specific types used in
SCHISMData such as atmospheric, wave, and boundary data components.
Args:
obj: The object to format
Returns:
A formatted string or None to use default formatting
"""
# Import specific types and formatting utilities
from rompy.logging import LoggingConfig
from rompy.formatting import get_formatted_header_footer
# Get ASCII mode setting from LoggingConfig
logging_config = LoggingConfig()
USE_ASCII_ONLY = logging_config.use_ascii
# Format SCHISMData (self-formatting)
if isinstance(obj, SCHISMData):
header, footer, bullet = get_formatted_header_footer(
title="SCHISM DATA CONFIGURATION", use_ascii=USE_ASCII_ONLY
)
lines = [header]
# Count and list data components
components = {}
if hasattr(obj, "atmos") and obj.atmos is not None:
components["Atmospheric"] = type(obj.atmos).__name__
# Add details for atmospheric data
if hasattr(obj.atmos, "air_1") and obj.atmos.air_1 is not None:
air_sources = 1
if hasattr(obj.atmos, "air_2") and obj.atmos.air_2 is not None:
air_sources = 2
lines.append(f" Air sources: {air_sources}")
if hasattr(obj.atmos, "rad_1") and obj.atmos.rad_1 is not None:
rad_sources = 1
if hasattr(obj.atmos, "rad_2") and obj.atmos.rad_2 is not None:
rad_sources = 2
lines.append(f" Radiation sources: {rad_sources}")
if hasattr(obj, "wave") and obj.wave is not None:
components["Wave"] = type(obj.wave).__name__
if hasattr(obj, "boundary_conditions") and obj.boundary_conditions is not None:
components["Boundary Conditions"] = type(obj.boundary_conditions).__name__
for comp_name, comp_type in components.items():
lines.append(f" {bullet} {comp_name}: {comp_type}")
if not components:
lines.append(f" {bullet} No data components configured")
lines.append(footer)
return "\n".join(lines)
# Format SCHISMDataSflux
if isinstance(obj, SCHISMDataSflux):
header, footer, bullet = get_formatted_header_footer(
title="ATMOSPHERIC DATA (SFLUX)", use_ascii=USE_ASCII_ONLY
)
lines = [header]
# Count air sources
air_sources = 0
if hasattr(obj, "air_1") and obj.air_1 is not None:
air_sources += 1
if hasattr(obj, "air_2") and obj.air_2 is not None:
air_sources += 1
if air_sources > 0:
lines.append(f" {bullet} Air sources: {air_sources}")
# Count radiation sources
rad_sources = 0
if hasattr(obj, "rad_1") and obj.rad_1 is not None:
rad_sources += 1
if hasattr(obj, "rad_2") and obj.rad_2 is not None:
rad_sources += 1
if rad_sources > 0:
lines.append(f" {bullet} Radiation sources: {rad_sources}")
# Check for precipitation
if hasattr(obj, "prc_1") and obj.prc_1 is not None:
lines.append(f" {bullet} Precipitation: Available")
lines.append(footer)
return "\n".join(lines)
# Format SCHISMDataWave
if isinstance(obj, SCHISMDataWave):
header, footer, bullet = get_formatted_header_footer(
title="WAVE DATA", use_ascii=USE_ASCII_ONLY
)
lines = [header]
if hasattr(obj, "sel_method"):
lines.append(f" {bullet} Selection method: {obj.sel_method}")
if hasattr(obj, "source") and obj.source is not None:
source_type = type(obj.source).__name__
lines.append(f" {bullet} Source: {source_type}")
lines.append(footer)
return "\n".join(lines)
# Format SCHISMDataBoundaryConditions
if isinstance(obj, SCHISMDataBoundaryConditions):
header, footer, bullet = get_formatted_header_footer(
title="BOUNDARY CONDITIONS", use_ascii=USE_ASCII_ONLY
)
lines = [header]
# Count boundary setups
boundary_count = 0
if hasattr(obj, "boundaries") and obj.boundaries is not None:
if isinstance(obj.boundaries, list):
boundary_count = len(obj.boundaries)
else:
boundary_count = 1
if boundary_count > 0:
lines.append(f" {bullet} Boundary setups: {boundary_count}")
# Check for tidal components
if hasattr(obj, "tidal") and obj.tidal is not None:
lines.append(f" {bullet} Tidal forcing: Available")
lines.append(footer)
return "\n".join(lines)
# Use the new formatting framework
from rompy.formatting import format_value
return format_value(obj)
[docs]
class HotstartConfig(RompyBaseModel):
"""
Configuration for generating SCHISM hotstart files.
This class specifies parameters for creating hotstart.nc files from
temperature and salinity data sources already configured in boundary conditions.
"""
enabled: bool = Field(
default=False, description="Whether to generate hotstart file"
)
temp_var: str = Field(
default="temperature",
description="Name of temperature variable in source dataset",
)
salt_var: str = Field(
default="salinity", description="Name of salinity variable in source dataset"
)
time_offset: float = Field(
default=0.0, description="Offset to add to source time values (in days)"
)
time_base: datetime = Field(
default=datetime(2000, 1, 1), description="Base time for source time values"
)
output_filename: str = Field(
default="hotstart.nc", description="Name of the output hotstart file"
)
[docs]
class BoundarySetupWithSource(BoundarySetup):
"""
Enhanced boundary setup that includes data sources.
This class extends BoundarySetup to provide a unified configuration
for both boundary conditions and their data sources.
"""
elev_source: Optional[Union[DataBlob, DataGrid, SCHISMDataBoundary]] = Field(
None, description="Data source for elevation boundary condition"
)
vel_source: Optional[Union[DataBlob, DataGrid, SCHISMDataBoundary]] = Field(
None, description="Data source for velocity boundary condition"
)
temp_source: Optional[Union[DataBlob, DataGrid, SCHISMDataBoundary]] = Field(
None, description="Data source for temperature boundary condition"
)
salt_source: Optional[Union[DataBlob, DataGrid, SCHISMDataBoundary]] = Field(
None, description="Data source for salinity boundary condition"
)
[docs]
@model_validator(mode="after")
def validate_data_sources(self):
"""Ensure data sources are provided when needed for space-time boundary types."""
# Check elevation data source
if (
self.elev_type in [ElevationType.EXTERNAL, ElevationType.HARMONICEXTERNAL]
and self.elev_source is None
):
logger.warning(
"elev_source should be provided for EXTERNAL or HARMONICEXTERNAL elevation type"
)
# Check velocity data source
if (
self.vel_type
in [
VelocityType.EXTERNAL,
VelocityType.HARMONICEXTERNAL,
VelocityType.RELAXED,
]
and self.vel_source is None
):
logger.warning(
"vel_source should be provided for EXTERNAL, HARMONICEXTERNAL, or RELAXED velocity type"
)
# Check temperature data source
if self.temp_type == TracerType.EXTERNAL and self.temp_source is None:
logger.warning(
"temp_source should be provided for EXTERNAL temperature type"
)
# Check salinity data source
if self.salt_type == TracerType.EXTERNAL and self.salt_source is None:
logger.warning("salt_source should be provided for EXTERNAL salinity type")
return self
[docs]
class SCHISMDataBoundaryConditions(RompyBaseModel):
"""
This class configures all boundary conditions for SCHISM including tidal,
ocean, river, and nested model boundaries.
It provides a unified interface for specifying boundary conditions and their
data sources, replacing the separate tides and ocean configurations.
"""
# Allow arbitrary types for schema generation
model_config = ConfigDict(arbitrary_types_allowed=True)
data_type: Literal["boundary_conditions"] = Field(
default="boundary_conditions",
description="Model type discriminator",
)
# Tidal dataset specification
tidal_data: Optional[TidalDataset] = Field(
None,
description="Tidal forcing dataset",
)
# Boundary configurations with integrated data sources
boundaries: Dict[int, BoundarySetupWithSource] = Field(
default_factory=dict,
description="Boundary configuration by boundary index",
)
# Predefined configuration types
setup_type: Optional[Literal["tidal", "hybrid", "river", "nested"]] = Field(
None, description="Predefined boundary setup type"
)
# Hotstart configuration
hotstart_config: Optional[HotstartConfig] = Field(
None, description="Configuration for hotstart file generation"
)
[docs]
@model_validator(mode="before")
@classmethod
def convert_numpy_types(cls, data):
"""Convert any numpy values to Python native types"""
if not isinstance(data, dict):
return data
for key, value in list(data.items()):
if isinstance(value, (np.bool_, np.integer, np.floating, np.ndarray)):
data[key] = to_python_type(value)
return data
[docs]
@model_validator(mode="after")
def validate_tidal_data(self):
"""Ensure tidal data is provided when needed for TIDAL or TIDALSPACETIME boundaries."""
boundaries = self.boundaries or {}
needs_tidal_data = False
# Check setup_type first
if self.setup_type in ["tidal", "hybrid"]:
needs_tidal_data = True
# Then check individual boundaries
for setup in boundaries.values():
if (
hasattr(setup, "elev_type")
and setup.elev_type
in [ElevationType.HARMONIC, ElevationType.HARMONICEXTERNAL]
) or (
hasattr(setup, "vel_type")
and setup.vel_type
in [VelocityType.HARMONIC, VelocityType.HARMONICEXTERNAL]
):
needs_tidal_data = True
break
if needs_tidal_data and not self.tidal_data:
raise ValueError(
"Tidal data is required for HARMONIC or HARMONICEXTERNAL boundary types but was not provided"
)
return self
[docs]
@model_validator(mode="after")
def validate_setup_type(self):
"""Validate setup type specific requirements."""
# Skip validation if setup_type is not set
if not self.setup_type:
return self
if self.setup_type in ["tidal", "hybrid"]:
if not self.tidal_data:
raise ValueError(
"tidal_data is required for tidal or hybrid setup_type"
)
elif self.setup_type == "river":
if self.boundaries:
has_flow = any(
hasattr(s, "const_flow") and s.const_flow is not None
for s in self.boundaries.values()
)
if not has_flow:
raise ValueError(
"At least one boundary should have const_flow for river setup_type"
)
elif self.setup_type == "nested":
if self.boundaries:
for idx, setup in self.boundaries.items():
if (
hasattr(setup, "vel_type")
and setup.vel_type == VelocityType.RELAXED
):
if not hasattr(setup, "inflow_relax") or not hasattr(
setup, "outflow_relax"
):
logger.warning(
f"inflow_relax and outflow_relax are recommended for nested setup_type in boundary {idx}"
)
else:
raise ValueError(
f"Unknown setup_type: {self.setup_type}. Expected one of: tidal, hybrid, river, nested"
)
return self
def _create_boundary_config(self, grid):
"""Create a TidalBoundary object based on the configuration."""
# Get tidal data paths
tidal_database = None
if self.tidal_data:
if (
hasattr(self.tidal_data, "tidal_database")
and self.tidal_data.tidal_database
):
tidal_database = str(self.tidal_data.tidal_database)
# Ensure boundary information is computed
if hasattr(grid.pylibs_hgrid, "compute_bnd"):
grid.pylibs_hgrid.compute_bnd()
else:
logger.warning(
"Grid object doesn't have compute_bnd method. Boundary information may be missing."
)
# Create a new TidalBoundary with all the configuration
# Ensure boundary information is computed before creating the boundary
if not hasattr(grid.pylibs_hgrid, "nob") or not hasattr(
grid.pylibs_hgrid, "nobn"
):
logger.info("Computing boundary information before creating TidalBoundary")
# First try compute_bnd if available
if hasattr(grid.pylibs_hgrid, "compute_bnd"):
grid.pylibs_hgrid.compute_bnd()
# Then try compute_all if nob is still missing
if not hasattr(grid.pylibs_hgrid, "nob") and hasattr(
grid.pylibs_hgrid, "compute_all"
):
if hasattr(grid.pylibs_hgrid, "compute_all"):
grid.pylibs_hgrid.compute_all()
# Verify boundary attributes are available
if not hasattr(grid.pylibs_hgrid, "nob"):
logger.error("Failed to set 'nob' attribute on grid.pylibs_hgrid")
raise AttributeError(
"Missing required 'nob' attribute on grid.pylibs_hgrid"
)
# Create TidalBoundary with pre-computed grid to avoid losing boundary info
# Get the grid path for TidalBoundary
grid_path = (
str(grid.hgrid.path)
if hasattr(grid, "hgrid") and hasattr(grid.hgrid, "path")
else None
)
if grid_path is None:
# Create a temporary file with the grid if needed
import tempfile
temp_file = tempfile.NamedTemporaryFile(suffix=".gr3", delete=False)
temp_path = temp_file.name
temp_file.close()
grid.pylibs_hgrid.write_hgrid(temp_path)
grid_path = temp_path
boundary = BoundaryHandler(grid_path=grid_path, tidal_data=self.tidal_data)
# Replace the TidalBoundary's grid with our pre-computed one to preserve boundary info
boundary.grid = grid.pylibs_hgrid
# Configure each boundary segment
for idx, setup in self.boundaries.items():
boundary_config = setup.to_boundary_config()
boundary.set_boundary_config(idx, boundary_config)
return boundary
[docs]
def get(
self,
destdir: str | Path,
grid: SCHISMGrid,
time: TimeRange,
) -> Dict[str, str]:
"""
Process all boundary data and generate necessary input files.
Parameters
----------
destdir : str | Path
Destination directory
grid : SCHISMGrid
SCHISM grid instance
time : TimeRange
Time range for the simulation
Returns
-------
Dict[str, str]
Paths to generated files
"""
# Processing boundary conditions
# Convert destdir to Path object
destdir = Path(destdir)
# Create destdir if it doesn't exist
if not destdir.exists():
logger.info(f"Creating destination directory: {destdir}")
destdir.mkdir(parents=True, exist_ok=True)
# # 1. Process tidal data if needed
if self.tidal_data:
logger.info(f"{ARROW} Processing tidal constituents: {', '.join(self.tidal_data.constituents) if hasattr(self.tidal_data, 'constituents') else 'default'}")
self.tidal_data.get(grid)
# 2. Create boundary condition file (bctides.in)
boundary = self._create_boundary_config(grid)
# Set start time and run duration
start_time = time.start
if time.end is not None and time.start is not None:
run_days = (
time.end - time.start
).total_seconds() / 86400.0 # Convert to days
else:
run_days = 1.0 # Default to 1 day if time is not properly specified
boundary.set_run_parameters(start_time, run_days)
# Generate bctides.in file
bctides_path = destdir / "bctides.in"
logger.info(f"{ARROW} Generating boundary condition file: bctides.in")
# Ensure grid object has complete boundary information before writing
if hasattr(grid.pylibs_hgrid, "compute_all"):
grid.pylibs_hgrid.compute_all()
# Double-check all required attributes are present
required_attrs = ["nob", "nobn", "iobn"]
missing_attrs = [
attr
for attr in required_attrs
if not (grid.pylibs_hgrid and hasattr(grid.pylibs_hgrid, attr))
]
if missing_attrs:
error_msg = (
f"Grid is missing required attributes: {', '.join(missing_attrs)}"
)
logger.error(error_msg)
raise AttributeError(error_msg)
# Write the boundary file - no fallbacks
boundary.write_boundary_file(bctides_path)
logger.info(f"{ARROW} Boundary conditions written successfully")
# 3. Process ocean data based on boundary configurations
processed_files = {"bctides": str(bctides_path)}
# Collect variables to process and source information for logging
variables_to_process = []
source_files = set()
for idx, setup in self.boundaries.items():
if setup.elev_type in [ElevationType.EXTERNAL, ElevationType.HARMONICEXTERNAL] and setup.elev_source:
variables_to_process.append("elevation")
if hasattr(setup.elev_source, 'source') and hasattr(setup.elev_source.source, 'uri'):
source_files.add(str(setup.elev_source.source.uri))
if setup.vel_type in [VelocityType.EXTERNAL, VelocityType.HARMONICEXTERNAL, VelocityType.RELAXED] and setup.vel_source:
variables_to_process.append("velocity")
if hasattr(setup.vel_source, 'source') and hasattr(setup.vel_source.source, 'uri'):
source_files.add(str(setup.vel_source.source.uri))
if setup.temp_type == TracerType.EXTERNAL and setup.temp_source:
variables_to_process.append("temperature")
if hasattr(setup.temp_source, 'source') and hasattr(setup.temp_source.source, 'uri'):
source_files.add(str(setup.temp_source.source.uri))
if setup.salt_type == TracerType.EXTERNAL and setup.salt_source:
variables_to_process.append("salinity")
if hasattr(setup.salt_source, 'source') and hasattr(setup.salt_source.source, 'uri'):
source_files.add(str(setup.salt_source.source.uri))
if variables_to_process:
unique_vars = list(dict.fromkeys(variables_to_process)) # Remove duplicates while preserving order
logger.info(f"{ARROW} Processing boundary data: {', '.join(unique_vars)}")
if source_files:
if len(source_files) == 1:
logger.info(f" • Source: {list(source_files)[0]}")
else:
logger.info(f" • Sources: {len(source_files)} files")
# Process each data source based on the boundary type
for idx, setup in self.boundaries.items():
# Process elevation data if needed
if setup.elev_type in [
ElevationType.EXTERNAL,
ElevationType.HARMONICEXTERNAL,
]:
if setup.elev_source:
if (
hasattr(setup.elev_source, "data_type")
and setup.elev_source.data_type == "boundary"
):
# Process using SCHISMDataBoundary interface
setup.elev_source.id = "elev2D" # Set the ID for the boundary
file_path = setup.elev_source.get(destdir, grid, time)
else:
# Process using DataBlob interface
file_path = setup.elev_source.get(str(destdir))
processed_files[f"elev_boundary_{idx}"] = file_path
# Process velocity data if needed
if setup.vel_type in [
VelocityType.EXTERNAL,
VelocityType.HARMONICEXTERNAL,
VelocityType.RELAXED,
]:
if setup.vel_source:
if (
hasattr(setup.vel_source, "data_type")
and setup.vel_source.data_type == "boundary"
):
# Process using SCHISMDataBoundary interface
setup.vel_source.id = "uv3D" # Set the ID for the boundary
file_path = setup.vel_source.get(destdir, grid, time)
else:
# Process using DataBlob interface
file_path = setup.vel_source.get(str(destdir))
processed_files[f"vel_boundary_{idx}"] = file_path
# Process temperature data if needed
if setup.temp_type == TracerType.EXTERNAL:
if setup.temp_source:
if (
hasattr(setup.temp_source, "data_type")
and setup.temp_source.data_type == "boundary"
):
# Process using SCHISMDataBoundary interface
setup.temp_source.id = "TEM_3D" # Set the ID for the boundary
file_path = setup.temp_source.get(destdir, grid, time)
else:
# Process using DataBlob interface
file_path = setup.temp_source.get(str(destdir))
processed_files[f"temp_boundary_{idx}"] = file_path
# Process salinity data if needed
if setup.salt_type == TracerType.EXTERNAL:
if setup.salt_source:
if (
hasattr(setup.salt_source, "data_type")
and setup.salt_source.data_type == "boundary"
):
# Process using SCHISMDataBoundary interface
setup.salt_source.id = "SAL_3D" # Set the ID for the boundary
file_path = setup.salt_source.get(destdir, grid, time)
else:
# Process using DataBlob interface
file_path = setup.salt_source.get(str(destdir))
processed_files[f"salt_boundary_{idx}"] = file_path
# Generate hotstart file if configured
if self.hotstart_config and self.hotstart_config.enabled:
logger.info(f"{ARROW} Generating hotstart file")
hotstart_path = self._generate_hotstart(destdir, grid, time)
processed_files["hotstart"] = hotstart_path
logger.info(f" • Output: {hotstart_path}")
# Log summary of processed files with more details
boundary_data_files = [f for k, f in processed_files.items() if 'boundary' in k]
if boundary_data_files:
logger.info(f" • Files: {', '.join([Path(f).name for f in boundary_data_files])}")
return processed_files
def _generate_hotstart(
self,
destdir: Union[str, Path],
grid: SCHISMGrid,
time: Optional[TimeRange] = None,
) -> str:
"""
Generate hotstart file using boundary condition data sources.
Args:
destdir: Destination directory for the hotstart file
grid: SCHISM grid object
time: Time range for the data
Returns:
Path to the generated hotstart file
"""
from rompy.schism.hotstart import SCHISMDataHotstart
# Find a boundary that has both temperature and salinity sources
temp_source = None
salt_source = None
for boundary_config in self.boundaries.values():
if boundary_config.temp_source is not None:
temp_source = boundary_config.temp_source
if boundary_config.salt_source is not None:
salt_source = boundary_config.salt_source
# If we found both, we can proceed
if temp_source is not None and salt_source is not None:
break
if temp_source is None or salt_source is None:
raise ValueError(
"Hotstart generation requires both temperature and salinity sources "
"to be configured in boundary conditions"
)
# Create hotstart instance using the first available source
# (assuming temp and salt sources point to the same dataset)
# Include both temperature and salinity variables for hotstart generation
temp_var_name = (
self.hotstart_config.temp_var if self.hotstart_config else "temperature"
)
salt_var_name = (
self.hotstart_config.salt_var if self.hotstart_config else "salinity"
)
# Log hotstart generation details
logger.info(f" • Variables: {temp_var_name}, {salt_var_name}")
if hasattr(temp_source, 'source') and hasattr(temp_source.source, 'uri'):
logger.info(f" • Source: {temp_source.source.uri}")
hotstart_data = SCHISMDataHotstart(
source=temp_source.source,
variables=[temp_var_name, salt_var_name],
coords=getattr(temp_source, "coords", None),
temp_var=temp_var_name,
salt_var=salt_var_name,
time_offset=(
self.hotstart_config.time_offset if self.hotstart_config else 0.0
),
time_base=(
self.hotstart_config.time_base
if self.hotstart_config
else datetime(2000, 1, 1)
),
output_filename=(
self.hotstart_config.output_filename
if self.hotstart_config
else "hotstart.nc"
),
)
return hotstart_data.get(str(destdir), grid=grid, time=time)
# def check_bctides_flags(cls, v):
# # TODO Add check fro bc flags in teh event of 3d inputs
# # SHould possibly move this these flags out of SCHISMDataTides class as they cover more than
# # just tides
# return cls
def get_valid_rename_dict(ds, rename_dict):
"""Construct a valid renaming dictionary that only includes names which need renaming."""
valid_rename_dict = {}
for old_name, new_name in rename_dict.items():
if old_name in ds.dims and new_name not in ds.dims:
valid_rename_dict[old_name] = new_name
return valid_rename_dict