API Reference

Workflow (High-Level API)

The primary interface for managing the estimation workflow.

Estimator

Estimator module for the Econox framework. Orchestrates the estimation process by connecting Data, Model, Solver, and Objective.

class econox.workflow.estimator.Estimator(model, param_space, method, solver=None, optimizer=<factory>, verbose=False)[source]

Bases: Module

Orchestrates the structural estimation process.

Handles: 1. Parameter transformation (Raw <-> Constrained) via ParameterSpace. 2. Solving the model (Single run or Batched Simulation/SMM). 3. Evaluating the loss function via EstimationMethod. 4. Minimizing the loss using an Optimizer.

Variables:
  • model – StructuralModel - The structural model to estimate.

  • param_space – ParameterSpace - Parameter transformation and constraints.

  • method – EstimationMethod - Strategy for estimation (Loss definition & Inference).

  • solver – Solver | None - Solver to compute model solutions. Not required for reduced-form estimation.

  • optimizer – Minimizer - Optimization strategy for minimizing the loss.

  • verbose – bool - If True, enables detailed logging for debugging.

Parameters:

Examples

>>> # 1. Setup components
>>> model = Model.from_data(...)
>>> param_space = ParameterSpace.create(...)
>>> solver = ValueIterationSolver(utility=..., dist=..., discount_factor=0.95)
>>> method = MaximumLikelihood(model_key="choice_probs", obs_key="actions")
>>> # 2. Initialize Estimator
>>> estimator = Estimator(
...     model=model,
...     param_space=param_space,
...     solver=solver,
...     method=method
... )
>>> # 3. Run estimation
>>> result = estimator.fit(observations=data)
>>> print(result.params)
model: StructuralModel
param_space: ParameterSpace
method: EstimationMethod
solver: Solver | None = None
optimizer: Minimizer
verbose: bool = False
fit(observations, initial_params=None, sample_size=None, force_numerical=False)[source]

Estimates the model parameters to minimize the objective function.

Parameters:
  • observations (Any) – Observed data to match (passed to Objective).

  • initial_params (dict | None) – Dictionary of initial parameter values (Constrained space). If None, uses initial_params from ParameterSpace.

  • sample_size (int | None) – Effective sample size for variance calculations. Note: This argument is primarily for numerical estimation. If an analytical solution is found, this argument is ignored and the actual data size (n_obs) is used instead.

  • force_numerical (bool) – If True, forces numerical optimization even if an analytical solution is available.

Returns:

  • params: Estimated parameters (Constrained space).

  • loss: Final loss value.

  • success: Whether optimization was successful.

  • std_errors: Standard errors of estimates (if computed).

  • vcov: Variance-covariance matrix (if computed).

  • t_values: t-statistics of estimates (if computed).

  • solver_result: Final solver result (if applicable).

Return type:

EstimationResult containing

__init__(model, param_space, method, solver=None, optimizer=<factory>, verbose=False)
Parameters:
Return type:

None

Simulator

Counterfactual Simulation and Policy Evaluation Workflow.

This module provides the infrastructure for structural simulation in Econox. It enables researchers to conduct “What-if” analyses by modifying model environments (Scenario), solving for new agent behaviors, and evaluating outcomes through differentiable objective functions.

Key Components:
  • Scenario: A container pairing a model environment with its solution.

  • SimulatorObjective: A flexible wrapper for outcome evaluation.

  • Simulator: The orchestrator for the simulation workflow.

class econox.workflow.simulator.Scenario(model, result)[source]

Bases: Module

A context container that pairs a structural model with its computed solution.

In economic simulation, a result (e.g., choice probabilities) must always be interpreted relative to its environment (e.g., prices, taxes). This class ensures that the model’s data and the solver’s output are bundled together.

Variables:
  • model – The StructuralModel instance representing the environment.

  • result – The SolverResult containing the policy or equilibrium solution.

Parameters:

Note

Scenario is immutable (frozen) and should be treated as a read-only container. This prevents accidental modifications during simulation workflows and supports JAX/Equinox functional-style usage.

Examples

>>> # Accessing data within a scenario
>>> tax_rate = scenario.model.data['tax']
>>> choices = scenario.result.profile
model: StructuralModel
result: SolverResult
__init__(model, result)
Parameters:
Return type:

None

class econox.workflow.simulator.SimulatorObjective(func)[source]

Bases: Module, Generic[T]

Interface for evaluating simulation outcomes.

A SimulatorObjective calculates metrics such as welfare changes or tax revenue by comparing scenarios. It remains differentiable and can carry its own internal state (e.g., social welfare weights).

Examples

>>> # Usage via the decorator (recommended)
>>> @simulator_objective_from_func
>>> def welfare_change(cf, base, params):
...     # Difference in average outcomes between scenarios
...     return cf.result.solution.mean() - base.result.solution.mean()
Parameters:

func (Callable[[Scenario, Scenario, PyTree], T])

func: Callable[[Scenario, Scenario, PyTree], T]
__init__(func)
Parameters:

func (Callable[[Scenario, Scenario, PyTree], T])

Return type:

None

econox.workflow.simulator.simulator_objective_from_func(func)[source]

A decorator that converts a Python function into a SimulatorObjective.

Parameters:

func (Callable[[Scenario, Scenario, PyTree], TypeVar(T)]) – A function with signature (cf, base, params) -> T.

Returns:

A JAX-compatible objective module.

Return type:

SimulatorObjective[T]

Examples

>>> @simulator_objective_from_func
... def tax_revenue_gain(cf, base, params):
...     # Revenue = tax_rate * income
...     rev_cf = cf.model.data['tax'] * cf.result.aux['income']
...     rev_base = base.model.data['tax'] * base.result.aux['income']
...     return rev_cf - rev_base
class econox.workflow.simulator.Simulator(solver, base_model, objective_function)[source]

Bases: Module

Counterfactual Simulation Engine.

The Simulator automates the process of updating model data, solving for new agent behaviors, and evaluating outcomes. It is fully compatible with JAX transformations (JIT, Grad), making it ideal for policy optimization.

Variables:
  • solver – The solver used to find the model’s solution.

  • base_model – The reference structural model.

  • objective_function – The logic used to evaluate results.

Parameters:

Examples

>>> # 1. Define objective
>>> @simulator_objective_from_func
... def diff_obj(cf, base, params):
...     return cf.result.solution.mean() - base.result.solution.mean()
>>>
>>> # 2. Initialize Simulator
>>> sim = Simulator(solver=my_solver, base_model=model, objective_function=diff_obj)
>>>
>>> # 3. Run simulation
>>> benefits = sim(params, updates={'tax': 0.25})
solver: Solver
base_model: StructuralModel
objective_function: SimulatorObjective[Any]
__init__(solver, base_model, objective_function)
Parameters:
Return type:

None

Core Interfaces (Protocols)

Abstract base classes and protocols defining the contract for custom components.

Protocol definitions for the Econox framework.

This module defines the core interfaces (contracts) that enable modularity. By adhering to these protocols, users can swap out components (e.g., changing utility functions or solver algorithms) without modifying the rest of the workflow.

class econox.protocols.StructuralModel(*args, **kwargs)[source]

Bases: Protocol

Represents the economic environment (State Space and Constraints).

In structural estimation, a model \(M\) is defined by the tuple \((S, A, T, \Omega)\), where:

  • \(S\): State space (num_states)

  • \(A\): Action space (num_actions)

  • \(T\): Time horizon (num_periods)

  • \(\Omega\): Information set / Data constants (data)

This protocol abstracts away the storage details of these elements.

property num_states: int
property num_actions: int
property num_periods: int | float

Number of periods T in the model. Should be a positive integer for finite horizon or np.inf for infinite horizon.

property data: PyTree

Immutable constants \(\Omega\) (features, matrices, etc.).

property transitions: PyTree | None

Transition structure (e.g., transition matrix or adjacency). Corresponds to \(P(s' | s, a)\).

property availability: PyTree | None

Feasible action set \(A(s)\). Boolean mask of shape (num_states, num_actions).

replace_data(key, value)[source]

Returns a new instance of the model with the specified data key updated. Required for Feedback mechanisms to update the environment (e.g., prices).

Parameters:
  • key (str) – The name of the data field to update.

  • value (Any) – The new value for that field.

Return type:

StructuralModel

Returns:

A new StructuralModel instance (immutable update).

class econox.protocols.Utility(*args, **kwargs)[source]

Bases: Protocol

Structural Utility Function \(u(s, a; \theta)\).

Defines the instantaneous payoff an agent receives from taking action \(a\) in state \(s\).

compute_flow_utility(params, model)[source]

Calculates the utility matrix given parameters and model state.

Return type:

Array, 'n_states n_actions']

Parameters:
class econox.protocols.Distribution(*args, **kwargs)[source]

Bases: Protocol

Stochastic Shock Distribution \(F(\epsilon)\).

Defines the properties of the unobserved state variables (error terms). Handles the smoothing of the max operator (Emax) and choice probabilities (CCP).

Common examples: Type-I Extreme Value (Logit), Normal (Probit).

expected_max(values)[source]

Computes the expected maximum value: \(E[\max_a (v(s, a) + \epsilon(a))]\)

Return type:

Array, 'n_states']

Parameters:

values (Float[jaxlib._jax.Array, 'n_states n_actions'])

choice_probabilities(values)[source]

