e3tools.nn package

Module contents

class e3tools.nn.Attention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]

Bases: Module

Equivariant attention layer

ref: https://arxiv.org/abs/2006.10503

forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor, mask: Tensor | None = None) Tensor | Tuple[Tensor, Tensor][source]

Computes the forward pass of the equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.AxisToMul(irreps_in: Irreps, factor: int)[source]

Bases: Module

Collapses the second-last axis by flattening the irreps. Compatible with torch.compile.

forward(x: Tensor) Tensor[source]

Collapses the second-last axis by flattening the irreps.

Parameters:

x – torch.Tensor of shape […, factor, irreps_in.dim]

Returns:

torch.Tensor of shape […, irreps_out.dim]

class e3tools.nn.Conv(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, radial_nn: Callable[[...], Module] | None = None, tensor_product: Callable[[...], Module] | None = None)[source]

Bases: Module

Equivariant convolution layer

ref: https://arxiv.org/abs/1802.08219

apply_per_edge(node_attr_src, edge_attr, edge_sh)[source]
forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of the equivariant convolution.

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.ConvBlock(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, conv: Callable[[...], Module] | None = None)[source]

Bases: Module

Equivariant convolution with gated non-linearity and linear self-interaction

forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor) Tensor[source]

Computes the forward pass of the equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.EquivariantMLP(irreps_in: Irreps, irreps_out: Irreps, irreps_hidden_list: list[Irreps], act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, norm_layer: Callable[[...], Module] | None = None)[source]

Bases: Sequential

An equivariant multi-layer perceptron with gated non-linearities.

class e3tools.nn.ExperimentalConv(*args, **kwargs)[source]

Bases: Conv

class e3tools.nn.ExperimentalTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]

Bases: Module

Compileable tensor product

forward(x, y, weight)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class e3tools.nn.ExtractIrreps(irreps_in: Irreps, irrep_extract: Irrep)[source]

Bases: Module

Extracts specific irreps from a tensor.

forward(data: Tensor) Tensor[source]

Extracts the specified irreps from the input tensor.

Parameters:

data – torch.Tensor of shape […, irreps_in.dim]

Returns:

torch.Tensor of shape […, irreps_out.dim]

class e3tools.nn.Gate(irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]

Bases: Module

Equivariant non-linear gate

