Fisher Information Matrix (FIM) / IFVP¶
- dattri.func.fisher.ifvp_explicit(func: Callable, argnums: int = 0, regularization: float = 0.0) Callable ¶
IFVP via explicit FIM calculation.
IFVP stands for inverse-FIM-vector product. For a given function func, this method first calculates the FIM explicitly and then wraps the FIM in a function that uses torch.linalg.solve to calculate the IFVP for any given vector.
- Parameters:
func (Callable) – A function taking one or more arguments and returning a single-element Tensor. The FIM will be calculated based on this function. Notably, this function should be negative log-likelihood (e.g., cross-entropy loss) for classification tasks. If you want to calculate the empirical FIM, you should use the ground truth label for the loss. If you want to calculate the true FIM, you should use the predicted label for the loss.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse FIM with respect to.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the FIM. This is useful when the FIM is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the FIM.
- Returns:
A function that takes a tuple of Tensor x and a vector v and returns the product of the FIM of func and v.
- dattri.func.fisher.ifvp_at_x_explicit(func: Callable, *x, argnums: int | Tuple[int, ...] = 0, regularization: float = 0.0) Callable ¶
IFVP via explicit FIM calculation.
IFVP stands for inverse-FIM-vector product. For a given function func, this method first calculates the FIM explicitly and then wraps the FIM in a function that uses torch.linalg.solve to calculate the IFVP for any given vector.
- Parameters:
func (Callable) – A function taking one or more arguments and returning a single-element Tensor. The FIM will be calculated based on this function. Notably, this function should be negative log-likelihood (e.g., cross-entropy loss) for classification tasks. If you want to calculate the empirical FIM, you should use the ground truth label for the loss. If you want to calculate the true FIM, you should use the predicted label for the loss.
*x – List of arguments for func.
argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse FIM with respect to.
regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the FIM. This is useful when the FIM is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the FIM.
- Returns:
A function that takes a vector v and returns the IFVP of the Hessian of func and v.