Computes conditional choice probabilities (CCP): \(P(a | s) = P(v(s, a) + \epsilon(a) \ge v(s, a') + \epsilon(a'), \forall a')\)

Return type:

Array, 'n_states n_actions']

Parameters:

values (Float[jaxlib._jax.Array, 'n_states n_actions'])

class econox.protocols.FeedbackMechanism(*args, **kwargs)[source]

Bases: Protocol

Equilibrium/Market Clearing Condition.

Defines how aggregate agent behaviors affect the environment (e.g., prices, congestion).

Mathematically: \(\Omega' = \Gamma(\sigma, \Omega)\), where \(\sigma\) is the policy.

update(params, current_result, model)[source]

Updates the model environment based on the current solution.

Return type:

StructuralModel

Parameters:
class econox.protocols.Dynamics(*args, **kwargs)[source]

Bases: Protocol

State Transition Law of Motion \(s' = f(s, a, \xi)\).

Defines how the distribution of agents over states evolves over time. Used for simulation and calculating steady-state distributions.

class econox.protocols.TerminalApproximator(*args, **kwargs)[source]

Bases: Protocol

Terminal Value Function Approximator for Finite Horizon Models.

approximate(expected, params, model)[source]

Computes the terminal value function approximation.

Parameters:
  • expected (Array, 'num_states num_actions']) – The expected future value matrix before terminal adjustment.

  • params (PyTree) – Model parameters (may include growth rates, etc.).

  • model (StructuralModel) – The structural model instance providing data and metadata.

Return type:

Array, 'num_states num_actions']

Returns:

The adjusted expected future value matrix.

class econox.protocols.Solver(*args, **kwargs)[source]

Bases: Protocol

Computational Engine for the Model.

A Solver finds the solution (Policy function / Value function) that satisfies the optimality conditions defined by the Model Primitives.

Implementation Note:

Concrete solvers should store their specific logic (Utility, Distribution) internally. The solve method strictly takes the parameters and the environment.

solve(params, model)[source]

Executes the solution algorithm.

Parameters:
  • params (PyTree) – Structural parameters \(\theta\).

  • model (StructuralModel) – The economic environment \((S, A, ...)\).

Returns:

A result object (e.g., SolverResult) containing the solution.

Return type:

Any

Structures (Data & State)

Data containers for models, parameters, and results.

Structural Model

class econox.structures.model.Model(num_states, num_actions, data, num_periods=inf, availability=None, transitions=None)[source]

Bases: Module

Immutable container representing the structural environment.

This class holds all exogenous data (states, transitions, covariates) required to define the economic model. It is designed to be purely data-centric and logic-free, ensuring separation between the environment and the behavioral logic (Utility/Solver).

As an Equinox Module, this class is a valid JAX PyTree, meaning it can be passed into JIT-compiled functions or differentiated with respect to.

Parameters:
  • num_states (int)

  • num_actions (int)

  • data (Dict[str, Float[jaxlib._jax.Array, '...']])

  • num_periods (int | float)

  • availability (Int[jaxlib._jax.Array, 'num_states num_actions'] | None)

  • transitions (Float[jaxlib._jax.Array, '...'] | BCOO | None)

num_states: int

Total cardinality of the state space (\(S\)). Used to determine array shapes.

num_actions: int

Total cardinality of the action space (\(A\)).

data: Dict[str, Float[jaxlib._jax.Array, '...']]

Dictionary of environment constants and exogenous variables.

Keys are identifiers (e.g., ‘wage’, ‘distance’, ‘rent’) that must match the keys expected by the Utility component. Values are typically JAX arrays of shape (\(S\), \(A\)), (\(S\),), or scalars.

num_periods: int | float = inf

Time horizon (\(T\)) of the model.

  • np.inf: Infinite horizon (default).

  • int: Finite horizon.

availability: Int[jaxlib._jax.Array, 'num_states num_actions'] | None = None

Binary mask indicating feasible actions. Shape (\(S\), \(A\)). 1 (or True) indicates action \(a\) is available in state \(s\), while 0 (or False) indicates it is physically impossible.

transitions: Float[jaxlib._jax.Array, '...'] | BCOO | None = None

Exogenous transition structure.

Depending on the model type, this could be: * Transition Probability Matrix: \(P(s' | s, a)\) * Adjacency Matrix (for spatial models) * Deterministic mapping logic

__init__(num_states, num_actions, data, num_periods=inf, availability=None, transitions=None)
Parameters:
  • num_states (int)

  • num_actions (int)

  • data (Dict[str, Float[jaxlib._jax.Array, '...']])

  • num_periods (int | float)

  • availability (Int[jaxlib._jax.Array, 'num_states num_actions'] | None)

  • transitions (Float[jaxlib._jax.Array, '...'] | BCOO | None)

Return type:

None

classmethod from_data(num_states, num_actions, data, availability=None, transitions=None, num_periods=np.inf)[source]

Factory method to initialize a Model from raw Python/NumPy data.

This is the recommended entry point. It automatically handles the conversion of Python lists and NumPy arrays into JAX DeviceArrays, ensuring compatibility with JIT compilation.

Parameters:
  • num_states (int) – Total number of states (\(S\)). Must be positive.

  • num_actions (int) – Total number of actions (\(A\)). Must be positive.

  • data (Dict[str, Any]) – Dictionary of feature matrices (e.g., {‘wage’: […]}). Keys should be strings, values can be lists, NumPy arrays, or JAX arrays.

  • availability (Any | None) – Optional mask for feasible actions. Shape must be (num_states, num_actions).

  • transitions (Any | None) – Optional transition matrices.

  • num_periods (int | float) – Time horizon. Defaults to np.inf (Infinite). Must be positive.

Return type:

Model

Returns:

A frozen, JAX-ready Model instance.

Raises:
  • ValueError – If dimensions are invalid.

  • TypeError – If data types are incompatible.

Examples

>>> # Infinite horizon model
>>> model = Model.from_data(
...     num_states=10,
...     num_actions=3,
...     data={'wage': np.random.randn(10, 3)}
... )
>>> # Finite horizon with availability constraints
>>> model = Model.from_data(
...     num_states=5,
...     num_actions=2,
...     data={'utility': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]},
...     availability=np.ones((5, 2)),
...     num_periods=10
... )
replace_data(key, value)[source]

Creates a NEW Model instance with specific data updated.

This method is essential for counterfactual simulations. It allows you to create a “What-if” world (e.g., “What if subsidies increase by 10%?”) without mutating the original baseline model.

Parameters:
  • key (str) – The name of the data field to update (must exist in data).

  • value (Any) – The new value for that field.

Return type:

Model

Returns:

A new Model instance with the updated data.

property features: Dict[str, Array]

Alias for data, provided for semantic convenience.

get(name)[source]

Retrieves a data array by name with error handling.

Raises:

KeyError – If the name is not found in data.

Return type:

Array

Parameters:

name (str)

property shapes: Dict[str, tuple[int, ...]]

Returns the shape of each array in data for debugging.

Parameter Management

econox.structures.params.ConstraintKind

Specifies the type of constraint applied to a parameter.

Options:
  • free: No constraints (-inf, +inf).

  • positive: Must be positive (0, +inf). Used for variances, etc.

  • negative: Must be negative (-inf, 0).

  • probability: Constrained to (0, 1).

  • unit_interval: Alias for “probability”.

  • fixed: Parameter is fixed to its initial value and not optimized.

  • bounded: Constrained to a specific range [lower, upper].

alias of Literal[‘free’, ‘positive’, ‘negative’, ‘probability’, ‘unit_interval’, ‘fixed’, ‘bounded’]

class econox.structures.params.ParameterSpace(initial_params, constraints, bounds, eps=1e-08, log_clip_min=-20.0, log_clip_max=20.0)[source]

Bases: Module

Manages parameter constraints and transformations with numerical stability. Compliant with the ParameterSpace protocol.

Handles the mapping between: 1. Raw Parameters (Real space, R^n): For the optimizer. 2. Model Parameters (Constrained space): For the economic model.

Examples

>>> # Define initial values
>>> init_params = {
...     "beta": 0.95,
...     "sigma": 1.0,
...     "alpha": 0.5,
...     "gamma": 2.0
... }
>>> # Define constraints
>>> constraints = {
...     "beta": "fixed",          # Not optimized
...     "sigma": "positive",      # Domain: (0, inf)
...     "alpha": "probability",   # Domain: (0, 1)
...     "gamma": "free"           # Domain: (-inf, inf) (Default)
... }
>>> # Create the parameter space
>>> pspace = ParameterSpace.create(init_params, constraints)
Parameters:
  • initial_params (Dict[str, Any])

  • constraints (Dict[str, Literal['free', 'positive', 'negative', 'probability', 'unit_interval', 'fixed', 'bounded']])

  • bounds (Dict[str, tuple[float, float]])

  • eps (float)

  • log_clip_min (float)

  • log_clip_max (float)

initial_params: Dict[str, Any]

Initial values of the parameters (Constrained space).

constraints: Dict[str, Literal['free', 'positive', 'negative', 'probability', 'unit_interval', 'fixed', 'bounded']]

Dictionary mapping parameter names to their constraint types.

bounds: Dict[str, tuple[float, float]]

Dictionary mapping parameter names to (lower, upper) bounds.

eps: float = 1e-08

Small constant for numerical stability.

log_clip_min: float = -20.0

Minimum value for log transformations.

log_clip_max: float = 20.0

Maximum value for log transformations.

__init__(initial_params, constraints, bounds, eps=1e-08, log_clip_min=-20.0, log_clip_max=20.0)
Parameters:
  • initial_params (Dict[str, Any])

  • constraints (Dict[str, Literal['free', 'positive', 'negative', 'probability', 'unit_interval', 'fixed', 'bounded']])

  • bounds (Dict[str, tuple[float, float]])

  • eps (float)

  • log_clip_min (float)

  • log_clip_max (float)

Return type:

None

classmethod create(initial_params, constraints=None, bounds=None)[source]

Factory method to initialize ParameterSpace. Validates keys, bounds, and fills default constraints (‘free’).

Parameters:
  • initial_params (Dict[str, Any]) – Dictionary of initial parameter values.

  • constraints (Dict[str, ConstraintKind] | None) – Optional dictionary specifying constraints for each parameter. Defaults to ‘free’ for unspecified parameters.

  • bounds (Dict[str, tuple[float, float]] | None) – Optional dictionary specifying (lower, upper) bounds for ‘bounded’ parameters.

Return type:

ParameterSpace

transform(raw_params)[source]

Transform raw (unconstrained) parameters to model (constrained) parameters.

Parameters:

raw_params (Dict[str, Any]) – Dictionary of unconstrained parameters. Fixed parameters should NOT be included in this dictionary.

Return type:

Dict[str, Any]

Returns:

Dictionary of constrained parameters including fixed parameters.

Raises:

ValueError – If required (non-fixed) parameters are missing or unexpected parameters are present.

inverse_transform(model_params)[source]

Model parameters (Constrained) -> Raw parameters (Unconstrained).

Return type:

Dict[str, Any]

Parameters:

model_params (Dict[str, Any])

get_bounds()[source]

Protocol: Returns parameter bounds for the optimizer.

Returns None because this class uses the ‘Transformation Method’ (Unconstrained Optimization). The optimizer operates on ‘raw_params’ which are unbounded (-inf, +inf). Constraints are enforced via the ‘transform’ method, not by the optimizer’s bound constraints.

Return type:

tuple[PyTree, PyTree] | None

property fixed_mask: Dict[str, bool]

Returns a boolean mask where True indicates a parameter is FIXED. Useful for masking gradients in the Estimator.

property num_total_params: int

Returns the number of all parameters.

property num_free_params: int

Returns the number of free (non-fixed) parameters.

Result Containers

Data structures for holding computation results. Uses Equinox modules to allow mixins and PyTree registration.

class econox.structures.results.ResultMixin[source]

Bases: object

Provides a generic save() method for Result objects. Implements the ‘Directory Bundle’ strategy.

summary(short=False, print_summary=True)[source]

Generate a summary string of the result object. :type short: bool :param short: If True, generate a shorter summary. Default is False. :type short: bool, optional :type print_summary: bool :param print_summary: If True, print the summary to console. Default is True. :type print_summary: bool, optional

Returns:

The summary string.

Return type:

str

Parameters:
  • short (bool)

  • print_summary (bool)

save(path, overwrite=False)[source]

Save the result object to a directory using the ‘Directory Bundle’ strategy.

Parameters:
  • path (Union[str, Path]) – The target directory path where the result will be saved.

  • overwrite (bool, optional) – If True, overwrite the directory if it already exists. Default is False.

Raises:

FileExistsError – If the target directory already exists and overwrite is False.

Return type:

None

class econox.structures.results.SolverResult(solution, profile=None, inner_result=None, success=False, aux=<factory>)[source]

Bases: ResultMixin, Module

Container for the output of a Solver (Inner/Outer Loop).

Parameters:
  • solution (PyTree)

  • profile (PyTree | None)

  • inner_result (SolverResult | None)

  • success (Bool[jaxlib._jax.Array, ''] | bool)

  • aux (Dict[str, Any])

solution: PyTree

Main solution returned by the solver (the fixed point).

  • DP: Expected Value Function \(EV(s)\) (Integrated Value Function / Emax). Represents the expected value before the realization of the shock \(\epsilon\).

  • GE: Equilibrium allocations (e.g., Population Distribution \(D\)) or Prices \(P\).

profile: PyTree | None = None

Associated profile information derived from the solution.

  • DP: Conditional Choice Probabilities (CCP) \(P(a|s)\). The probability of choosing action \(a\) given state \(s\).

  • GE: Market prices (Wage, Rent) or aggregate states corresponding to the solution.

inner_result: SolverResult | None = None

Associated inner solver result used during nested solving.

success: Bool[jaxlib._jax.Array, ''] | bool = False

Whether the solver converged successfully.

aux: Dict[str, Any]

Additional auxiliary information (e.g., diagnostics).

__init__(solution, profile=None, inner_result=None, success=False, aux=<factory>)
Parameters:
  • solution (PyTree)

  • profile (PyTree | None)

  • inner_result (SolverResult | None)

  • success (Bool[jaxlib._jax.Array, ''] | bool)

  • aux (Dict[str, Any])

Return type:

None

class econox.structures.results.EstimationResult(params, loss, success=False, solver_result=None, std_errors=None, vcov=None, diagnostics=<factory>, meta=<factory>, initial_params=None, fixed_mask=None)[source]

Bases: ResultMixin, Module

Container for the output of an Estimator.

Parameters:
  • params (PyTree)

  • loss (Shaped[jaxlib._jax.Array, ''] | float)

  • success (Bool[jaxlib._jax.Array, ''] | bool)

  • solver_result (SolverResult | None)

  • std_errors (PyTree | None)

  • vcov (Float[jaxlib._jax.Array, 'n_params n_params'] | None)

  • diagnostics (Dict[str, Any])

  • meta (Dict[str, Any])

  • initial_params (PyTree | None)

  • fixed_mask (PyTree | None)

params: PyTree

Estimated parameters.

loss: Shaped[jaxlib._jax.Array, ''] | float

Final value of the loss function (e.g., negative log-likelihood).

success: Bool[jaxlib._jax.Array, ''] | bool = False

Whether the estimation converged successfully.

solver_result: SolverResult | None = None

Associated (outermost) solver result used during estimation.

__init__(params, loss, success=False, solver_result=None, std_errors=None, vcov=None, diagnostics=<factory>, meta=<factory>, initial_params=None, fixed_mask=None)
Parameters:
  • params (PyTree)

  • loss (Shaped[jaxlib._jax.Array, ''] | float)

  • success (Bool[jaxlib._jax.Array, ''] | bool)

  • solver_result (SolverResult | None)

  • std_errors (PyTree | None)

  • vcov (Float[jaxlib._jax.Array, 'n_params n_params'] | None)

  • diagnostics (Dict[str, Any])

  • meta (Dict[str, Any])

  • initial_params (PyTree | None)

  • fixed_mask (PyTree | None)

Return type:

None

std_errors: PyTree | None = None

Standard errors of the estimated parameters, if available.

vcov: Float[jaxlib._jax.Array, 'n_params n_params'] | None = None

Variance-covariance matrix of the estimated parameters, if available.

diagnostics: Dict[str, Any]

Additional diagnostics about the estimation process.

meta: Dict[str, Any]

Additional metadata about the estimation process (e.g., convergence criteria, iteration counts, duration).

initial_params: PyTree | None = None

Initial parameters used for estimation, if available.

fixed_mask: PyTree | None = None

Boolean mask indicating which parameters were fixed during estimation.

property t_values: PyTree | None

Compute t-values if standard errors are available.

to_dataframe()[source]

Convert the estimation results to a Pandas DataFrame for easy analysis.

Returns:

DataFrame containing parameters, standard errors, and t-values.

Return type:

pd.DataFrame

to_latex(split_cols=True, threshold=25)[source]

Convert the estimation results to a LaTeX table. :returns: LaTeX table as a string. :rtype: str

Parameters:
  • split_cols (bool)

  • threshold (int)

Return type:

str

summary(short=False, print_summary=True)[source]

Generate a summary string of the estimation result. :type short: bool :param short: If True, generate a shorter summary. Default is False. :type short: bool, optional :type print_summary: bool :param print_summary: If True, print the summary to console. Default is True. :type print_summary: bool, optional

Returns:

The summary string.

Return type:

str

Parameters:
  • short (bool)

  • print_summary (bool)

Logic (Physics & Rules)

Components defining the economic logic and mechanics of the model.

Utility Functions

Utility components for the Econox framework.

class econox.logic.utility.LinearUtility(param_keys, feature_key)[source]

Bases: Module

Computes flow utility as a linear combination of features and parameters.

This module implements the standard linear utility specification: .. math:: U(s, a) = sum_{k} beta_k cdot X_k(s, a)

It expects parameters to be provided as individual scalars (or consistent arrays) corresponding to the last dimension of the feature tensor. This design enforces explicit naming of parameters, facilitating integration with ParameterSpace for constraints and interpretation.

Variables:
  • param_keys (tuple[str, ...]) – A sequence of keys to retrieve coefficients from the parameter PyTree. Order must match the last dimension of the feature tensor. Example: (“beta_income”, “beta_distance”)

  • feature_key (str) – Key to retrieve the feature tensor from model.data. The tensor must have shape (num_states, num_actions, num_features).

Parameters:
  • param_keys (tuple[str, ...])

  • feature_key (str)

Example

>>> # Model data has features shape (100, 5, 2) -> 2 features
>>> utility = LinearUtility(param_keys=("beta_0", "beta_1"), feature_key="X")
>>> # Params must contain "beta_0" and "beta_1"
>>> u = utility.compute_flow_utility(params, model)
>>> u.shape
(100, 5)
param_keys: tuple[str, ...]

Keys in params for the coefficients corresponding to each feature.

feature_key: str

Key in model.data for the feature tensor of shape (num_states, num_actions, num_features).

compute_flow_utility(params, model)[source]

Calculates flow utility using matrix multiplication.

The method retrieves parameters specified by param_keys, stacks them into a single vector, and computes the dot product with the feature tensor.

Parameters:
  • params (PyTree) – Parameter PyTree containing the coefficients. Values for param_keys should typically be scalars.

  • model (StructuralModel) – StructuralModel containing the feature tensor at model.data[self.feature_key]. Expected shape: (num_states, num_actions, num_features).

Returns:

The calculated flow utility matrix.

Return type:

Float[Array, “num_states num_actions”]

Raises:
  • ValueError – If the feature tensor is not 3D.

  • ValueError – If the number of param_keys does not match the last dimension (num_features) of the feature tensor.

  • ValueError – If parameters cannot be stacked (e.g., shape mismatch).

__init__(param_keys, feature_key)
Parameters:
  • param_keys (tuple[str, ...])

  • feature_key (str)

Return type:

None

class econox.logic.utility.FunctionUtility(func)[source]

Bases: Module

Wraps a user-defined function to satisfy the Utility protocol. Allows defining utility logic as a simple function.

Variables:

func (Callable) – A function with signature (params: PyTree, model: StructuralModel) -> Float[Array, “num_states num_actions”]. This function computes the flow utility matrix for the given parameters and model.

Parameters:

func (Callable)

Example

>>> import econox as ecx
>>> # Define a custom utility function
>>> def my_utility(params, model):
...     # params["beta"] is a scalar, model.data["x"] is (num_states, num_actions)
...     return params["beta"] * model.data["x"]
>>> # Wrap it as a FunctionUtility
>>> utility: econox.protocols.Utility = ecx.utility(my_utility)
>>> u = utility.compute_flow_utility(params, model)
>>> u.shape
(num_states, num_actions)
func: Callable
__init__(func)
Parameters:

func (Callable)

Return type:

None

compute_flow_utility(params, model)[source]

Calls the user-defined function to compute flow utility. :type params: PyTree :param params: Parameter PyTree. :type model: StructuralModel :param model: StructuralModel instance.

Returns:

The computed flow utility matrix.

Return type:

Float[Array, “num_states num_actions”]

Parameters:
econox.logic.utility.utility(func)[source]

Decorator to convert a function into a Utility module.

This allows you to define a utility function with the standard signature and automatically wrap it as a module compatible with the Econox framework.

Parameters:

func (Callable) – A function with signature (params: PyTree, model: StructuralModel) -> Float[Array, “num_states num_actions”]. The function should compute and return the flow utility matrix for the given parameters and model.

Returns:

An object with a compute_flow_utility(params, model) method that calls the provided function.

Return type:

FunctionUtility

Example

>>> import econox as ecx
>>> # Define a custom utility function
>>> @ecx.utility
... def my_utility(params, model):
...     # params["beta"] is a scalar, model.data["x"] is (num_states, num_actions)
...     return params["beta"] * model.data["x"]
>>> # my_utility is now a FunctionUtility instance
>>> u = my_utility.compute_flow_utility(params, model)
>>> u.shape
(num_states, num_actions)

Probability Distributions

Distribution components for Econox. Handles stochastic parts of the model (error terms).

class econox.logic.distribution.GumbelDistribution(scale=1.0)[source]

Bases: Module

Type I Extreme Value (Gumbel) distribution logic for Logit models. Provides Emax (LogSumExp) and choice probabilities (Softmax).

Variables:

scale (float) – Scale parameter of the Gumbel distribution.

Parameters:

scale (float)

scale: float = 1.0
expected_max(values)[source]

Computes the expected maximum value E[max(v + epsilon)].

Formula: scale * log( sum( exp(v / scale) ) ) (Note: Standard implementation commonly refers to this as the ‘Inclusive Value’.)

Return type:

Array, 'num_states']

Parameters:

values (Float[jaxlib._jax.Array, 'num_states num_actions'])

choice_probabilities(values)[source]

Computes the choice probabilities P(choice | state).

Formula: exp(v / scale) / sum( exp(v / scale) )

Return type:

Array, 'num_states num_actions']

