Source code for econox.structures.params

# src/econox/structures/params.py

from __future__ import annotations
from typing import Dict, Any, Literal
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import PyTree

from econox.config import LOG_CLIP_MIN, LOG_CLIP_MAX, NUMERICAL_EPSILON

ConstraintKind = Literal[
    "free",           # (-inf, +inf)
    "positive",       # (0, +inf)
    "negative",       # (-inf, 0)
    "probability",    # (0, 1)
    "unit_interval",  # (0, 1) Alias for "probability" constraint
    "fixed",          # Fixed value
    "bounded",        # (lower, upper)
]
"""
Specifies the type of constraint applied to a parameter.

Options:
    - **free**: No constraints (-inf, +inf).
    - **positive**: Must be positive (0, +inf). Used for variances, etc.
    - **negative**: Must be negative (-inf, 0).
    - **probability**: Constrained to (0, 1).
    - **unit_interval**: Alias for "probability".
    - **fixed**: Parameter is fixed to its initial value and not optimized.
    - **bounded**: Constrained to a specific range [lower, upper].
"""

[docs] class ParameterSpace(eqx.Module): """ Manages parameter constraints and transformations with numerical stability. Compliant with the ParameterSpace protocol. Handles the mapping between: 1. Raw Parameters (Real space, R^n): For the optimizer. 2. Model Parameters (Constrained space): For the economic model. Examples: >>> # Define initial values >>> init_params = { ... "beta": 0.95, ... "sigma": 1.0, ... "alpha": 0.5, ... "gamma": 2.0 ... } >>> # Define constraints >>> constraints = { ... "beta": "fixed", # Not optimized ... "sigma": "positive", # Domain: (0, inf) ... "alpha": "probability", # Domain: (0, 1) ... "gamma": "free" # Domain: (-inf, inf) (Default) ... } >>> # Create the parameter space >>> pspace = ParameterSpace.create(init_params, constraints) """ # ---Fields (Immutable)--- initial_params: Dict[str, Any] """Initial values of the parameters (Constrained space).""" constraints: Dict[str, ConstraintKind] """Dictionary mapping parameter names to their constraint types.""" bounds: Dict[str, tuple[float, float]] """Dictionary mapping parameter names to (lower, upper) bounds.""" # ---Numerical Stability Constants--- eps: float = eqx.field(default=NUMERICAL_EPSILON, static=True) """Small constant for numerical stability.""" log_clip_min: float = eqx.field(default=LOG_CLIP_MIN, static=True) """Minimum value for log transformations.""" log_clip_max: float = eqx.field(default=LOG_CLIP_MAX, static=True) """Maximum value for log transformations.""" # ---Factory Method (Replace __init__)---
[docs] @classmethod def create( cls, initial_params: Dict[str, Any], constraints: Dict[str, ConstraintKind] | None = None, bounds: Dict[str, tuple[float, float]] | None = None, ) -> ParameterSpace: """ Factory method to initialize ParameterSpace. Validates keys, bounds, and fills default constraints ('free'). Args: initial_params (Dict[str, Any]): Dictionary of initial parameter values. constraints (Dict[str, ConstraintKind] | None): Optional dictionary specifying constraints for each parameter. Defaults to 'free' for unspecified parameters. bounds (Dict[str, tuple[float, float]] | None): Optional dictionary specifying (lower, upper) bounds for 'bounded' parameters. """ # Validate inputs if not initial_params: raise ValueError("initial_params cannot be empty") # Fill defaults filled_constraints = {} for k in initial_params.keys(): if constraints and k in constraints: filled_constraints[k] = constraints[k] else: filled_constraints[k] = "free" # Validate bounds filled_bounds = bounds or {} for k, kind in filled_constraints.items(): if kind == "bounded": if k not in filled_bounds: raise ValueError(f"Parameter '{k}' has 'bounded' constraint but no bounds specified.") # Validate bounds correctness lower, upper = filled_bounds[k] if lower > upper: raise ValueError(f"Bounds for '{k}' must satisfy lower <= upper, got ({lower}, {upper}).") # Treat as fixed if bounds are equal elif lower == upper: filled_constraints[k] = "fixed" # Validate initial value within bounds init_val = initial_params[k] if not (lower <= init_val <= upper): raise ValueError(f"Initial value for '{k}' ({init_val}) is out of bounds ({lower}, {upper}).") # Validate unknown keys in constraints if constraints: unknown_keys = set(constraints.keys()) - set(initial_params.keys()) if unknown_keys: raise ValueError(f"Constraints defined for unknown parameters: {unknown_keys}") return cls( initial_params=initial_params, constraints=filled_constraints, bounds=filled_bounds )
# ---Protocol Implementation---
[docs] def transform(self, raw_params: Dict[str, Any]) -> Dict[str, Any]: """ Transform raw (unconstrained) parameters to model (constrained) parameters. Args: raw_params: Dictionary of unconstrained parameters. Fixed parameters should NOT be included in this dictionary. Returns: Dictionary of constrained parameters including fixed parameters. Raises: ValueError: If required (non-fixed) parameters are missing or unexpected parameters are present. """ if not isinstance(raw_params, dict): raise TypeError("ParameterSpace currently expects a dictionary of parameters.") input_keys = set(raw_params.keys()) all_keys = set(self.initial_params.keys()) # Check for missing required parameters required_keys = { k for k in all_keys if self.constraints.get(k, "free") != "fixed" } if not input_keys.issuperset(required_keys): missing = required_keys - input_keys raise ValueError(f"Missing required parameters in raw_params: {missing}") # Check for unexpected extra parameters extra = input_keys - all_keys if extra: raise ValueError(f"Unexpected parameters in raw_params: {extra}") def _transform_leaf(value, name): kind = self.constraints.get(name, "free") if kind == "fixed": # FIXED parameters should return their initial value, not the raw value return self.initial_params[name] if kind == "free": return value # Common clip for numerical stability (after free/fixed check) clipped = jnp.clip(value, self.log_clip_min, self.log_clip_max) if kind == "positive": return jnp.exp(clipped) elif kind == "negative": return -jnp.exp(clipped) elif kind in ("probability", "unit_interval"): return jax.nn.sigmoid(clipped) elif kind == "bounded": lower, upper = self.bounds[name] normalized = jax.nn.sigmoid(clipped) return lower + (upper - lower) * normalized else: raise ValueError(f"Unknown constraint type: {kind}") return { name: _transform_leaf(raw_params.get(name), name) for name in self.initial_params.keys() }
[docs] def inverse_transform(self, model_params: Dict[str, Any]) -> Dict[str, Any]: """ Model parameters (Constrained) -> Raw parameters (Unconstrained). """ if not isinstance(model_params, dict): raise TypeError("ParameterSpace currently expects a dictionary of parameters.") input_keys = set(model_params.keys()) expected_keys = set(self.initial_params.keys()) if input_keys != expected_keys: missing = expected_keys - input_keys extra = input_keys - expected_keys error_msg = "Parameter keys mismatch." if missing: error_msg += f" Missing: {missing}." if extra: error_msg += f" Extra (unexpected): {extra}." raise ValueError(error_msg) # ----------------------------------- def _inv_transform_leaf(value, name): kind = self.constraints.get(name, "free") if kind == "fixed": # FIXED parameters: return 0 in raw space (will be ignored by optimizer) return 0.0 if kind == "free": return value if kind == "positive": safe_value = jnp.maximum(value, self.eps) return jnp.log(safe_value) elif kind == "negative": safe_value = jnp.maximum(-value, self.eps) return jnp.log(safe_value) elif kind in ("probability", "unit_interval"): safe_value = jnp.clip(value, self.eps, 1.0 - self.eps) return jax.scipy.special.logit(safe_value) elif kind == "bounded": lower, upper = self.bounds[name] denom = upper - lower # Handle potential degeneracy in bounds is_degenerate = jnp.abs(denom) < self.eps safe_denom = jnp.where(is_degenerate, 1.0, denom) # Normalize and clip normalized = (value - lower) / safe_denom normalized_safe = jnp.clip(normalized, 1e-6, 1.0 - 1e-6) # Apply logit transformation to the clipped normalized value unconstrained = jax.scipy.special.logit(normalized_safe) # If degenerate, return 0.0 return jnp.where(is_degenerate, 0.0, unconstrained) else: raise ValueError(f"Unknown constraint type: {kind}") return { name: _inv_transform_leaf(value, name) for name, value in model_params.items() if self.constraints.get(name, "free") != "fixed" # Exclude FIXED params }
[docs] def get_bounds(self) -> tuple[PyTree, PyTree] | None: """ Protocol: Returns parameter bounds for the optimizer. Returns None because this class uses the 'Transformation Method' (Unconstrained Optimization). The optimizer operates on 'raw_params' which are unbounded (-inf, +inf). Constraints are enforced via the 'transform' method, not by the optimizer's bound constraints. """ return None
# ---Helper Properties--- @property def fixed_mask(self) -> Dict[str, bool]: """ Returns a boolean mask where True indicates a parameter is FIXED. Useful for masking gradients in the Estimator. """ return { k: (v == "fixed") for k, v in self.constraints.items() } @property def num_total_params(self) -> int: """ Returns the number of all parameters. """ return len(self.constraints) @property def num_free_params(self) -> int: """ Returns the number of free (non-fixed) parameters. """ return sum( 1 for kind in self.constraints.values() if kind != "fixed" )