88import numpy as np
99from bmipy .bmi import Bmi
1010from PyStemmusScope .bmi .utils import InapplicableBmiMethods
11+ from PyStemmusScope .bmi .utils import nested_set
12+ from PyStemmusScope .bmi .variable_reference import VARIABLES
13+ from PyStemmusScope .bmi .variable_reference import BmiVariable
1114from PyStemmusScope .config_io import read_config
1215
1316
14- MODEL_INPUT_VARNAMES : tuple [str , ...] = ("soil_temperature" ,)
15-
16- MODEL_OUTPUT_VARNAMES : tuple [str , ...] = (
17- "soil_temperature" ,
18- "respiration" ,
17+ MODEL_INPUT_VARNAMES : tuple [str , ...] = tuple (
18+ var .name for var in VARIABLES if var .input
1919)
2020
21- MODEL_VARNAMES : tuple [str , ...] = tuple (
22- set ( MODEL_INPUT_VARNAMES + MODEL_OUTPUT_VARNAMES )
21+ MODEL_OUTPUT_VARNAMES : tuple [str , ...] = tuple (
22+ var . name for var in VARIABLES if var . output
2323)
2424
25- VARNAME_UNITS : dict [str , str ] = {"respiration" : "unknown" , "soil_temperature" : "degC" }
25+ MODEL_VARS : dict [str , BmiVariable ] = {var .name : var for var in VARIABLES }
26+
27+ MODEL_VARNAMES : tuple [str , ...] = tuple (var .name for var in VARIABLES )
28+
29+ VARNAME_UNITS : dict [str , str ] = {var .name : var .units for var in VARIABLES }
2630
27- VARNAME_DTYPE : dict [str , str ] = {
28- "respiration" : "float64" ,
29- "soil_temperature" : "float64" ,
30- }
31+ VARNAME_DTYPE : dict [str , str ] = {var .name : var .dtype for var in VARIABLES }
3132
32- VARNAME_GRID : dict [str , int ] = {
33- "respiration" : 0 ,
34- "soil_temperature" : 1 ,
35- }
33+ VARNAME_GRID : dict [str , int ] = {var .name : var .grid for var in VARIABLES }
34+
35+ VARNAME_LOC : dict [str , list [str ]] = {var .name : var .keys for var in VARIABLES }
3636
3737NO_STATE_MSG = (
3838 "The model state is not available. Please run `.update()` before requesting "
@@ -59,23 +59,32 @@ def load_state(config: dict) -> h5py.File:
5959 return h5py .File (matfile , mode = "a" )
6060
6161
62- def get_variable (state : h5py .File , varname : str ) -> np .ndarray :
62+ def get_variable (
63+ state : h5py .File , varname : str
64+ ) -> np .ndarray : # noqa: PLR0911 PLR0912 C901
6365 """Get a variable from the model state.
6466
6567 Args:
6668 state: STEMMUS_SCOPE model state
6769 varname: Variable name
6870 """
69- if varname == "respiration" :
70- return state ["fluxes" ]["Resp" ][0 ]
71+ if varname not in MODEL_VARNAMES :
72+ msg = "Unknown variable name"
73+ raise ValueError (msg )
74+
75+ # deviating implemetation:
7176 elif varname == "soil_temperature" :
7277 return state ["TT" ][0 , :- 1 ]
78+
79+ # default implementation:
80+ _s = state
81+ for _loc in VARNAME_LOC [varname ]:
82+ _s = _s .get (_loc )
83+
84+ if MODEL_VARS [varname ].all_timesteps :
85+ return _s [0 ].astype (VARNAME_DTYPE [varname ])[[int (state ["KT" ][0 ])]]
7386 else :
74- if varname in MODEL_VARNAMES :
75- msg = "Varname is missing in get_variable! Contact devs."
76- else :
77- msg = "Unknown variable name"
78- raise ValueError (msg )
87+ return _s [0 ].astype (VARNAME_DTYPE [varname ])
7988
8089
8190def set_variable (
@@ -101,16 +110,21 @@ def set_variable(
101110 else :
102111 vals = value
103112
113+ if varname in MODEL_OUTPUT_VARNAMES and varname not in MODEL_INPUT_VARNAMES :
114+ msg = "This variable is a model output variable only. You cannot set it."
115+ raise ValueError (msg )
116+ elif varname not in MODEL_INPUT_VARNAMES :
117+ msg = "Uknown variable name"
118+ raise ValueError (msg )
119+
120+ # deviating implementations:
104121 if varname == "soil_temperature" :
105122 state ["TT" ][0 , :- 1 ] = vals
123+ elif varname == "groundwater_coupling_enabled" :
124+ state ["GroundwaterSettings" ]["GroundwaterCoupling" ][0 ] = vals .astype ("float" )
125+ # default:
106126 else :
107- if varname in MODEL_OUTPUT_VARNAMES and varname not in MODEL_INPUT_VARNAMES :
108- msg = "This variable is a model output variable only. You cannot set it."
109- elif varname in MODEL_VARNAMES :
110- msg = "Varname is missing in set_variable! Contact devs."
111- else :
112- msg = "Uknown variable name"
113- raise ValueError (msg )
127+ nested_set (state , VARNAME_LOC [varname ] + [0 ], vals )
114128 return state
115129
116130
@@ -401,6 +415,9 @@ def set_value(self, name: str, src: np.ndarray) -> None:
401415 """
402416 if self .state is None :
403417 raise ValueError (NO_STATE_MSG )
418+ if src .size != self .get_grid_size (self .get_var_grid (name )):
419+ msg = f"Size of `src` and variable '{ name } ' grid size are not equal!"
420+ raise ValueError (msg )
404421 self .state = set_variable (self .state , name , src )
405422
406423 def set_value_at_indices (
@@ -419,6 +436,9 @@ def set_value_at_indices(
419436 """
420437 if self .state is None :
421438 raise ValueError (NO_STATE_MSG )
439+ if inds .size != src .size :
440+ msg = "Sizes of `inds` and `src` are not equal!"
441+ raise ValueError (msg )
422442 self .state = set_variable (self .state , name , src , inds )
423443
424444 ### GRID INFO ###
0 commit comments