Parameters:

values (Float[jaxlib._jax.Array, 'num_states num_actions'])

__init__(scale=1.0)
Parameters:

scale (float)

Return type:

None

Feedback Mechanisms

Feedback mechanisms for General Equilibrium (GE) interactions.

This module provides the infrastructure for defining and executing feedback loops in structural models. It supports both component-wise updates (FunctionFeedback) and full-model updates (CustomUpdateFeedback).

class econox.logic.feedback.CompositeFeedback(feedbacks)[source]

Bases: Module

A container that executes multiple feedback mechanisms sequentially. Useful when you want to chain multiple update steps.

Note

If multiple feedbacks share intermediate calculations (e.g., population density), using CompositeFeedback may result in redundant computations, as each feedback is applied independently and may recompute shared intermediates. In such cases, it is more efficient to use a joint update via @model_feedback (i.e., CustomUpdateFeedback), which allows you to compute shared intermediates only once.

Usage:

@ecx.function_feedback(target_key="wage")
def wage_update(data, params, result):
    # ... calculation ...
    return new_wage_values

@ecx.model_feedback # Both model_feedback and function_feedback can be used here
def rent_update(data, params, result):
    # ... calculation ...
    return new_model

feedback = CompositeFeedback(feedbacks=[wage_update, rent_update])
Parameters:

feedbacks (Sequence[FeedbackMechanism])

feedbacks: Sequence[FeedbackMechanism]
update(params, current_result, model)[source]

Sequential application of feedback mechanisms.

Return type:

StructuralModel

Parameters:
__init__(feedbacks)
Parameters:

feedbacks (Sequence[FeedbackMechanism])

Return type:

None

class econox.logic.feedback.CustomUpdateFeedback(func)[source]

Bases: Module

A feedback mechanism that allows the user to define a custom function to update the ENTIRE model structure.

This is the most flexible approach, allowing for complex dependencies (e.g., wage and rent depending on the same density calculation) without redundant computations or shape mismatches.

Variables:

func – A callable with signature (params, result, model) -> StructuralModel.

Parameters:

func (Callable)

func: Callable
update(params, current_result, model)[source]

Delegates the update logic entirely to the user-defined function.

Return type:

StructuralModel

Parameters:
__init__(func)
Parameters:

func (Callable)

Return type:

None

econox.logic.feedback.model_feedback(func)[source]

Decorator to register a function as a CustomUpdateFeedback.

Usage:

@ecx.model_feedback
def my_ge_loop(params, result, model):
    # ... calculation ...
    new_model = model.replace_data("wage", new_wage_values)
    new_model = new_model.replace_data("rent", new_rent_values)
    return new_model
Return type:

CustomUpdateFeedback

Parameters:

func (Callable)

class econox.logic.feedback.FunctionFeedback(func, target_key)[source]

Bases: Module

A simpler wrapper for updating a specific key in model.data. Best for independent, single-variable updates.

Parameters:
  • func (Callable)

  • target_key (str)

func: Callable
target_key: str
update(params, current_result, model)[source]
Return type:

StructuralModel

Parameters:
__init__(func, target_key)
Parameters:
  • func (Callable)

  • target_key (str)

Return type:

None

econox.logic.feedback.function_feedback(target_key)[source]

Decorator for simple single-variable updates. Usage:

@ecx.function_feedback(target_key="wage")
def wage_update(params, result, data):
    # ... calculation ...
    return new_wage_values
Return type:

Callable[..., FunctionFeedback]

Parameters:

target_key (str)

Transition Dynamics

Dynamics logic components for the Econox framework. Defines how the population/state distribution evolves over time (Law of Motion).

class econox.logic.dynamics.SimpleDynamics(use_transitions=False)[source]

Bases: Module

Standard Law of Motion for dynamic models.