Parameters:
  • irreps_out (e3nn.o3.Irreps) – output feature irreps (input irreps are inferred from output irreps)

  • act (Mapping[int, nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.LeakyReLU(), -1: torch.nn.Tanh()}

  • act_gates (Mapping[int, nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.Sigmoid(), -1: torch.nn.Tanh()}

forward(x: Tensor) Tensor[source]

Apply the gate to the input tensor.

class e3tools.nn.GateWrapper(irreps_in: Irreps, irreps_out: Irreps, irreps_gate: Irreps)[source]

Bases: Module

Applies a linear transformation before and after the gate.

forward(x: Tensor) Tensor[source]

Apply the pre-gate, gate, and post-gate transformations.

class e3tools.nn.Gated(layer: Callable[[...], Module], irreps_in: str | Irreps, irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]

Bases: Module

Wraps another layer with an equivariant gate.

forward(*args, **kwargs)[source]

Apply the layer and then the gate to the input tensor.

class e3tools.nn.LayerNorm(irreps: Irreps, eps: float = 1e-06)[source]

Bases: Module

Equivariant layer normalization compatible with torch.compile. Each irrep is normalized independently.

ref: https://github.com/atomicarchitects/equiformer/blob/master/nets/fast_layer_norm.py

forward(x: Tensor) Tensor[source]

Apply layer normalization to input tensor. Each irrep is normalized independently.

Parameters:

x (torch.Tensor) – Input tensor of shape […, self.irreps_in.dim]

Returns:

Normalized tensor of shape […, self.irreps_out.dim]

Return type:

torch.Tensor

class e3tools.nn.Linear(irreps_in: Irreps, irreps_out: Irreps, *, f_in: int | None = None, f_out: int | None = None, internal_weights: bool | None = None, shared_weights: bool | None = None, instructions: List[Tuple[int, int]] | None = None, biases: bool | List[bool] = False, path_normalization: str = 'element', _optimize_einsums: bool | None = None)[source]

Bases: CodeGenMixin, Module

Linear operation equivariant to \(O(3)\)

Notes

e3nn.o3.Linear objects created with different partitionings of the same irreps, such as Linear("10x0e", "0e") and Linear("3x0e + 7x0e", "0e"), are not equivalent: the second module has more instructions, which affects normalization. In a rough sense:

Linear(“10x0e”, “0e”) = normalization_coeff_0 * W_0 @ input Linear(“3x0e + 7x0e”, “0e”) = normalization_coeff_1 * W_1 @ input[:3] + normalization_coeff_2 * W_2 @ input[3:]

To make them equivalent, simplify irreps_in before constructing network modules:

o3.Irreps(“3x0e + 7x0e”).simplify() # => 10x0e

Parameters:
  • irreps_in (e3nn.o3.Irreps) – representation of the input

  • irreps_out (e3nn.o3.Irreps) – representation of the output

  • internal_weights (bool) – whether the e3nn.o3.Linear should store its own weights. Defaults to True unless shared_weights is explicitly set to False, for consistancy with e3nn.o3.TensorProduct.

  • shared_weights (bool) – whether the e3nn.o3.Linear should be weighted individually for each input in a batch. Defaults to True. Cannot be False if internal_weights is True.

  • instructions (list of 2-tuples, optional) – list of tuples (i_in, i_out) indicating which irreps in irreps_in should contribute to which irreps in irreps_out. If None (the default), all allowable instructions will be created: every (i_in, i_out) such that irreps_in[i_in].ir == irreps_out[i_out].ir.

  • biases (list of bool, optional) – indicates for each element of irreps_out if it has a bias. By default there is no bias. If biases=True it gives bias to all scalars (l=0 and p=1).

weight_numel

the size of the weights for this e3nn.o3.Linear

Type:

int

Examples

Linearly combines 4 scalars into 8 scalars and 16 vectors into 8 vectors.

>>> lin = Linear("4x0e+16x1o", "8x0e+8x1o")
>>> lin.weight_numel
160

Create a “block sparse” linear that does not combine two different groups of scalars; note that the number of weights is 4*4 + 3*3 = 25:

>>> lin = Linear("4x0e + 3x0e", "4x0e + 3x0e", instructions=[(0, 0), (1, 1)])
>>> lin.weight_numel
25

Be careful: because they have different instructions, the following two operations are not normalized in the same way, even though they contain all the same “connections”:

>>> lin1 = Linear("10x0e", "0e")
>>> lin2 = Linear("3x0e + 7x0e", "0e")
>>> lin1.weight_numel == lin2.weight_numel
True
>>> with torch.no_grad():
...     lin1.weight.fill_(1.0)
...     lin2.weight.fill_(1.0)
Parameter containing:
...
>>> x = torch.arange(10.0)
>>> (lin1(x) - lin2(x)).abs().item() < 1e-5
True
forward(features, weight: Tensor | None = None, bias: Tensor | None = None)[source]

evaluate

Parameters:
  • features (torch.Tensor) – tensor of shape (..., irreps_in.dim)

  • weight (torch.Tensor, optional) – required if internal_weights is False

Returns:

tensor of shape (..., irreps_out.dim)

Return type:

torch.Tensor

internal_weights: bool
shared_weights: bool
weight_numel: int
weight_view_for_instruction(instruction: int, weight: Tensor | None = None) Tensor[source]

View of weights corresponding to instruction.

Parameters:
  • instruction (int) – The index of the instruction to get a view on the weights for.

  • weight (torch.Tensor, optional) – like weight argument to forward()

Returns:

A view on weight or this object’s internal weights for the weights corresponding to the instruction th instruction.

Return type:

torch.Tensor

weight_views(weight: Tensor | None = None, yield_instruction: bool = False)[source]

Iterator over weight views for all instructions.

Parameters:
  • weight (torch.Tensor, optional) – like weight argument to forward()

  • yield_instruction (bool, default False) – Whether to also yield the corresponding instruction.

Yields:
  • If yield_instruction is True, yields (instruction_index, instruction, weight_view).

  • Otherwise, yields weight_view.

class e3tools.nn.LinearSelfInteraction(f)[source]

Bases: Module

Equivariant linear self interaction layer

Parameters:

f (nn.Module) – Equivariant layer to wrap. f.irreps_in and f.irreps_out must be defined

forward(x: Tensor, *args) Tensor[source]

Combines the input layer with a skip connection.

class e3tools.nn.MulToAxis(irreps_in: Irreps, factor: int)[source]

Bases: Module

Adds a new axis by factoring out irreps. Compatible with torch.compile.

forward(x: Tensor) Tensor[source]

Adds a new axis by factoring out irreps.

Parameters:

x – torch.Tensor of shape […, irreps_in.dim]

Returns:

torch.Tensor of shape […, factor, irreps_out.dim]

class e3tools.nn.MultiheadAttention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, num_heads: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]

Bases: Module

Equivariant attention layer with multiple heads

ref: https://arxiv.org/abs/2006.10503

forward(node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_sh: Tensor, mask: Tensor | None = None) Tensor | Tuple[Tensor, Tensor][source]

Computes the forward pass of equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

  • mask ([E] or None)

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.ScalarMLP(in_features: int, out_features: int, hidden_features: list[int], activation_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] | None = None, dropout=0.0, bias=True)[source]

Bases: Sequential

A multi-layer perceptron for scalar inputs and outputs.

class e3tools.nn.ScaleIrreps(irreps_in: Tensor)[source]

Bases: Module

Scales each irrep by a weight.

forward(data: Tensor, weights: Tensor) Tensor[source]

Scales each irrep by a weight.

Parameters:
  • data – torch.Tensor of shape […, irreps_in.dim]

  • weights – torch.Tensor of shape […, irreps_in.num_irreps]

Returns:

torch.Tensor of shape […, irreps_in.dim]

class e3tools.nn.SeparableConv(*args, **kwargs)[source]

Bases: Conv

Equivariant convolution layer using separable tensor product

ref: https://arxiv.org/abs/1802.08219 ref: https://arxiv.org/abs/2206.11990

class e3tools.nn.SeparableConvBlock(*args, **kwargs)[source]

Bases: ConvBlock

e3tools.nn.ConvBlock with SeparableConv as the underlying convolution layer.

class e3tools.nn.SeparableTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]

Bases: Module

Tensor product factored into depthwise and pointwise components

ref: https://arxiv.org/abs/2206.11990 ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157

forward(x, y, weight)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class e3tools.nn.TransformerBlock(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, edge_attr_dim: int, num_heads: int, irreps_query: Irreps | None = None, irreps_key: Irreps | None = None, irreps_ff_hidden_list: list[Irreps] | None = None, conv: Callable[[...], Module] | None = None)[source]

Bases: Module

Equivariant transformer block

forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]