from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
from torch_geometric.utils.mixin import CastMixin
class ExplanationType(Enum):
"""Enum class for the explanation type."""
model = 'model'
phenomenon = 'phenomenon'
class MaskType(Enum):
"""Enum class for the mask type."""
object = 'object'
common_attributes = 'common_attributes'
attributes = 'attributes'
class ModelMode(Enum):
"""Enum class for the model return type."""
classification = 'classification'
regression = 'regression'
class ModelTaskLevel(Enum):
"""Enum class for the model task level."""
node = 'node'
edge = 'edge'
graph = 'graph'
class ModelReturnType(Enum):
"""Enum class for the model return type."""
raw = 'raw'
probs = 'probs'
log_probs = 'log_probs'
class ThresholdType(Enum):
"""Enum class for the threshold type."""
hard = 'hard'
topk = 'topk'
topk_hard = 'topk_hard'
# connected = 'connected' # TODO
[docs]@dataclass
class ExplainerConfig(CastMixin):
r"""Configuration class to store and validate high level explanation
parameters.
Args:
explanation_type (ExplanationType or str): The type of explanation to
compute. The possible values are:
- :obj:`"model"`: Explains the model prediction.
- :obj:`"phenomenon"`: Explains the phenomenon that the model
is trying to predict.
In practice, this means that the explanation algorithm will either
compute their losses with respect to the model output or the target
output.
node_mask_type (MaskType or str, optional): The type of mask to apply
on nodes. The possible values are (default: :obj:`None`):
- :obj:`None`: Will not apply any mask on nodes.
- :obj:`"object"`: Will mask each node.
- :obj:`"common_attributes"`: Will mask each feature.
- :obj:`"attributes"`: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional): The type of mask to apply
on edges. Same types as :obj:`node_mask_type`.
(default: :obj:`None`)
"""
explanation_type: ExplanationType
node_mask_type: Optional[MaskType]
edge_mask_type: Optional[MaskType]
def __init__(
self,
explanation_type: Union[ExplanationType, str],
node_mask_type: Optional[Union[MaskType, str]] = None,
edge_mask_type: Optional[Union[MaskType, str]] = None,
):
if node_mask_type is not None:
node_mask_type = MaskType(node_mask_type)
if edge_mask_type is not None:
edge_mask_type = MaskType(edge_mask_type)
self.explanation_type = ExplanationType(explanation_type)
self.node_mask_type = node_mask_type
self.edge_mask_type = edge_mask_type
if self.node_mask_type is None and self.edge_mask_type is None:
raise ValueError("Either 'node_mask_type' or 'edge_mask_type' "
"must be provided.")
[docs]@dataclass
class ModelConfig(CastMixin):
r"""Configuration class to store model parameters.
Args:
mode (ModelMode or str): The mode of the model. The possible values
are:
- :obj:`"classification"`: A classification model.
- :obj:`"regression"`: A regression model.
task_level (ModelTaskLevel or str): The task-level of the model.
The possible values are:
- :obj:`"node"`: A node-level prediction model.
- :obj:`"edge"`: An edge-level prediction model.
- :obj:`"graph"`: A graph-level prediction model.
return_type (ModelReturnType or str, optional): The return type of the
model. The possible values are (default: :obj:`None`):
- :obj:`"raw"`: The model returns raw values.
- :obj:`"probs"`: The model returns probabilities.
- :obj:`"log_probs"`: The model returns log-probabilities.
"""
mode: ModelMode
task_level: ModelTaskLevel
return_type: ModelReturnType
def __init__(
self,
mode: Union[ModelMode, str],
task_level: Union[ModelTaskLevel, str],
return_type: Optional[Union[ModelReturnType, str]] = None,
):
self.mode = ModelMode(mode)
self.task_level = ModelTaskLevel(task_level)
if return_type is None and self.mode == ModelMode.regression:
return_type = ModelReturnType.raw
self.return_type = ModelReturnType(return_type)
if (self.mode == ModelMode.regression
and self.return_type != ModelReturnType.raw):
raise ValueError(f"A model for regression needs to return raw "
f"outputs (got {self.return_type.value})")
[docs]@dataclass
class ThresholdConfig(CastMixin):
r"""Configuration class to store and validate threshold parameters.
Args:
threshold_type (ThresholdType or str): The type of threshold to apply.
The possible values are:
- :obj:`None`: No threshold is applied.
- :obj:`"hard"`: A hard threshold is applied to each mask.
The elements of the mask with a value below the :obj:`value`
are set to :obj:`0`, the others are set to :obj:`1`.
- :obj:`"topk"`: A soft threshold is applied to each mask.
The top obj:`value` elements of each mask are kept, the
others are set to :obj:`0`.
- :obj:`"topk_hard"`: Aame as :obj:`"topk"` but values are set
to :obj:`1` for all elements which are kept.
value (int or float, optional): The value to use when thresholding.
(default: :obj:`None`)
"""
type: ThresholdType
value: Union[float, int]
def __init__(
self,
threshold_type: Union[ThresholdType, str],
value: Union[float, int],
):
self.type = ThresholdType(threshold_type)
self.value = value
if not isinstance(self.value, (int, float)):
raise ValueError(f"Threshold value must be a float or int "
f"(got {type(self.value)}).")
if (self.type == ThresholdType.hard
and (self.value < 0 or self.value > 1)):
raise ValueError(f"Threshold value must be between 0 and 1 "
f"(got {self.value})")
if self.type in [ThresholdType.topk, ThresholdType.topk_hard]:
if not isinstance(self.value, int):
raise ValueError(f"Threshold value needs to be an integer "
f"(got {type(self.value)}).")
if self.value <= 0:
raise ValueError(f"Threshold value needs to be positive "
f"(got {self.value}).")