Hessian, HVP, and IHVP¶
- dattri.func.hessian.hvp(func: Callable, argnums: int = 0, mode: str = 'rev-rev', regularization: float = 0.0) Callable ¶
Hessian Vector Product(HVP) calculation function.
This function takes the func where hessian is carried out and return a function takes x (the argument of func) and a vector v to calculate the hessian-vector production.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be calculated on this function. The positional arguments to func must all be Tensors.
argnums (int) – An integer default to 0. Specifies which argument of func to compute hessian with respect to.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian vector product, which is useful for the later inverse calculation if the Hessian matrix is singular or ill-conditioned. Specifically, the regularization term is regularization * v.
- Returns:
A function that takes a tuple of Tensor x as the arguments of func and a vector v and returns the HVP of the Hessian of func and v.
Note
This method does not fix the x. It’s suitable if you have multiple x for the hvp calculation. If you have a fixed x please consider using hvp_at_x.
- Raises:
IHVPUsageError – If mode is not one of “rev-rev” and “rev-fwd”.
- dattri.func.hessian.hvp_at_x(func: Callable, x: Tuple[torch.Tensor, ...], argnums: int = 0, mode: str = 'rev-rev', regularization: float = 0.0) Callable ¶
Hessian Vector Product(HVP) calculation function (with fixed x).
This function returns a function that takes a vector v and calculate the hessian-vector production.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be calculated on this function. The positional arguments to func must all be Tensors.
x (Tuple[torch.Tensor, ...]) – The returned function will computed the hessian matrix with respect to these arguments. argnums indicate which of the input x is used as primal.
argnums (int) – An integer default to 0. Specifies which argument of func to compute hessian with respect to.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian vector product, which is useful for the later inverse calculation if the Hessian matrix is singular or ill-conditioned. Specifically, the regularization term is regularization * v.
- Returns:
A function that takes a vector v and returns the HVP of the Hessian of func and v.
Note
This method does fix the x to avoid some additional computation. If you have multiple x and want to use vmap to accelerate the computation, please consider using hvp.
- Raises:
IHVPUsageError – If mode is not one of “rev-rev” and “rev-fwd”.
- dattri.func.hessian.ihvp_explicit(func: Callable, argnums: int = 0, regularization: float = 0.0) Callable ¶
IHVP via explicit Hessian calculation.
IHVP stands for inverse-hessian-vector product. For a given function func, this method first calculates the Hessian matrix explicitly and then wraps the Hessian in a function that uses torch.linalg.solve to calculate the IHVP for any given vector.
- Parameters:
func (Callable) – A function taking one or more arguments and returning a single-element Tensor. The Hessian will be calculated based on this function.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian matrix. This is useful when the Hessian matrix is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the Hessian matrix.
- Returns:
A function that takes a tuple of Tensor x and a vector v and returns the IHVP of the Hessian of func and v.
- dattri.func.hessian.ihvp_at_x_explicit(func: Callable, *x, argnums: int | Tuple[int, ...] = 0, regularization: float = 0.0) Callable ¶
IHVP via explicit Hessian calculation.
IHVP stands for inverse-hessian-vector product. For a given function func, this method first calculates the Hessian matrix explicitly and then wraps the Hessian in a function that uses torch.linalg.solve to calculate the IHVP for any given vector.
- Parameters:
func (Callable) – A function taking one or more arguments and returning a single-element Tensor. The Hessian will be calculated based on this function.
*x – List of arguments for func.
argnums (int or Tuple[int], optional) – An integer or a tuple of integers deciding which arguments in *x to get the Hessian with respect to. Default: 0.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian matrix. This is useful when the Hessian matrix is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the Hessian matrix.
- Returns:
A function that takes a vector v and returns the IHVP of the Hessian of func and v.
Note
This method stores the Hessian matrix explicitly and is not computationally efficient.
- dattri.func.hessian.ihvp_cg(func: Callable, argnums: int = 0, max_iter: int = 10, tol: float = 1e-07, mode: str = 'rev-rev', regularization: float = 0.0) Callable ¶
Conjugate Gradient Descent ihvp algorithm function.
Standing for the inverse-hessian-vector product, returns a function that, when given vectors, computes the product of inverse-hessian and vector.
Conjugate Gradient Descent algorithm calculate the hvp function and use it iteratively through Conjugate Gradient.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be calculated on this function.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
max_iter (int) – An integer default 10. Specifies the maximum iteration to calculate the ihvp through Conjugate Gradient Descent.
tol (float) – A float default to 1e-7. Specifies the break condition that decide if the algorithm has converged. If the torch.norm of residual is less than tol, then the algorithm is truncated.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian vector product, which is useful for the later inverse calculation if the Hessian matrix is singular or ill-conditioned. Specifically, the regularization term is regularization * v.
- Returns:
A function that takes a tuple of Tensor x and a vector v and returns the IHVP of the Hessian of func and v.
- dattri.func.hessian.ihvp_at_x_cg(func: Callable, *x, argnums: int = 0, max_iter: int = 10, tol: float = 1e-07, mode: str = 'rev-rev', regularization: float = 0.0) Callable ¶
Conjugate Gradient Descent IHVP algorithm function (with fixed x).
Standing for the inverse-hessian-vector product, returns a function that, when given vectors, computes the product of inverse-hessian and vector.
Conjugate Gradient Descent algorithm calculated the hvp function and use it iteratively through Conjugate Gradient.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be calculated on this function.
*x – List of arguments for func.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
max_iter (int) – An integer default 10. Specifies the maximum iteration to calculate the IHVP through Conjugate Gradient Descent.
tol (float) – A float default to 1e-7. Specifies the break condition that decide if the algorithm has converged. If the torch.norm of residual is less than tol, then the algorithm is truncated.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian vector product, which is useful for the later inverse calculation if the Hessian matrix is singular or ill-conditioned. Specifically, the regularization term is regularization * v.
- Returns:
A function that takes a vector v and returns the IHVP of the Hessian of func and v.
- dattri.func.hessian.ihvp_arnoldi(func: Callable, argnums: int = 0, max_iter: int = 100, tol: float = 1e-07, mode: str = 'rev-fwd', regularization: float = 0.0) Callable ¶
Arnoldi Iteration IHVP algorithm function.
Standing for the inverse-hessian-vector product, returns a function that, when given vectors, computes the product of inverse-hessian and vector.
Arnoldi Iteration builds an approximately H-invariant subspace by constructing the n-th order Krylov subspace and builds an orthonormal basis for it.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be calculated on this function.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
max_iter (int) – An integer default 100. Specifies the maximum iteration to calculate the IHVP through Arnoldi Iteration.
tol (float) – A float default to 1e-7. Specifies the break condition that decide if the algorithm has converged. If the torch.norm of current basis vector is less than tol, then the arnoldi_iter algorithm is truncated.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian vector product, which is useful for the later inverse calculation if the Hessian matrix is singular or ill-conditioned. Specifically, the regularization term is regularization * v.
- Returns:
A function that takes a tuple of Tensor x and a vector v and returns the IHVP of the Hessian of func and v.
- dattri.func.hessian.ihvp_at_x_arnoldi(func: Callable, *x, argnums: int = 0, max_iter: int = 100, top_k: int = 100, norm_constant: float = 1.0, tol: float = 1e-07, mode: str = 'rev-fwd', regularization: float = 0.0) Callable ¶
Arnoldi Iteration IHVP algorithm function (with fixed x).
Standing for the inverse-hessian-vector product, returns a function that, when given vectors, computes the product of inverse-hessian and vector.
Arnoldi Iteration builds an approximately H-invariant subspace by constructing the n-th order Krylov subspace and builds an orthonormal basis for it.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be calculated on this function.
*x – List of arguments for func.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
max_iter (int) – An integer default to 100. Specifies the maximum iteration to calculate the IHVP through Arnoldi Iteration.
top_k (int) – An integer default to 100. Specifies how many eigenvalues and eigenvectors to distill.
norm_constant (float) – A float default to 1.0. Specifies a constant value for the norm of each projection. In some situations (e.g. with a large numbers of parameters) it might be advisable to set norm_constant > 1 to avoid dividing projection components by a large normalization factor.
tol (float) – A float default to 1e-7. Specifies the break condition that decide if the algorithm has converged. If the torch.norm of current basis vector is less than tol, then the arnoldi_iter algorithm is truncated.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the Hessian vector product, which is useful for the later inverse calculation if the Hessian matrix is singular or ill-conditioned. Specifically, the regularization term is regularization * v.
- Returns:
A function that takes a vector v and returns the IHVP of the Hessian of func and v.
- dattri.func.hessian.ihvp_lissa(func: Callable, argnums: int = 0, batch_size: int = 1, num_repeat: int = 1, recursion_depth: int = 5000, damping: int = 0.0, scaling: int = 50.0, collate_fn: Callable | None = None, mode: str = 'rev-rev') Callable ¶
IHVP via LiSSA algorithm.
Standing for the inverse-hessian-vector product, returns a function that, when given vectors, computes the product of inverse-hessian and vector.
LiSSA algorithm approximates the IHVP function by averaging multiple samples. The samples are estimated by recursion based on Taylor expansion.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be estimated on this function.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
batch_size (int) – An integer default to 1. Specifies the batch size used for LiSSA inner loop update.
num_repeat (int) – An integer default to 1. Specifies the number of samples of the hvp approximation to average on.
recursion_depth (int) – A integer default to 5000. Specifies the number of recursions used to estimate each IHVP sample.
damping (int) – Damping factor used for non-convexity in LiSSA IHVP calculation.
scaling (int) – Scaling factor used for convergence in LiSSA IHVP calculation.
collate_fn (Optional[Callable]) – A function to collate the input data to fit the input of func. If None, the input data will be directly passed to func. This is useful when func has some nested input structure.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
- Returns:
A function that takes a list of tuples of Tensor x and a vector v and returns the IHVP of the Hessian of func and v.
- dattri.func.hessian.ihvp_at_x_lissa(func: Callable, *x, in_dims: Tuple | None = None, argnums: int = 0, batch_size: int = 1, num_repeat: int = 1, recursion_depth: int = 5000, damping: int = 0.0, scaling: int = 50.0, collate_fn: Callable | None = None, mode: str = 'rev-rev') Callable ¶
IHVP with fixed func inputs via LiSSA algorithm.
Standing for the inverse-hessian-vector product, returns a function that, when given vectors, computes the product of inverse-hessian and vector.
LiSSA algorithm approximates the IHVP function by averaging multiple samples. The samples are estimated by recursion based on Taylor expansion.
- Parameters:
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The hessian will be estimated on this function.
*x – List of arguments for func.
in_dims (Optional[Tuple]) – A tuple with the same shape as *x, indicating which dimension should be considered as batch size dimension. Take the first dimension as batch size dimension by default.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse hessian with respect to.
batch_size (int) – An integer default to 1. Specifies the batch size used for LiSSA inner loop update.
num_repeat (int) – An integer default to 1. Specifies the number of samples of the hvp approximation to average on.
recursion_depth (int) – A integer default to 5000. Specifies the number of recursions used to estimate each IHVP sample.
damping (int) – Damping factor used for non-convexity in LiSSA IHVP calculation.
scaling (int) – Scaling factor used for convergence in LiSSA IHVP calculation.
collate_fn (Optional[Callable]) – A function to collate the input data to fit the input of func. If None, the input data will be directly passed to func. This is useful when func has some nested input structure.
mode (str) –
The auto diff mode, which can have one of the following values: - rev-rev: calculate the hessian with two reverse-mode auto-diff. It has
better compatibility while cost more memory.
- rev-fwd: calculate the hessian with the composing of reverse-mode and
forward-mode. It’s more memory-efficient but may not be supported by some operator.
- Returns:
A function that takes a vector v and returns the IHVP of the Hessian of func and v.