Supports two modes:

  1. Explicit Transition: \(D_{t+1} = (D_t * P(a|s)) @ T(s'|s,a)\) Used when a transition matrix is provided in the model.

  2. Direct Mapping: \(D_{t+1} = D_t @ P(a|s)\) Used when no transition matrix is provided. Assumes Action Space maps 1-to-1 to State Space (A=S).

Parameters:

use_transitions (bool)

use_transitions: bool = False
__init__(use_transitions=False)
Parameters:

use_transitions (bool)

Return type:

None

class econox.logic.dynamics.TrajectoryDynamics(enforce_boundary=True)[source]

Bases: Module

Dynamics for solving path-dependent problems (e.g., Rational Expectations).

Handles the evolution of the entire state trajectory over a finite horizon, often involving: 1. Explicit Transition Matrices (S*A -> S) to handle complex flows. 2. Boundary Conditions (e.g., Fixing t=0 population).

This class expects the following in the model: - The transition matrix must be accessible via the transitions property of the model (i.e., model.transitions should return a (S*A, S) matrix defining physical movement). - “initial_year_indices”: Indices to enforce boundary conditions (in model.data). - “initial_year_values”: Values to enforce at the boundary (in model.data).

Parameters:

enforce_boundary (bool)

enforce_boundary: bool = True
__init__(enforce_boundary=True)
Parameters:

enforce_boundary (bool)

Return type:

None

Terminal Approximators

Terminal value function approximators for dynamic programming solvers.

These classes define strategies to close finite-horizon dynamic models by approximating the expected value function \(EV(T)\) at the simulation horizon.

class econox.logic.terminal.IdentityTerminal[source]

Bases: Module

Identity terminal approximator (Zero modification).

This strategy assumes the terminal value is already correctly initialized and performs no modification to the input matrix.

\[\mathbb{E}V_T(s, a) = \mathbb{E}V_{T}^{input}(s, a)\]

Examples

>>> approximator = IdentityTerminal()
>>> # Returns expected value matrix as-is
>>> adjusted_ev = approximator.approximate(expected, params, model)
approximate(expected, params, model)[source]

Returns the expected value matrix without any modifications.

This method acts as a pass-through, preserving the original Emax values computed by the Bellman operator.

Return type:

Array, 'S A']

Parameters:
  • expected (Float[jaxlib._jax.Array, 'S A'])

  • params (PyTree)

  • model (StructuralModel)

__init__()
Return type:

None

class econox.logic.terminal.StationaryTerminal(term_idx, prev_idx)[source]

Bases: Module

Stationary terminal approximator (Steady-state Boundary).

Approximates the terminal period by assuming the system has reached a time-invariant steady state, where the value function at \(T\) replicates the value at \(T-1\).

\[\mathbb{E}V_T(s, a) = \mathbb{E}V_{T-1}(s', a) \quad \forall s \in \mathcal{S}_{term}\]
Parameters:
  • term_idx (tuple[int, ...]) – Indices of the terminal states \(\mathcal{S}_{term}\).

  • prev_idx (tuple[int, ...]) – Indices of the predecessor states \(\mathcal{S}_{prev}\).

Examples

>>> # Assuming states (4, 5) are terminal and (2, 3) are T-1
>>> approximator = StationaryTerminal(
...     term_idx=(4, 5),
...     prev_idx=(2, 3)
... )
>>> adjusted_ev = approximator.approximate(expected, params, model)
term_idx: tuple[int, ...]
prev_idx: tuple[int, ...]
approximate(expected, params, model)[source]

Overwrites terminal state values with values from predecessor states.

This implementation performs a scatter operation where:

\[EV_{adj}[term\_idx, :] = EV_{raw}[prev\_idx, :]\]
Return type:

Array, 'S A']

Parameters:
  • expected (Float[jaxlib._jax.Array, 'S A'])

  • params (PyTree)

  • model (StructuralModel)

__init__(term_idx, prev_idx)
Parameters:
  • term_idx (tuple[int, ...])

  • prev_idx (tuple[int, ...])

Return type:

None

class econox.logic.terminal.ExponentialTrendTerminal(term_idx, prev_idx, pre_prev_idx=None, growth_rate_keys=None)[source]

Bases: Module

Exponential trend terminal approximator with Adaptive Branching.

This approximator handles non-stationary growth at the horizon by applying a growth rate \(\gamma_s\) to the value function.

Branching Priority:

  1. Exogenous (Parameter-driven): If growth_rate_keys is not None, the solver uses growth rates \(\gamma\) from params.

  2. Endogenous (Data-driven): If growth_rate_keys is None and pre_prev_idx is provided, the solver extrapolates the trend from the model’s internal dynamics (\(T-1\) and \(T-2\)).

It supports growth rates through three parameter specification patterns:

  1. Global Scalar: A single key mapping to a scalar value (e.g., "g"). The same growth rate is applied to all terminal states.

  2. Aggregated Scalars: A list or tuple of keys (e.g., ["g1", "g2", "g3"]). The number of keys must match the length of term_idx (and prev_idx). Each scalar parameter corresponds to a terminal state in order.

  3. State-Indexed Vector: A single key mapping to an array of length \(n\) (e.g., "g_vector"), where \(n\) is the length of term_idx. Each element corresponds to a terminal state in order.

\[\begin{split}\mathbb{E}V_T(s, a) = \begin{cases} (1 + \gamma_s) \mathbb{E}V_{T-1}(s', a) & \text{if keys provided (Exogenous)} \\ \frac{\mathbb{E}V_{T-1}(s', a)}{\mathbb{E}V_{T-2}(s'', a)} \mathbb{E}V_{T-1}(s', a) & \text{if pre_prev_idx provided (Endogenous)} \end{cases}\end{split}\]
Parameters:
  • term_idx (tuple[int, ...]) – Indices of the terminal states \(T\).

  • prev_idx (tuple[int, ...]) – Indices of the predecessor states \(T-1\).

  • pre_prev_idx (tuple[int, ...] | None) – Indices of the states \(T-2\). Required for endogenous mode.

  • growth_rate_keys (Union[str, List[str], Tuple[str, ...], None]) – Identifier(s) for growth rate \(\gamma\). Accepts a single str for global/vector parameters, or a list[str] to aggregate multiple regional scalars.

Raises:

ValueError – If both growth_rate_keys and pre_prev_idx are None.

Examples

>>> # Pattern 1: Global scalar growth
>>> approx = ExponentialTrendTerminal(term_idx, prev_idx, growth_rate_keys="g")
>>> params = {"g": 0.02}
>>> # Pattern 2: Aggregated scalars (3 terminal states)
>>> term_idx = (13, 14, 15)
>>> prev_idx = (10, 11, 12)
>>> approx = ExponentialTrendTerminal(
...     term_idx, prev_idx, growth_rate_keys=["g1", "g2", "g3"]
... )
>>> params = {"g1": 0.02, "g2": 0.03, "g3": 0.01}
>>> # Pattern 3: Endogenous dynamic extrapolation (No params needed)
>>> approx = ExponentialTrendTerminal(term_idx, prev_idx, pre_prev_idx=pre_prev)
>>> adjusted_ev = approx.approximate(expected, {}, model)
term_idx: tuple[int, ...]
prev_idx: tuple[int, ...]
pre_prev_idx: tuple[int, ...] | None = None
growth_rate_keys: str | List[str] | Tuple[str, ...] | None = None
approximate(expected, params, model)[source]

Applies exponential growth to the terminal horizon.

The method adaptively switches between:

  • Exogenous Growth: Multiplying \(T-1\) values by \((1 + \gamma)\) from params.

  • Endogenous Extrapolation: Multiplying \(T-1\) values by the ratio \(EV_{T-1} / EV_{T-2}\).

It automatically handles spatial heterogeneity by mapping parameter keys or vector elements to the corresponding state indices.

Return type:

Array, 'S A']

Parameters:
  • expected (Float[jaxlib._jax.Array, 'S A'])

  • params (PyTree)

  • model (StructuralModel)

__init__(term_idx, prev_idx, pre_prev_idx=None, growth_rate_keys=None)
Parameters:
  • term_idx (tuple[int, ...])

  • prev_idx (tuple[int, ...])

  • pre_prev_idx (tuple[int, ...] | None)

  • growth_rate_keys (str | List[str] | Tuple[str, ...] | None)

Return type:

None

class econox.logic.terminal.LinearTrendTerminal(term_idx, prev_idx, pre_prev_idx=None, drift_keys=None)[source]

Bases: Module

Linear trend terminal approximator with Adaptive Branching.

Approximates the terminal value by adding a drift component \(\delta_s\).

Branching Priority:

  1. Exogenous (Parameter-driven): If drift_keys is not None, uses drift terms \(\delta\) from params.

  2. Endogenous (Data-driven): If drift_keys is None and pre_prev_idx is provided, extrapolates the linear difference between \(T-1\) and \(T-2\).

Similar to the exponential variant, it supports three patterns for \(\delta\):

  1. Global Drift: A single key mapping to a scalar drift value applied to all terminal states.

  2. Aggregated Drifts: A list or tuple of keys. The number of keys must match the length of term_idx (and prev_idx). Each scalar corresponds to a terminal state in order.

  3. Drift Vector: A single key mapping to an array of length \(n\), where \(n\) is the length of term_idx. Each element corresponds to a terminal state in order.

\[\begin{split}\mathbb{E}V_T(s, a) = \begin{cases} \mathbb{E}V_{T-1}(s', a) + \delta_s & \text{if keys provided (Exogenous)} \\ \mathbb{E}V_{T-1}(s', a) + (\mathbb{E}V_{T-1}(s', a) - \mathbb{E}V_{T-2}(s'', a)) & \text{if pre_prev_idx provided (Endogenous)} \end{cases}\end{split}\]
Parameters:
  • term_idx (tuple[int, ...]) – Indices of the terminal states \(T\).

  • prev_idx (tuple[int, ...]) – Indices of the predecessor states \(T-1\).

  • pre_prev_idx (tuple[int, ...] | None) – Indices of the states \(T-2\). Required for endogenous mode.

  • drift_keys (Union[str, List[str], Tuple[str, ...], None]) – Identifier(s) for drift \(\delta\). Accepts a single str for global/vector parameters, or a list[str] to aggregate multiple regional scalars.

Raises:

ValueError – If both drift_keys and pre_prev_idx are None.

Examples

>>> # Linear drift via parameter keys
>>> approx = LinearTrendTerminal(term_idx, prev_idx, drift_keys="drift")
>>> params = {"drift": 500.0}
>>> adjusted_ev = approx.approximate(expected, params, model)
term_idx: tuple[int, ...]
prev_idx: tuple[int, ...]
pre_prev_idx: tuple[int, ...] | None = None
drift_keys: str | List[str] | Tuple[str, ...] | None = None
approximate(expected, params, model)[source]

Applies linear drift to the terminal horizon.

The method adaptively switches between:

  • Exogenous Drift: Adding \(\delta\) from params to \(T-1\) values.

  • Endogenous Extrapolation: Adding the difference \((EV_{T-1} - EV_{T-2})\) to \(EV_{T-1}\).

Return type:

Array, 'S A']

Parameters:
  • expected (Float[jaxlib._jax.Array, 'S A'])

  • params (PyTree)

  • model (StructuralModel)

__init__(term_idx, prev_idx, pre_prev_idx=None, drift_keys=None)
Parameters:
  • term_idx (tuple[int, ...])

  • prev_idx (tuple[int, ...])

  • pre_prev_idx (tuple[int, ...] | None)

  • drift_keys (str | List[str] | Tuple[str, ...] | None)

Return type:

None

Solvers (Algorithms)

Numerical algorithms for solving economic models (Forward Problems).

Dynamic Programming

Dynamic programming solver module for economic models. Can be used for static models as well by setting discount_factor=0.

class econox.solvers.dynamic_programming.ValueIterationSolver(utility, dist, discount_factor, terminal_approximator=<factory>, numerical_solver=<factory>)[source]

Bases: Module

Fixed-point solver using value function iteration.

Variables:
  • utility (Utility) – Utility function to compute flow utilities.

  • dist (Distribution) – Probability distribution for choice modeling.

  • discount_factor (float) – Discount factor for future utilities.

  • terminal_approximator (TerminalApproximator) – Approximator for terminal value function.

  • numerical_solver (FixedPoint) – Numerical solver for finding fixed points.

Parameters:

Examples

>>> # Define structural components
>>> utility = MyUtilityFunction()
>>> dist = Type1ExtremeValue()
>>> # Initialize solver
>>> solver = ValueIterationSolver(
...     utility=utility,
...     dist=dist,
...     discount_factor=0.99,
...     terminal_approximator=IdentityTerminal(),
... )
>>> # Solve the model
>>> result = solver.solve(params, model)
>>> # Access results
>>> EV = result.solution  # Expected Value Function EV(s)
>>> P = result.profile    # Choice Probabilities P(a|s)
utility: Utility
dist: Distribution
discount_factor: float
terminal_approximator: TerminalApproximator
numerical_solver: FixedPoint
solve(params, model)[source]

Solves for the fixed point of the structural model using value iteration.

Parameters:
  • params (PyTree) – Model parameters.

  • model (StructuralModel) – The structural model instance.

Returns:

The result of the solver containing the solution and additional information containing:

  • solution (Array): The computed Expected Value Function \(EV(s)\) (Integrated Value Function / Emax).

  • profile (Array): The Conditional Choice Probabilities (CCP) \(P(a|s)\) derived from the value function.

  • success (Bool): Whether the solver converged successfully.

  • aux (Dict): Auxiliary information, including number of steps taken.

Return type:

SolverResult

__init__(utility, dist, discount_factor, terminal_approximator=<factory>, numerical_solver=<factory>)
Parameters:
Return type:

None

Equilibrium Solvers

Equilibrium solver module for dynamic economic models. Can be used for static models as well by setting discount_factor=0.

class econox.solvers.equilibrium.EquilibriumSolver(inner_solver, feedback, dynamics, numerical_solver=<factory>, damping=1.0, initial_distribution=None)[source]

Bases: Module

Fixed-point solver for General Equilibrium (GE) or Stationary Equilibrium.

This solver searches for a distribution of agents (or prices) \(D^*\) such that:

\[D^* = \Phi(D^*, \theta)\]

where \(\Phi\) represents the compound operator of: 1. Updating the environment (Feedback): \(\Omega' = \Gamma(D, \Omega)\) 2. Solving the agent’s problem (Inner Solver): \(\sigma^* = \text{argmax} \, V(s; \Omega')\) 3. Applying the law of motion (Dynamics): \(D' = f(D, \sigma^*)\)

Variables:
  • inner_solver (Solver) – The solver used to compute the optimal policy given a fixed environment.

  • feedback (FeedbackMechanism) – Logic to update model data based on aggregate results.

  • dynamics (Dynamics) – Law of motion describing how the distribution evolves.

  • numerical_solver (FixedPoint) – The numerical algorithm for the outer loop (e.g., Anderson Acceleration).

  • damping (float) – Damping factor for the update step \(D_{k+1} = (1-\lambda)D_k + \lambda D_{new}\).

  • initial_distribution (Array | None) – Initial guess for the distribution.

Parameters:

Examples

>>> # 1. Inner agent problem (e.g., Household optimization)
>>> inner_solver = ValueIterationSolver(...)
>>> # 2. Market clearing logic (e.g., Supply = Demand)
>>> feedback = FunctionFeedback(func=WageFeedback, target_key="wage")
>>> # 3. Dynamics (Law of Motion)
>>> dynamics = SimpleDynamics()
>>> # 4. Equilibrium Solver
>>> eq_solver = EquilibriumSolver(
...     inner_solver=inner_solver,
...     feedback=feedback,
...     dynamics=dynamics,
...     damping=0.5
... )
>>> # Solve for stationary equilibrium
>>> result = eq_solver.solve(params, model)
inner_solver: Solver
feedback: FeedbackMechanism
dynamics: Dynamics
numerical_solver: FixedPoint
damping: float = 1.0
initial_distribution: Float[jaxlib._jax.Array, 'num_states'] | None = None
solve(params, model)[source]

Solves for the fixed point of the structural model using equilibrium conditions.

Parameters:
  • params (PyTree) – Model parameters.

  • model (StructuralModel) – The structural model instance.

Returns:

The result object containing:

  • solution: Equilibrium Distribution \(D^*\)

  • profile: Equilibrium Policy \(P^*\)

  • inner_result: Full result from the inner solver (Value Function etc.)

Return type:

SolverResult

__init__(inner_solver, feedback, dynamics, numerical_solver=<factory>, damping=1.0, initial_distribution=None)
Parameters:
Return type:

None

Optimization Backends

Note

Import Note: This module is not exposed in the top-level econox namespace. You must access it via the submodule:

from econox.optim import Minimizer
# or
import econox.optim as opt

Optimization and Fixed-Point strategies using Optimistix. Wraps numerical solvers to provide a consistent interface for Econox components.

class econox.optim.MinimizerResult(params, loss, success, steps)[source]

Bases: Module

A generic container for optimization results. Decouples the Estimator from the specific backend (optimistix/jaxopt).

Variables:
  • params – The optimized parameters (PyTree).

  • loss – The final loss value (Scalar).

  • success – Whether the optimization was successful (Bool).

  • steps – Number of optimization steps taken (Int).

Parameters:
  • params (PyTree)

  • loss (Shaped[jaxlib._jax.Array, ''])

  • success (Bool[jaxlib._jax.Array, ''])

  • steps (Int[jaxlib._jax.Array, ''])

params: PyTree
loss: Shaped[jaxlib._jax.Array, '']
success: Bool[jaxlib._jax.Array, '']
steps: Int[jaxlib._jax.Array, '']
__init__(params, loss, success, steps)
Parameters:
  • params (PyTree)

  • loss (Shaped[jaxlib._jax.Array, ''])

  • success (Bool[jaxlib._jax.Array, ''])

  • steps (Int[jaxlib._jax.Array, ''])

Return type:

None

class econox.optim.Minimizer(method=LBFGS(   rtol=1e-06, atol=1e-06, norm=<function max_norm>, use_inverse=True, descent=NewtonDescent(), search=BacktrackingArmijo(), history_length=10, verbose=frozenset() ), max_steps=1000, throw=False)[source]

Bases: Module

Wrapper for optimistix.minimise. Implements the econox.protocols.Optimizer interface.

You can customize the method and tolerances at initialization.

Examples

>>> # Default (LBFGS, tol=1e-6)
>>> opt = Minimizer()
>>> # Custom method (e.g., Nelder-Mead) and tolerances
>>> opt = Minimizer(method=optx.NelderMead(atol=1e-5, rtol=1e-5))
Parameters:
  • method (AbstractMinimiser)

  • max_steps (int)

  • throw (bool)

method: AbstractMinimiser = LBFGS(   rtol=1e-06,   atol=1e-06,   norm=<function max_norm>,   use_inverse=True,   descent=NewtonDescent(),   search=BacktrackingArmijo(),   history_length=10,   verbose=frozenset() )
max_steps: int = 1000
throw: bool = False
minimize(loss_fn, init_params, args=None)[source]

Minimizes the loss function using the specified method and tolerances.

Parameters:
  • loss_fn (Callable[[PyTree, Any], Scalar]) – The loss function to minimize. Takes parameters and additional arguments, returns a scalar loss.

  • init_params (PyTree) – Initial parameter values for optimization.

  • args (Any, optional) – Additional arguments passed to the loss function. Defaults to None.

Returns:

Contains the optimized parameters, final loss, success status, and iteration count.

Return type:

MinimizerResult

property method_name: str

Returns the name of the optimization method used.

__init__(method=LBFGS(   rtol=1e-06, atol=1e-06, norm=<function max_norm>, use_inverse=True, descent=NewtonDescent(), search=BacktrackingArmijo(), history_length=10, verbose=frozenset() ), max_steps=1000, throw=False)
Parameters:
  • method (AbstractMinimiser)

  • max_steps (int)

  • throw (bool)

Return type:

None

class econox.optim.FixedPointResult(value, success, steps)[source]

Bases: Module

Container for fixed-point computation results. Used by internal solvers (Bellman, Equilibrium) to report convergence status.

Variables:
  • value – The computed fixed-point value (PyTree).

  • success – Whether the fixed-point iteration was successful (Bool).

  • steps – Number of iterations taken (Int).

Parameters:
  • value (PyTree)

  • success (Bool[jaxlib._jax.Array, ''])

  • steps (Int[jaxlib._jax.Array, ''])

value: PyTree
success: Bool[jaxlib._jax.Array, '']
steps: Int[jaxlib._jax.Array, '']
__init__(value, success, steps)
Parameters:
  • value (PyTree)

  • success (Bool[jaxlib._jax.Array, ''])

  • steps (Int[jaxlib._jax.Array, ''])

Return type:

None

class econox.optim.FixedPoint(method=FixedPointIteration(rtol=1e-08, atol=1e-08), max_steps=2000, throw=False, adjoint=ImplicitAdjoint(linear_solver=AutoLinearSolver(well_posed=False)))[source]

Bases: Module

Wrapper for optimistix.fixed_point.

Examples

>>> # Default (FixedPointIteration)
>>> # Uses default max_steps (2000) and tolerances (rtol=1e-8, atol=1e-8)
>>> fp = FixedPoint()
>>> # Custom
>>> fp = FixedPoint(method=optx.FixedPointIteration(rtol=1e-10, atol=1e-10), max_steps=5000)
Parameters:
  • method (AbstractFixedPointSolver)

  • max_steps (int)

  • throw (bool)

  • adjoint (AbstractAdjoint)

method: AbstractFixedPointSolver = FixedPointIteration(rtol=1e-08, atol=1e-08)
max_steps: int = 2000
throw: bool = False
adjoint: AbstractAdjoint = ImplicitAdjoint(linear_solver=AutoLinearSolver(well_posed=False))
find_fixed_point(step_fn, init_val, args=None)[source]

Solves for \(y\) such that \(y = \text{step\_fn}(y, \text{args})\). Returns a FixedPointResult containing the solution and status.

Parameters:
  • step_fn (Callable[[PyTree, Any], PyTree]) – The fixed-point function. Takes current value and args, returns next value.

  • init_val (PyTree) – Initial guess for the fixed-point iteration.

  • args (Any, optional) – Additional arguments passed to the fixed-point function.

Returns:

Contains the fixed-point value, success status, and iteration count.

Return type:

FixedPointResult

__init__(method=FixedPointIteration(rtol=1e-08, atol=1e-08), max_steps=2000, throw=False, adjoint=ImplicitAdjoint(linear_solver=AutoLinearSolver(well_posed=False)))
Parameters:
  • method (AbstractFixedPointSolver)

  • max_steps (int)

  • throw (bool)

  • adjoint (AbstractAdjoint)

Return type:

None

Methods (Estimation Techniques)

Strategies for estimating model parameters from data (Inverse Problems).

Base Classes

Base module for method functions in the Econox framework.

class econox.methods.base.EstimationMethod(*, variance=None)[source]

Bases: Module

Base class for all estimation method functions in Econox.

This class serves three main purposes: 1. Strategy Definition: Defines the loss function to be minimized during numerical estimation. 2. Analytical Solution: Optionally provides a direct solution method (e.g., for OLS/2SLS). 3. Inference: Optionally defines how to calculate standard errors (e.g., Hessian, Sandwich).

Users can create custom objectives by subclassing this class or by using the @method_from_loss decorator.

Variables:

variance – Variance | None

Parameters:

variance (Variance | None)

variance: Variance | None = None

Optional variance calculation strategy for inference.

abstractmethod compute_loss(result, observations, params, model)[source]

Calculates the scalar loss metric to be minimized.

This method is the core of the numerical estimation loop. It compares the model’s prediction (result) with the real-world data (observations).

Parameters:
  • result (Any | None) – The output from the Solver (e.g., SolverResult). If an analytical solution is being evaluated, this may be None.

  • observations (Any) – Observed data to fit the model against.

  • params (PyTree) – Current model parameters (useful for regularization terms).

  • model (StructuralModel) – The structural model environment.

Return type:

Array, '']

Returns:

A scalar JAX array representing the loss (e.g., Negative Log-Likelihood).

solve(model, observations, param_space)[source]

Computes the analytical solution for the parameters, if available.

This method allows the Estimator to bypass the numerical optimization loop for models that have a closed-form solution (e.g., OLS, 2SLS).

Parameters:
  • model (StructuralModel) – The structural model environment.

  • observations (Any) – Observed data.

  • param_space (Any) – The parameter space definition.

Returns:

Returns an EstimationResult if an analytical solution is found. Returns None otherwise (default), and the Estimator will fall back to numerical optimization using compute_loss.

Return type:

EstimationResult | None

classmethod from_function(func)[source]

Creates an EstimationMethod instance from a simple loss function.

This factory method allows users to define objectives using a simple function instead of defining a full class. The created objective will rely on numerical optimization (solve returns None) and will not compute standard errors by default.

Parameters:

func (Callable) – A function with the signature: (result, observations, params, model) -> Scalar

Return type:

EstimationMethod

Returns:

An instance of a dynamically created EstimationMethod subclass.

Example

>>> @method_from_loss
... def mse_loss(result, observations, params, model):
...     return jnp.mean((result.solution - observations) ** 2)
__init__(*, variance=None)
Parameters:

variance (Variance | None)

Return type:

None

econox.methods.base.method_from_loss(func)

Creates an EstimationMethod instance from a simple loss function.

This factory method allows users to define objectives using a simple function instead of defining a full class. The created objective will rely on numerical optimization (solve returns None) and will not compute standard errors by default.

Parameters:

func (Callable) – A function with the signature: (result, observations, params, model) -> Scalar

Return type:

EstimationMethod

Returns:

An instance of a dynamically created EstimationMethod subclass.

Example

>>> @method_from_loss
... def mse_loss(result, observations, params, model):
...     return jnp.mean((result.solution - observations) ** 2)

Analytical Methods (OLS/2SLS)

Analytical linear estimation methods (OLS, 2SLS) with fixed parameter support.

This module provides: - OLS (Ordinary Least Squares) - 2SLS (Two-Stage Least Squares)

Both support: - Fixed parameter constraints - Numerical stability via QR decomposition - Automatic fallback to numerical optimization for complex constraints

Example

>>> ols = LeastSquares(feature_key="X", target_key="y")
>>> # Use Estimator (Recommended)
>>> est = Estimator(model=model, param_space=param_space, method=ols)
>>> result = est.fit(observations=observations)
>>> # Or call solve() directly
>>> result = ols.solve(model=model, observations=observations, param_space=param_space)
class econox.methods.analytical.AnalyticalParameterHandler(is_fixed_mask, fixed_values_vec, param_names, n_total, n_fixed, n_free)[source]

Bases: Module

Handles fixed parameter constraints for analytical linear methods.

This handler: - Separates fixed and free parameters - Transforms design matrices to account for fixed values - Reconstructs full parameter vectors and covariance matrices

Variables:
  • is_fixed_mask – Boolean mask (True = fixed, False = free)

  • fixed_values_vec – Values for fixed parameters

  • param_names – Names of all parameters

  • n_total – Total number of parameters

  • n_fixed – Number of fixed parameters

  • n_free – Number of free parameters (n_total - n_fixed)

Parameters:
  • is_fixed_mask (Array)

  • fixed_values_vec (Array)

  • param_names (List[str])

  • n_total (int)

  • n_fixed (int)

  • n_free (int)

is_fixed_mask: Array
fixed_values_vec: Array
param_names: List[str]
n_total: int
n_fixed: int
n_free: int
classmethod from_params(param_space, param_names)[source]
Return type:

AnalyticalParameterHandler

Parameters:
transform(X, y)[source]
Return type:

Tuple[Array, Array]

Parameters:
  • X (Array)

  • y (Array)

reconstruct(beta_free, vcov_free)[source]
Return type:

Tuple[Array, Array]

Parameters:
  • beta_free (Array)

  • vcov_free (Array)

__init__(is_fixed_mask, fixed_values_vec, param_names, n_total, n_fixed, n_free)
Parameters:
  • is_fixed_mask (Array)

  • fixed_values_vec (Array)

  • param_names (List[str])

  • n_total (int)

  • n_fixed (int)

  • n_free (int)

Return type:

None

class econox.methods.analytical.LinearMethod(add_intercept=True, target_key='y', param_names=None, *, variance=<factory>)[source]

Bases: EstimationMethod

Base class providing the template method for analytical solving.

Parameters:
  • add_intercept (bool)

  • target_key (str)

  • param_names (List[str] | None)

  • variance (Variance | None)

add_intercept: bool = True
target_key: str = 'y'
param_names: List[str] | None = None
solve(model, observations, param_space)[source]

Analytical estimation workflow with fixed parameter support.

Workflow:
  1. Data Preparation: Extract y, construct X (and Z for 2SLS)

  2. Constraint Handling: Separate fixed/free parameters

  3. Core Estimation: Subclass-specific OLS/2SLS logic

  4. Reconstruction: Merge fixed and estimated parameters

  5. Statistics: Compute residuals, R², standard errors

Parameters:
  • model (StructuralModel) – Structural model containing data

  • observations (Any) – Observation data (used to extract target y)

  • param_space (ParameterSpace) – Parameter space with constraints and initial values

Return type:

EstimationResult | None

Returns:

EstimationResult if successful, None if fallback to numerical needed
  • loss: The Sum of Squared Residuals (SSR).

    Note that this is the total sum, not the mean (MSE).

Note

Returns None when constraints other than ‘fixed’/’free’ are present, signaling Estimator to use numerical optimization instead.

compute_loss(result, observations, params, model)[source]

Calculates the scalar loss metric to be minimized.

This method is the core of the numerical estimation loop. It compares the model’s prediction (result) with the real-world data (observations).

Parameters:
  • result (Any | None) – The output from the Solver (e.g., SolverResult). If an analytical solution is being evaluated, this may be None.

  • observations (Any) – Observed data to fit the model against.

  • params (PyTree) – Current model parameters (useful for regularization terms).

  • model (StructuralModel) – The structural model environment.

Return type:

Any

Returns:

A scalar JAX array representing the loss (e.g., Negative Log-Likelihood).

__init__(add_intercept=True, target_key='y', param_names=None, *, variance=<factory>)
Parameters:
  • add_intercept (bool)

  • target_key (str)

  • param_names (List[str] | None)

  • variance (Variance | None)

Return type:

None

class econox.methods.analytical.LeastSquares(add_intercept=True, target_key='y', param_names=None, feature_key='X', *, variance=<factory>)[source]

Bases: LinearMethod

Parameters:
  • add_intercept (bool)

  • target_key (str)

  • param_names (List[str] | None)

  • feature_key (str)

  • variance (Variance | None)

feature_key: str = 'X'
__init__(add_intercept=True, target_key='y', param_names=None, feature_key='X', *, variance=<factory>)
Parameters:
  • add_intercept (bool)

  • target_key (str)

  • param_names (List[str] | None)

  • feature_key (str)

  • variance (Variance | None)

Return type:

None

class econox.methods.analytical.TwoStageLeastSquares(add_intercept=True, target_key='y', param_names=None, endog_key='X', instrument_key='Z', controls_key=None, *, variance=<factory>)[source]

Bases: LinearMethod

Parameters:
  • add_intercept (bool)

  • target_key (str)

  • param_names (List[str] | None)

  • endog_key (str)

  • instrument_key (str)

  • controls_key (str | None)

  • variance (Variance | None)

endog_key: str = 'X'
instrument_key: str = 'Z'
controls_key: str | None = None
__init__(add_intercept=True, target_key='y', param_names=None, endog_key='X', instrument_key='Z', controls_key=None, *, variance=<factory>)
Parameters:
  • add_intercept (bool)

  • target_key (str)

  • param_names (List[str] | None)

  • endog_key (str)

  • instrument_key (str)

  • controls_key (str | None)

  • variance (Variance | None)

Return type:

None

Numerical Methods (MLE/GMM)

Numerical estimation methods (loss-based). Standard methods like Maximum Likelihood (MLE) and GMM.

class econox.methods.numerical.CompositeMethod(methods, weights=None, *, variance=None)[source]

Bases: EstimationMethod

Combines multiple estimation methods into a single scalar loss. Assumes methods are independent (Block-Diagonal Weighting).

Loss = sum( weight_i * loss_i )

Variables:
  • methods – Sequence[EstimationMethod]

  • weights – Sequence[float] | None Optional weights for each method. If None, equal weights are used.

  • variance – Variance | None Optional variance calculation strategy for inference. Note: By default, variance is not computed for composite methods because the combined loss may not correspond to a valid statistical model.

Parameters:
methods: Sequence[EstimationMethod]
weights: Sequence[float] | None = None
variance: Variance | None = None

Optional variance calculation strategy for inference.

compute_loss(result, observations, params, model)[source]

Calculates the scalar loss metric to be minimized.

This method is the core of the numerical estimation loop. It compares the model’s prediction (result) with the real-world data (observations).

Parameters:
  • result (Any | None) – The output from the Solver (e.g., SolverResult). If an analytical solution is being evaluated, this may be None.

  • observations (Any) – Observed data to fit the model against.

  • params (PyTree) – Current model parameters (useful for regularization terms).

  • model (StructuralModel) – The structural model environment.

Return type:

Array, '']

Returns:

A scalar JAX array representing the loss (e.g., Negative Log-Likelihood).

__init__(methods, weights=None, *, variance=None)
Parameters:
Return type:

None

class econox.methods.numerical.MaximumLikelihood(choice_probs_key='profile', *, variance=<factory>)[source]

Bases: EstimationMethod

Standard MLE for Discrete Choice (Migration/Occupation). Computes Negative Log-Likelihood (NLL) based on choice probabilities.

Parameters:
  • choice_probs_key (str)

  • variance (Variance | None)

choice_probs_key: str = 'profile'
compute_loss(result, observations, params, model)[source]

Calculates the scalar loss metric to be minimized.

This method is the core of the numerical estimation loop. It compares the model’s prediction (result) with the real-world data (observations).

Parameters:
  • result (Any | None) – The output from the Solver (e.g., SolverResult). If an analytical solution is being evaluated, this may be None.

  • observations (Any) – Observed data to fit the model against.

  • params (PyTree) – Current model parameters (useful for regularization terms).

  • model (StructuralModel) – The structural model environment.

Return type:

Array, '']

Returns:

A scalar JAX array representing the loss (e.g., Negative Log-Likelihood).

__init__(choice_probs_key='profile', *, variance=<factory>)
Parameters:
  • choice_probs_key (str)

  • variance (Variance | None)

Return type:

None

class econox.methods.numerical.GaussianMomentMatch(obs_key, model_key, scale_param_key, log_transform=False, *, variance=None)[source]

Bases: EstimationMethod

Fits a continuous model variable (e.g. Rent, Wage) to observed data assuming a Gaussian (or Log-Normal) error structure.

Parameters:
  • obs_key (str)

  • model_key (str)

  • scale_param_key (str)

  • log_transform (bool)

  • variance (Variance | None)

obs_key: str
model_key: str
scale_param_key: str
log_transform: bool = False
variance: Variance | None = None

Optional variance calculation strategy for inference.

__init__(obs_key, model_key, scale_param_key, log_transform=False, *, variance=None)
Parameters:
  • obs_key (str)

  • model_key (str)

  • scale_param_key (str)

  • log_transform (bool)

  • variance (Variance | None)

Return type:

None

compute_loss(result, observations, params, model)[source]

Calculates the scalar loss metric to be minimized.

This method is the core of the numerical estimation loop. It compares the model’s prediction (result) with the real-world data (observations).

Parameters:
  • result (Any | None) – The output from the Solver (e.g., SolverResult). If an analytical solution is being evaluated, this may be None.

  • observations (Any) – Observed data to fit the model against.

  • params (PyTree) – Current model parameters (useful for regularization terms).

  • model (StructuralModel) – The structural model environment.

Return type:

Array, '']

