Tasks

class dattri.task.AttributionTask(loss_func: Callable, model: nn.Module, checkpoints: str | List[str] | List[Dict[str, torch.Tensor]] | Dict[str, torch.Tensor], target_func: Callable | None = None)

Bases: object

The abstraction of the attribution task information.

__init__(loss_func: Callable, model: nn.Module, checkpoints: str | List[str] | List[Dict[str, torch.Tensor]] | Dict[str, torch.Tensor], target_func: Callable | None = None) None

Initialize the AttributionTask.

Parameters:
  • loss_func (Callable) –

    The loss function of the model training. The function can be quite flexible in terms of what is calculated, but it should take the parameters and the data as input. Other than that, the forwarding of model should be in torch.func style. It will be used as target function to be attributed if no other target function provided A typical example is as follows: ```python def f(params, data):

    image, label = data loss = nn.CrossEntropyLoss() yhat = torch.func.functional_call(model, params, image) return loss(yhat, label)

    ```. This examples calculates the CE loss of the model on the data.

  • model (nn.Module) – The model that the target function is based on. To be more specific, the model is the model used in the target function. Since only the computation graph of the model will be used, so it is allowed that this model is not loaded with a trained parameters.

  • checkpoints – (Union[str, List[str], List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]): The checkpoints of the model, both dictionary of the state_dict and the path to the checkpoint are supported. If ensemble is needed, a list of checkpoint is also supported.

  • target_func (Callable) –

    The target function to be attributed. This input is optional, if not provided, the target function will be the same as the loss function. The function can be quite flexible in terms of what is calculated, but it should take the parameters and the data as input. Other than that, the forwarding of model should be in torch.func style. A typical example is as follows: ```python def f(params, data):

    image, label = data loss = nn.CrossEntropyLoss() yhat = torch.func.functional_call(model, params, image) return loss(yhat, label)

    ```.

get_checkpoints() List[Dict[str, torch.Tensor] | str]

Return the checkpoints of the model.

Returns:

The checkpoints of the task.

Return type:

List[Union[Dict[str, torch.Tensor], str]]

get_grad_loss_func(in_dims: Tuple[None | int, ...] = (None, 1), layer_name: str | List[str] | None = None, ckpt_idx: int | None = None) Callable

Return a function that computes the gradient of the loss function.

Parameters:
  • in_dims (Tuple[Union[None, int], ...]) – The input dimensions of the loss function. This should be a tuple of integers and None. The length of the tuple should be the same as the number of inputs of the loss function. If the input is a scalar, the corresponding element should be None. If the input is a tensor, the corresponding element should be the dimension of the tensor.

  • layer_name (Optional[Union[str, List[str]]]) – The name of the layer as to calculate the gradient w.r.t. If None, all the parameters will be used to calcluate the gradient of loss. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters().

  • ckpt_idx (Optional[int]) – The index of the checkpoint to be loaded, only needed when layer_name is not None.

Returns:

The function that computes the gradient of the loss function.

Return type:

Callable

get_grad_target_func(in_dims: Tuple[None | int, ...] = (None, 1), layer_name: str | List[str] | None = None, ckpt_idx: int | None = None) Callable

Return a function that computes the gradient of the target function.

Parameters:
  • in_dims (Tuple[Union[None, int], ...]) – The input dimensions of the target function. This should be a tuple of integers and None. The length of the tuple should be the same as the number of inputs of the target function. If the input is a scalar, the corresponding element should be None. If the input is a tensor, the corresponding element should be the dimension of the tensor.

  • layer_name (Optional[Union[str, List[str]]]) – The name of the layer as to calculate the gradient w.r.t. If None, all the parameters will be used to calcluate the gradient of target func. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters().

  • ckpt_idx (Optional[int]) – The index of the checkpoint to be loaded, only needed when layer_name is not None.

Returns:

The function that computes the gradient of the target function.

Return type:

Callable

get_loss_func(flatten: bool = True, layer_name: str | List[str] | None = None, ckpt_idx: int | None = None) Callable

Return a function that computes the gradient of the loss function.

Parameters:
  • flatten (bool) – If True, the loss function will be flattened.

  • layer_name (Optional[Union[str, List[str]]]) – The name of the layer as the input to calculate the loss. If None, all the parameters will be used as input of the loss func. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters().

  • ckpt_idx (Optional[int]) – The index of the checkpoint to be loaded, only needed when layer_name is not None.

Returns:

The loss function itself.

Return type:

Callable

Raises:

NotImplementedError – If layer_name is not None.

get_model() Module

Return the model of the task.

Returns:

The model of the task.

Return type:

nn.Module

get_param(ckpt_idx: int = 0, layer_name: str | List[str] | None = None, layer_split: bool | None = False, param_layer_map: List[int] | None = None) Tuple[torch.Tensor | List[torch.Tensor], List[int] | None]

Return the flattened parameter of the model.

Parameters:
  • ckpt_idx (int) – The index of the checkpoint to be loaded.

  • layer_name (Optional[Union[str, List[str]]]) – layer_name is used when only a portion of the parameters are needed to be extracted. It declares the parameters belonging to which layers will be extracted. If None, all the parameters will be returned. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default is None.

  • layer_split (Optional[bool]) – layer_split is used when the returned parameters need to be split by layers. If True, the return value of this function will be a tuple of parameters where each element is the parameters of a layer. If False, the return value will be a flattened tensor of all the parameters. Default is False.

  • param_layer_map (Optional[List[int]]) – A map stating the which element of the parameter tuple belongs to which layer. It is only used when layer_split is True. Default to None, which means the map will be generated automatically. If param_layer_map is explicitly set, it should have the same length as the named_parameters. For example, for two layer model, params = (0.weights1, 0.bias, 1.weights, 1.bias), param_layer_map should be [0, 0, 1, 1]. The explicitly set value will be returned directly.

Returns:

If

layer_split is True, the return value will be a tuple of the parameters of each layer and the param_layer_map. If layer_split is False, the return value will be aflattened parameter of the model and None.

Return type:

Tuple[Union[torch.Tensor, List[torch.Tensor]], Optional[List[int]]]

Raises:

ValueError – If the length of param_layer_map is not the same as the length of named_parameters

get_target_func(flatten: bool = True, layer_name: str | List[str] | None = None, ckpt_idx: int | None = None) Callable

Return a function that computes the target function.

Parameters:
  • flatten (bool) – If True, the target function will be flattened.

  • layer_name (Optional[Union[str, List[str]]]) – The name of the layer as the input to calculate the target func. If None, all the parameters will be used as input of the target func. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters().

  • ckpt_idx (Optional[int]) – The index of the checkpoint to be loaded, only needed when layer_name is not None.

Returns:

The target function itself.

Return type:

Callable

Raises:

NotImplementedError – If layer_name is not None and flatten = False.

register_forward_hook(layer_name: str | List[str]) Tuple[torch.Tensors, ...]

Register forward hook to specified layer_name.

Parameters:

layer_name (Union[str, List[str]]) – The name of the layer to be registered.

Raises:

NotImplementedError – This method has not been implemented yet.