Fisher Information Matrix (FIM) / IFVP

dattri.func.fisher.ifvp_explicit(func: Callable, argnums: int = 0, regularization: float = 0.0) Callable

IFVP via explicit FIM calculation.

IFVP stands for inverse-FIM-vector product. For a given function func, this method first calculates the FIM explicitly and then wraps the FIM in a function that uses torch.linalg.solve to calculate the IFVP for any given vector.

Parameters:
  • func (Callable) – A function taking one or more arguments and returning a single-element Tensor. The FIM will be calculated based on this function. Notably, this function should be negative log-likelihood (e.g., cross-entropy loss) for classification tasks. If you want to calculate the empirical FIM, you should use the ground truth label for the loss. If you want to calculate the true FIM, you should use the predicted label for the loss.

  • argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse FIM with respect to.

  • regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the FIM. This is useful when the FIM is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the FIM.

Returns:

A function that takes a tuple of Tensor x and a vector v and returns the product of the FIM of func and v.

dattri.func.fisher.ifvp_at_x_explicit(func: Callable, *x, argnums: int | Tuple[int, ...] = 0, regularization: float = 0.0) Callable

IFVP via explicit FIM calculation.

IFVP stands for inverse-FIM-vector product. For a given function func, this method first calculates the FIM explicitly and then wraps the FIM in a function that uses torch.linalg.solve to calculate the IFVP for any given vector.

Parameters:
  • func (Callable) – A function taking one or more arguments and returning a single-element Tensor. The FIM will be calculated based on this function. Notably, this function should be negative log-likelihood (e.g., cross-entropy loss) for classification tasks. If you want to calculate the empirical FIM, you should use the ground truth label for the loss. If you want to calculate the true FIM, you should use the predicted label for the loss.

  • *x – List of arguments for func.

  • argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse FIM with respect to.

  • regularization (float) – A float default to 0.0. Specifies the regularization term to be added to the FIM. This is useful when the FIM is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the FIM.

Returns:

A function that takes a vector v and returns the IFVP of the Hessian of func and v.

dattri.func.fisher.ifvp_datainf(func: Callable, argnums: int, in_dims: Tuple[None | int, ...], regularization: float | List[float] | None = None, param_layer_map: List[int] | None = None) Callable

DataInf IFVP algorithm function.

Standing for the inverse-FIM-vector product, returns a function that, when given vectors, computes the product of inverse-FIM and vector.

DataInf assume the loss to be cross-entropy and thus derive a closed form IFVP without having to approximate the FIM. Implementation for reference: https://github.com/ykwon0407/DataInf/blob/main/src/influence.py

Parameters:
  • func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The layer-wise gradients will be calculated on this function. Note that datainf expects the loss to be cross-entropy.

  • argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse FIM with respect to.

  • in_dims (Tuple[Union[None, int], ...]) – Parameter sent to vmap to produce batched layer-wise gradients. Example: inputs, weights, labels corresponds to (0,None,0).

  • regularization (List [float]) – A float or list of floats default to 0.0. Specifies the regularization term to be added to the empirical FIM in each layer. This is useful when the empirical FIM is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the empirical FIM. The list is of length L, where L is the total number of layers.

  • param_layer_map – Optional[List[int]]: Specifies how the parameters are grouped into layers. Should be the same length as parameters tuple. For example, for a two layer model, params = (0.weights1,0.bias,1.weights,1.bias), param_layer_map should be [0,0,1,1],resulting in two layers as expected.

Returns:

A function that takes a list of tuples of Tensor x and a tuple of tensors v (layer-wise) and returns the approximate IFVP of the approximate Hessian of func and v.

Raises:

IFVPUsageError – If the length of regularization is not the same as the number of layers.

dattri.func.fisher.ifvp_at_x_datainf(func: Callable, argnums: int, in_dims: Tuple[None | int, ...], regularization: List[float] | None = None, *x, param_layer_map: List[int] | None = None) Callable

DataInf IFVP algorithm function (with fixed x).

Standing for the inverse-FIM-vector product, returns a function that, when given vectors, computes the product of inverse-FIM and vector.

DataInf assume the loss to be cross-entropy and thus derive a closed form IFVP without having to approximate the FIM.

Parameters:
  • func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. The layer-wise gradients will be calculated on this function. Note that datainf expects the loss to be cross-entropy.

  • argnums (int) – An integer default to 0. Specifies which argument of func to compute inverse FIM with respect to.

  • in_dims (Tuple[Union[None, int], ...]) – Parameter sent to vmap to produce batched layer-wise gradients. Example: inputs, weights, labels corresponds to (0,None,0).

  • regularization (List [float]) – A list of floats default to 0.0. Specifies the regularization term to be added to the empirical FIM in each layer. This is useful when the empirical FIM is singular or ill-conditioned. The regularization term is regularization * I, where I is the identity matrix directly added to the empirical FIM. The list is of length L, where L is the total number of layers.

  • param_layer_map – Optional[List[int]]: Specifies how the parameters are grouped into layers. Should be the same length as parameters tuple. For example, for a two layer model, params = (0.weights1,0.bias,1.weights,1.bias), param_layer_map should be (0,0,1,1),resulting in two layers as expected.

  • *x – List of arguments for func.

Returns:

A function that takes a tuple v and returns the tuple of IFVPs of the Hessian of func and v.

dattri.func.fisher.ifvp_at_x_ekfac(func: Callable, *x, in_dims: Tuple | None = None, batch_size: int = 1, max_iter: int | None = None, mlp_cache: MLPCache | List[MLPCache], damping: float = 0.0) Callable

IFVP via EK-FAC algorithm.

Standing for the inverse-FIM-vector product, returns a function that, when given vectors, computes the product of inverse-FIM and vector.

EK-FAC algorithm provides layer-wise approximation for the IFVP function. The samples are estimated based on Gauss-Newton Hessian.

Parameters:
  • func (Callable) –

    A Python function that takes one or more arguments. Must return the following, - losses: a tensor of shape (batch_size,). - mask (optional): a tensor of shape (batch_size, t), where 1’s

    indicate that the IFVP will be estimated on these input positions and 0’s indicate that these positions are irrelevant (e.g. padding tokens).

    t is the number of steps, or sequence length of the input data. If the input data are non-sequential, t should be set to 1. The FIM will be estimated on this function.

  • *x – List of arguments for func.

  • in_dims (Tuple, optional) – A tuple with the same shape as *x, indicating which dimension should be considered as batch size dimension. Take the first dimension as batch size dimension by default.

  • batch_size (int) – An integer default to 1, indicating the batch size used for estimating the covariance matrices and lambdas.

  • max_iter (int, optional) – An integer indicating the maximum number of batches that will be used for estimating the the covariance matrices and lambdas.

  • mlp_cache (Union[MLPCache, List[MLPCache]]) – A single or list of registered caches, used to record the input and hidden vectors as well as their relevant gradients during the forward and backward calls of func.

  • damping – Damping factor used for non-convexity in EK-FAC IFVP calculation.

Returns:

A function that takes a tuple of Tensor x and a nested structure of vector v and returns the IFVP of the Hessian of func and v.