e3tools.nn package¶
Module contents¶
- class e3tools.nn.Attention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]¶
Bases:
Module
Equivariant attention layer
ref: https://arxiv.org/abs/2006.10503
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor, mask: Tensor | None = None) Tensor | Tuple[Tensor, Tensor] [source]¶
Computes the forward pass of the equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.AxisToMul(irreps_in: Irreps, factor: int)[source]¶
Bases:
Module
Collapses the second-last axis by flattening the irreps. Compatible with torch.compile.
- class e3tools.nn.Conv(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, radial_nn: Callable[[...], Module] | None = None, tensor_product: Callable[[...], Module] | None = None)[source]¶
Bases:
Module
Equivariant convolution layer
ref: https://arxiv.org/abs/1802.08219
- forward(node_attr, edge_index, edge_attr, edge_sh)[source]¶
Computes the forward pass of the equivariant convolution.
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.ConvBlock(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, conv: Callable[[...], Module] | None = None)[source]¶
Bases:
Module
Equivariant convolution with gated non-linearity and linear self-interaction
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor) Tensor [source]¶
Computes the forward pass of the equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.EquivariantMLP(irreps_in: Irreps, irreps_out: Irreps, irreps_hidden_list: list[Irreps], act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, norm_layer: Callable[[...], Module] | None = None)[source]¶
Bases:
Sequential
An equivariant multi-layer perceptron with gated non-linearities.
- class e3tools.nn.ExperimentalTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]¶
Bases:
Module
Compileable tensor product
- forward(x, y, weight)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class e3tools.nn.ExtractIrreps(irreps_in: Irreps, irrep_extract: Irrep)[source]¶
Bases:
Module
Extracts specific irreps from a tensor.
- class e3tools.nn.Gate(irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]¶
Bases:
Module
Equivariant non-linear gate
- Parameters:
irreps_out (e3nn.o3.Irreps) – output feature irreps (input irreps are inferred from output irreps)
act (Mapping[int, nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.LeakyReLU(), -1: torch.nn.Tanh()}
act_gates (Mapping[int, nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.Sigmoid(), -1: torch.nn.Tanh()}
- class e3tools.nn.GateWrapper(irreps_in: Irreps, irreps_out: Irreps, irreps_gate: Irreps)[source]¶
Bases:
Module
Applies a linear transformation before and after the gate.
- class e3tools.nn.Gated(layer: Callable[[...], Module], irreps_in: str | Irreps, irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]¶
Bases:
Module
Wraps another layer with an equivariant gate.
- class e3tools.nn.LayerNorm(irreps: Irreps, eps: float = 1e-06)[source]¶
Bases:
Module
Equivariant layer normalization compatible with torch.compile. Each irrep is normalized independently.
ref: https://github.com/atomicarchitects/equiformer/blob/master/nets/fast_layer_norm.py
- class e3tools.nn.Linear(irreps_in: Irreps, irreps_out: Irreps, *, f_in: int | None = None, f_out: int | None = None, internal_weights: bool | None = None, shared_weights: bool | None = None, instructions: List[Tuple[int, int]] | None = None, biases: bool | List[bool] = False, path_normalization: str = 'element', _optimize_einsums: bool | None = None)[source]¶
Bases:
CodeGenMixin
,Module
Linear operation equivariant to \(O(3)\)
Notes
e3nn.o3.Linear objects created with different partitionings of the same irreps, such as
Linear("10x0e", "0e")
andLinear("3x0e + 7x0e", "0e")
, are not equivalent: the second module has more instructions, which affects normalization. In a rough sense:Linear(“10x0e”, “0e”) = normalization_coeff_0 * W_0 @ input Linear(“3x0e + 7x0e”, “0e”) = normalization_coeff_1 * W_1 @ input[:3] + normalization_coeff_2 * W_2 @ input[3:]
To make them equivalent, simplify
irreps_in
before constructing network modules:o3.Irreps(“3x0e + 7x0e”).simplify() # => 10x0e
- Parameters:
irreps_in (e3nn.o3.Irreps) – representation of the input
irreps_out (e3nn.o3.Irreps) – representation of the output
internal_weights (bool) – whether the e3nn.o3.Linear should store its own weights. Defaults to
True
unlessshared_weights
is explicitly set toFalse
, for consistancy with e3nn.o3.TensorProduct.shared_weights (bool) – whether the e3nn.o3.Linear should be weighted individually for each input in a batch. Defaults to
True
. Cannot beFalse
ifinternal_weights
isTrue
.instructions (list of 2-tuples, optional) – list of tuples
(i_in, i_out)
indicating which irreps inirreps_in
should contribute to which irreps inirreps_out
. IfNone
(the default), all allowable instructions will be created: every(i_in, i_out)
such thatirreps_in[i_in].ir == irreps_out[i_out].ir
.biases (list of bool, optional) – indicates for each element of
irreps_out
if it has a bias. By default there is no bias. Ifbiases=True
it gives bias to all scalars (l=0 and p=1).
- weight_numel¶
the size of the weights for this e3nn.o3.Linear
- Type:
int
Examples
Linearly combines 4 scalars into 8 scalars and 16 vectors into 8 vectors.
>>> lin = Linear("4x0e+16x1o", "8x0e+8x1o") >>> lin.weight_numel 160
Create a “block sparse” linear that does not combine two different groups of scalars; note that the number of weights is 4*4 + 3*3 = 25:
>>> lin = Linear("4x0e + 3x0e", "4x0e + 3x0e", instructions=[(0, 0), (1, 1)]) >>> lin.weight_numel 25
Be careful: because they have different instructions, the following two operations are not normalized in the same way, even though they contain all the same “connections”:
>>> lin1 = Linear("10x0e", "0e") >>> lin2 = Linear("3x0e + 7x0e", "0e") >>> lin1.weight_numel == lin2.weight_numel True >>> with torch.no_grad(): ... lin1.weight.fill_(1.0) ... lin2.weight.fill_(1.0) Parameter containing: ... >>> x = torch.arange(10.0) >>> (lin1(x) - lin2(x)).abs().item() < 1e-5 True
- forward(features, weight: Tensor | None = None, bias: Tensor | None = None)[source]¶
evaluate
- Parameters:
features (torch.Tensor) – tensor of shape
(..., irreps_in.dim)
weight (torch.Tensor, optional) – required if
internal_weights
is False
- Returns:
tensor of shape
(..., irreps_out.dim)
- Return type:
torch.Tensor
- internal_weights: bool¶
- weight_numel: int¶
- weight_view_for_instruction(instruction: int, weight: Tensor | None = None) Tensor [source]¶
View of weights corresponding to
instruction
.- Parameters:
instruction (int) – The index of the instruction to get a view on the weights for.
weight (torch.Tensor, optional) – like
weight
argument toforward()
- Returns:
A view on
weight
or this object’s internal weights for the weights corresponding to theinstruction
th instruction.- Return type:
torch.Tensor
- weight_views(weight: Tensor | None = None, yield_instruction: bool = False)[source]¶
Iterator over weight views for all instructions.
- Parameters:
weight (torch.Tensor, optional) – like
weight
argument toforward()
yield_instruction (bool, default False) – Whether to also yield the corresponding instruction.
- Yields:
If
yield_instruction
isTrue
, yields(instruction_index, instruction, weight_view)
.Otherwise, yields
weight_view
.
- class e3tools.nn.LinearSelfInteraction(f)[source]¶
Bases:
Module
Equivariant linear self interaction layer
- Parameters:
f (nn.Module) – Equivariant layer to wrap. f.irreps_in and f.irreps_out must be defined
- class e3tools.nn.MulToAxis(irreps_in: Irreps, factor: int)[source]¶
Bases:
Module
Adds a new axis by factoring out irreps. Compatible with torch.compile.
- class e3tools.nn.MultiheadAttention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, num_heads: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]¶
Bases:
Module
Equivariant attention layer with multiple heads
ref: https://arxiv.org/abs/2006.10503
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor, mask: Tensor | None = None) Tensor | Tuple[Tensor, Tensor] [source]¶
Computes the forward pass of equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
mask ([E] or None)
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.ScalarMLP(in_features: int, out_features: int, hidden_features: list[int], activation_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] | None = None, dropout=0.0, bias=True)[source]¶
Bases:
Sequential
A multi-layer perceptron for scalar inputs and outputs.
- class e3tools.nn.ScaleIrreps(irreps_in: Tensor)[source]¶
Bases:
Module
Scales each irrep by a weight.
- class e3tools.nn.SeparableConv(*args, **kwargs)[source]¶
Bases:
Conv
Equivariant convolution layer using separable tensor product
ref: https://arxiv.org/abs/1802.08219 ref: https://arxiv.org/abs/2206.11990
- class e3tools.nn.SeparableConvBlock(*args, **kwargs)[source]¶
Bases:
ConvBlock
e3tools.nn.ConvBlock with SeparableConv as the underlying convolution layer.
- class e3tools.nn.SeparableTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]¶
Bases:
Module
Tensor product factored into depthwise and pointwise components
ref: https://arxiv.org/abs/2206.11990 ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
- forward(x, y, weight)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class e3tools.nn.TransformerBlock(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, edge_attr_dim: int, num_heads: int, irreps_query: Irreps | None = None, irreps_key: Irreps | None = None, irreps_ff_hidden_list: list[Irreps] | None = None, conv: Callable[[...], Module] | None = None)[source]¶
Bases:
Module
Equivariant transformer block
- forward(node_attr, edge_index, edge_attr, edge_sh)[source]¶
Computes the forward pass of equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]