Returns:

A scalar JAX array representing the loss (e.g., Negative Log-Likelihood).

Variance & Inference

Variance calculation strategies for statistical inference. Handles the computation of standard errors and covariance matrices.

class econox.methods.variance.Variance[source]

Bases: Module

Base class for variance computation strategies.

compute(loss_fn, params, observations, num_observations)[source]

Calculates the standard errors and variance-covariance matrix.

Parameters:
  • loss_fn (Callable[[PyTree], Array, '']]) – A differentiable function f(params) -> loss. The objective function with result and model applied via closure.

  • params (PyTree) – The estimated optimal parameters.

  • observations (Any) – The observed data.

  • num_observations (int) – Number of data points (N).

Returns:

  • std_errors: PyTree of standard errors (same structure as params).

  • vcov: Variance-covariance matrix (n_params x n_params).

Return type:

A tuple containing

__init__()
Return type:

None

class econox.methods.variance.Hessian[source]

Bases: Variance

Calculates variance using the inverse Hessian of the loss function.

Standard approach for Maximum Likelihood Estimation (MLE). Assumes the loss function is the negative log-likelihood. \(V = H^{-1} / N\)

compute(loss_fn, params, observations, num_observations)[source]

Calculates standard errors and variance-covariance matrix using the Hessian.

