e3tools.nn package¶
Module contents¶
- class e3tools.nn.Attention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]¶
Bases:
ModuleEquivariant attention layer
ref: https://arxiv.org/abs/2006.10503
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor, mask: Tensor | None = None) Tensor | Tuple[Tensor, Tensor][source]¶
Computes the forward pass of the equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.AxisToMul(irreps_in: Irreps, factor: int)[source]¶
Bases:
ModuleCollapses the second-last axis by flattening the irreps. Compatible with torch.compile.
- class e3tools.nn.Conv(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, radial_nn: Callable[[...], Module] | None = None, tensor_product: Callable[[...], Module] | None = None)[source]¶
Bases:
ModuleEquivariant convolution layer
ref: https://arxiv.org/abs/1802.08219
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor) Tensor[source]¶
Computes the forward pass of the equivariant convolution.
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.ConvBlock(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, conv: Callable[[...], Module] | None = None)[source]¶
Bases:
ModuleEquivariant convolution with gated non-linearity and linear self-interaction
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor) Tensor[source]¶
Computes the forward pass of the equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.DepthwiseTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]¶
Bases:
ModuleDepthwise tensor product
ref: https://arxiv.org/abs/2206.11990 ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
- forward(x: Tensor, y: Tensor, weight: Tensor) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class e3tools.nn.EquivariantMLP(irreps_in: Irreps, irreps_out: Irreps, irreps_hidden_list: list[Irreps], act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, norm_layer: Callable[[...], Module] | None = None)[source]¶
Bases:
SequentialAn equivariant multi-layer perceptron with gated non-linearities.
- class e3tools.nn.ExtractIrreps(irreps_in: Irreps, irrep_extract: Irrep)[source]¶
Bases:
ModuleExtracts specific irreps from a tensor.
- class e3tools.nn.FusedConv(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, radial_nn: Callable[[...], Module] | None = None, tensor_product: Callable[[...], Module] | None = None)[source]¶
Bases:
ModuleFused version of equivariant convolution layer with OpenEquivariance kernels.
ref: https://arxiv.org/abs/1802.08219 ref: https://arxiv.org/abs/2501.13986
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor) Tensor[source]¶
Computes the forward pass of the equivariant convolution.
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.FusedSeparableConv(*args, **kwargs)[source]¶
Bases:
FusedConvEquivariant convolution layer using separable tensor product, with fused OpenEquivariance kernels.
ref: https://arxiv.org/abs/1802.08219 ref: https://arxiv.org/abs/2206.11990
- class e3tools.nn.FusedSeparableConvBlock(*args, **kwargs)[source]¶
Bases:
ConvBlocke3tools.nn.ConvBlock with FusedSeparableConv as the underlying convolution layer.
- class e3tools.nn.Gate(irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]¶
Bases:
ModuleEquivariant non-linear gate
- Parameters:
irreps_out (e3nn.o3.Irreps) – output feature irreps (input irreps are inferred from output irreps)
act (Mapping[int, nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.LeakyReLU(), -1: torch.nn.Tanh()}
act_gates (Mapping[int, nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.Sigmoid(), -1: torch.nn.Tanh()}
- class e3tools.nn.GateWrapper(irreps_in: Irreps, irreps_out: Irreps, irreps_gate: Irreps)[source]¶
Bases:
ModuleApplies a linear transformation before and after the gate.
- class e3tools.nn.Gated(layer: Callable[[...], Module], irreps_in: str | Irreps, irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]¶
Bases:
ModuleWraps another layer with an equivariant gate.
- class e3tools.nn.LayerNorm(irreps: Irreps, eps: float = 1e-06)[source]¶
Bases:
ModuleEquivariant layer normalization compatible with torch.compile. Each irrep is normalized independently.
ref: https://github.com/atomicarchitects/equiformer/blob/master/nets/fast_layer_norm.py
- class e3tools.nn.Linear(irreps_in: Irreps, irreps_out: Irreps, *, f_in: int | None = None, f_out: int | None = None, internal_weights: bool | None = None, shared_weights: bool | None = None, instructions: List[Tuple[int, int]] | None = None, biases: bool | List[bool] = False, path_normalization: str = 'element', _optimize_einsums: bool | None = None)[source]¶
Bases:
CodeGenMixin,ModuleLinear operation equivariant to \(O(3)\)
Notes
e3nn.o3.Linear objects created with different partitionings of the same irreps, such as
Linear("10x0e", "0e")andLinear("3x0e + 7x0e", "0e"), are not equivalent: the second module has more instructions, which affects normalization. In a rough sense:Linear(“10x0e”, “0e”) = normalization_coeff_0 * W_0 @ input Linear(“3x0e + 7x0e”, “0e”) = normalization_coeff_1 * W_1 @ input[:3] + normalization_coeff_2 * W_2 @ input[3:]
To make them equivalent, simplify
irreps_inbefore constructing network modules:o3.Irreps(“3x0e + 7x0e”).simplify() # => 10x0e
- Parameters:
irreps_in (e3nn.o3.Irreps) – representation of the input
irreps_out (e3nn.o3.Irreps) – representation of the output
internal_weights (bool) – whether the e3nn.o3.Linear should store its own weights. Defaults to
Trueunlessshared_weightsis explicitly set toFalse, for consistancy with e3nn.o3.TensorProduct.shared_weights (bool) – whether the e3nn.o3.Linear should be weighted individually for each input in a batch. Defaults to
True. Cannot beFalseifinternal_weightsisTrue.instructions (list of 2-tuples, optional) – list of tuples
(i_in, i_out)indicating which irreps inirreps_inshould contribute to which irreps inirreps_out. IfNone(the default), all allowable instructions will be created: every(i_in, i_out)such thatirreps_in[i_in].ir == irreps_out[i_out].ir.biases (list of bool, optional) – indicates for each element of
irreps_outif it has a bias. By default there is no bias. Ifbiases=Trueit gives bias to all scalars (l=0 and p=1).
- weight_numel¶
the size of the weights for this e3nn.o3.Linear
- Type:
int
Examples
Linearly combines 4 scalars into 8 scalars and 16 vectors into 8 vectors.
>>> lin = Linear("4x0e+16x1o", "8x0e+8x1o") >>> lin.weight_numel 160
Create a “block sparse” linear that does not combine two different groups of scalars; note that the number of weights is 4*4 + 3*3 = 25:
>>> lin = Linear("4x0e + 3x0e", "4x0e + 3x0e", instructions=[(0, 0), (1, 1)]) >>> lin.weight_numel 25
Be careful: because they have different instructions, the following two operations are not normalized in the same way, even though they contain all the same “connections”:
>>> lin1 = Linear("10x0e", "0e") >>> lin2 = Linear("3x0e + 7x0e", "0e") >>> lin1.weight_numel == lin2.weight_numel True >>> with torch.no_grad(): ... lin1.weight.fill_(1.0) ... lin2.weight.fill_(1.0) Parameter containing: ... >>> x = torch.arange(10.0) >>> (lin1(x) - lin2(x)).abs().item() < 1e-5 True
- forward(features, weight: Tensor | None = None, bias: Tensor | None = None)[source]¶
evaluate
- Parameters:
features (torch.Tensor) – tensor of shape
(..., irreps_in.dim)weight (torch.Tensor, optional) – required if
internal_weightsis False
- Returns:
tensor of shape
(..., irreps_out.dim)- Return type:
torch.Tensor
- internal_weights: bool¶
- weight_numel: int¶
- weight_view_for_instruction(instruction: int, weight: Tensor | None = None) Tensor[source]¶
View of weights corresponding to
instruction.- Parameters:
instruction (int) – The index of the instruction to get a view on the weights for.
weight (torch.Tensor, optional) – like
weightargument toforward()
- Returns:
A view on
weightor this object’s internal weights for the weights corresponding to theinstructionth instruction.- Return type:
torch.Tensor
- weight_views(weight: Tensor | None = None, yield_instruction: bool = False)[source]¶
Iterator over weight views for all instructions.
- Parameters:
weight (torch.Tensor, optional) – like
weightargument toforward()yield_instruction (bool, default False) – Whether to also yield the corresponding instruction.
- Yields:
If
yield_instructionisTrue, yields(instruction_index, instruction, weight_view).Otherwise, yields
weight_view.
- class e3tools.nn.LinearSelfInteraction(f)[source]¶
Bases:
ModuleEquivariant linear self interaction layer
- Parameters:
f (nn.Module) – Equivariant layer to wrap. f.irreps_in and f.irreps_out must be defined
- class e3tools.nn.MulToAxis(irreps_in: Irreps, factor: int)[source]¶
Bases:
ModuleAdds a new axis by factoring out irreps. Compatible with torch.compile.
- class e3tools.nn.MultiheadAttention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, num_heads: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]¶
Bases:
ModuleEquivariant attention layer with multiple heads
ref: https://arxiv.org/abs/2006.10503
- forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor, mask: Tensor | None = None) Tensor | Tuple[Tensor, Tensor][source]¶
Computes the forward pass of equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
mask ([E] or None)
- Returns:
out
- Return type:
[N, irreps_out.dim]
- class e3tools.nn.ScalarMLP(in_features: int, out_features: int, hidden_features: list[int], activation_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] | None = None, dropout=0.0, bias=True)[source]¶
Bases:
SequentialA multi-layer perceptron for scalar inputs and outputs.
- class e3tools.nn.ScaleIrreps(irreps_in: Tensor)[source]¶
Bases:
ModuleScales each irrep by a weight.
- class e3tools.nn.SeparableConv(*args, **kwargs)[source]¶
Bases:
ConvEquivariant convolution layer using separable tensor product
ref: https://arxiv.org/abs/1802.08219 ref: https://arxiv.org/abs/2206.11990
- class e3tools.nn.SeparableConvBlock(*args, **kwargs)[source]¶
Bases:
ConvBlocke3tools.nn.ConvBlock with SeparableConv as the underlying convolution layer.
- class e3tools.nn.SeparableTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]¶
Bases:
ModuleTensor product factored into depthwise and pointwise components
ref: https://arxiv.org/abs/2206.11990 ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
- forward(x, y, weight)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class e3tools.nn.TransformerBlock(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, edge_attr_dim: int, num_heads: int, irreps_query: Irreps | None = None, irreps_key: Irreps | None = None, irreps_ff_hidden_list: list[Irreps] | None = None, conv: Callable[[...], Module] | None = None)[source]¶
Bases:
ModuleEquivariant transformer block
- forward(node_attr, edge_index, edge_attr, edge_sh)[source]¶
Computes the forward pass of equivariant graph attention
Let N be the number of nodes, and E be the number of edges
- Parameters:
node_attr ([N, irreps_in.dim])
edge_index ([2, E])
edge_attr ([E, edge_attr_dim])
edge_sh ([E, irreps_sh.dim])
- Returns:
out
- Return type:
[N, irreps_out.dim]