autopdex.implicit_diff.custom_root
- autopdex.implicit_diff.custom_root(residual_fun: Callable, mat_fun: Callable, solve: Callable, free_dofs: Any = None, has_aux: bool = False, mode='reverse', reference_signature: Callable | None = None)[source]
Decorator for adding implicit differentiation to a root solver.
- Parameters:
residual_fun – A callable the returns the possibly nonlinear residual of which to find the root of,
residual_fun(dofs, *args). The invariant isresidual_fun(sol, *args) == 0at the solution / rootsol.mat_fun – A callable that returns the sparse tangent matrix as a jax.experimental.BCOO with dofs and args as arguments. Can also be a pure callback.
solve – A linear solver of the form
solve(mat[jax.experimental.BCOO], b[jnp.ndarray]).free_dofs – For constraining certain degrees of freedom. free_dofs has to be a boolean mask with the same structure of dofs (jnp.ndarray or dict of jnp.ndarrays) indicating the dofs that are not constrained. Requires args[1][‘dirichlet conditions’] to be defined (see source code of _root_jvp and _root_vjp for details or solver.adaptive_load_stepping for exemplary use).
has_aux – whether the decorated root solver function returns auxiliary data.
mode – The differentiation mode (‘forward’ or ‘reverse’/’backward’).
reference_signature – optional function signature (i.e. arguments and keyword arguments), with which the solver and optimality functions are expected to agree. Defaults to
residual_fun. It is required that solver and optimality functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all**kwargs). To satisfy custom_root’s requirement, any function with an unambiguous signature can be provided here.
- Returns:
The decorated root solver function that is equipped a with custom vjp or jvp rule.
Example
See e.g. the implementation of autopdex.solver.adaptive_load_stepping.