Parameters:
  • loss_fn (Callable[[PyTree], Array, '']]) – A differentiable function f(params) -> loss.

  • params (PyTree) – The estimated optimal parameters.

  • observations (Any) – The observed data.

  • num_observations (int) – Number of data points (N).

Returns:

A tuple containing:

  • std_errors: Standard errors matching the structure of the input params. (If input params are flattened, this will be a 1D array)

  • vcov: Variance-covariance matrix.

Return type:

tuple

__init__()
Return type:

None

Configuration

Note

Import Note: Configuration constants are located in the econox.config submodule.

from econox import config
print(config.NUMERICAL_EPSILON)

Global settings and constants used throughout the library.

Global configuration defaults for Econox.

econox.config.INLINE_ARRAY_SIZE_THRESHOLD: int = 10

Arrays with size <= this value will be saved inline in summary.txt and metadata.json.

econox.config.SUMMARY_STRING_MAX_LENGTH: int = 50

Maximum length for string representations in summary.txt before adding ‘…’.

econox.config.FLATTEN_MULTIDIM_ARRAYS: bool = True

If True, arrays with >2 dimensions will be flattened when saving to CSV.

econox.config.SUMMARY_FIELD_WIDTH: int = 25

Width for field name padding in summary.txt.

econox.config.SUMMARY_SEPARATOR_LENGTH: int = 60

Length of separator lines (=== bars) in summary.txt.

econox.config.LOSS_PENALTY: float = inf

Penalty value returned for invalid model solutions during optimization. Used to steer the optimizer away from unstable parameter regions. Default is positive infinity.

Utilities

Note

Import Note: Helper functions are located in the econox.utils submodule.

from econox.utils import get_from_pytree

General utility functions for array manipulation and PyTree handling.

General utility functions shared across the Econox package.

econox.utils.get_from_pytree(data, key, default=_MISSING)[source]

Retrieve a value from a data container, supporting both dict-style ([‘key’]) and attribute-style (.key) access.

Parameters:
  • data (Any) – The container (dict, NamedTuple, PyTree, etc.).

  • key (str) – The key or attribute name to retrieve.

  • default (Union[TypeVar(T), object]) – Value to return if key is not found. If not provided, raises error.

Return type:

Union[Any, TypeVar(T)]

Returns:

The value associated with the key, or default if not found.

Raises:
  • KeyError – If data is dict-like and key is missing (and no default).

  • AttributeError – If data is object-like and attribute is missing (and no default).