e3tools.nn package

Module contents

class e3tools.nn.Attention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]

Bases: Module

Equivariant attention layer

ref: https://arxiv.org/abs/2006.10503

forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of the equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.AxisToMul(irreps_in: Irreps, factor: int)[source]

Bases: Module

Collapses the second-last axis by flattening the irreps.

forward(x: Tensor) Tensor[source]

Collapses the second-last axis by flattening the irreps.

Parameters:

x – torch.Tensor of shape […, factor, irreps.dim // factor]

Returns:

torch.Tensor of shape […, irreps.dim]

class e3tools.nn.Conv(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, radial_nn: Callable[[...], Module] | None = None, tensor_product: Callable[[...], Module] | None = None)[source]

Bases: Module

Equivariant convolution layer

ref: https://arxiv.org/abs/1802.08219

apply_per_edge(node_attr_src, edge_attr, edge_sh)[source]
forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of the equivariant convolution.

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.ConvBlock(irreps_in: str | Irreps, irreps_out: str | Irreps, irreps_sh: str | Irreps, edge_attr_dim: int, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, conv: Callable[[...], Module] | None = None)[source]

Bases: Module

Equivariant convolution with gated non-linearity and linear self-interaction

forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of the equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.EquivariantMLP(irreps_in: Irreps, irreps_out: Irreps, irreps_hidden_list: list[Irreps], act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None, norm_layer: Callable[[...], Module] | None = None)[source]

Bases: Sequential

An equivariant multi-layer perceptron with gated non-linearities.

class e3tools.nn.ExperimentalConv(*args, **kwargs)[source]

Bases: Conv

class e3tools.nn.ExperimentalTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]

Bases: Module

Compileable tensor product

forward(x, y, weight)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class e3tools.nn.ExtractIrreps(irreps_in: Irreps, irrep_extract: Irrep)[source]

Bases: Module

Extracts specific irreps from a e3nn tensor.

forward(data: Tensor) Tensor[source]

Extracts the specified irreps from the input tensor.

Parameters:

data – torch.Tensor of shape […, irreps_in.dim]

Returns:

torch.Tensor of shape […, irreps_out.dim]

class e3tools.nn.Gate(irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]

Bases: Module

Equivariant non-linear gate

Parameters:
  • irreps_out (e3nn.o3.Irreps) – output feature irreps (input irreps are inferred from output irreps)

  • act (Mapping[int, torch.nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.LeakyReLU(), -1: torch.nn.Tanh()}

  • act_gates (Mapping[int, torch.nn.Module]) – Mapping from parity to activation module. If None defaults to {1 : torch.nn.Sigmoid(), -1: torch.nn.Tanh()}

forward(x: Tensor) Tensor[source]

Apply the gate to the input tensor.

class e3tools.nn.GateWrapper(irreps_in: Irreps, irreps_out: Irreps, irreps_gate: Irreps)[source]

Bases: Module

Applies a linear transformation before and after the gate.

forward(x: Tensor) Tensor[source]

Apply the pre-gate, gate, and post-gate transformations.

class e3tools.nn.Gated(layer: Callable[[...], Module], irreps_in: str | Irreps, irreps_out: str | Irreps, act: Mapping[int, Module] | None = None, act_gates: Mapping[int, Module] | None = None)[source]

Bases: Module

Wraps another layer with an equivariant gate.

forward(*args, **kwargs)[source]

Apply the layer and then the gate to the input tensor.

class e3tools.nn.LayerNorm(irreps: Irreps, eps: float = 1e-06)[source]

Bases: Module

Equivariant layer normalization.

ref: https://github.com/atomicarchitects/equiformer/blob/master/nets/fast_layer_norm.py

forward(x: Tensor) Tensor[source]

Apply layer normalization to input tensor. Each irrep is normalized independently.

class e3tools.nn.LinearSelfInteraction(f)[source]

Bases: Module

Equivariant linear self interaction layer

Parameters:

f (torch.nn.Module) – Equivariant layer to wrap. f.irreps_in and f.irreps_out must be defined

forward(x: Tensor, *args) Tensor[source]

Combines the input layer with a skip connection.

class e3tools.nn.MulToAxis(irreps_in: Irreps, factor: int)[source]

Bases: Module

Adds a new axis by factoring out irreps.

forward(x: Tensor) Tensor[source]

Adds a new axis by factoring out irreps.

Parameters:

x – torch.Tensor of shape […, irreps.dim]

Returns:

torch.Tensor of shape […, factor, irreps.dim // factor]

class e3tools.nn.MultiheadAttention(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, irreps_query: Irreps, irreps_key: Irreps, edge_attr_dim: int, n_head: int, conv: Callable[[...], Module] | None = None, return_attention: bool = False)[source]

Bases: Module

Equivariant attention layer with multiple heads

ref: https://arxiv.org/abs/2006.10503

forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]

class e3tools.nn.ScalarMLP(in_features: int, out_features: int, hidden_features: list[int], activation_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Callable[[...], ~torch.nn.modules.module.Module] | None = None, dropout=0.0, bias=True)[source]

Bases: Sequential

A multi-layer perceptron for scalar inputs and outputs.

class e3tools.nn.ScaleIrreps(irreps_in: Tensor)[source]

Bases: Module

Scales each irrep by a weight.

forward(data: Tensor, weights: Tensor) Tensor[source]

Scales each irrep by a weight.

Parameters:
  • data – torch.Tensor of shape […, irreps_in.dim]

  • weights – torch.Tensor of shape […, irreps_in.num_irreps]

Returns:

torch.Tensor of shape […, irreps_in.dim]

class e3tools.nn.SeparableConv(*args, **kwargs)[source]

Bases: Conv

Equivariant convolution layer using separable tensor product

ref: https://arxiv.org/abs/1802.08219 ref: https://arxiv.org/abs/2206.11990

class e3tools.nn.SeparableTensorProduct(irreps_in1: str | Irreps, irreps_in2: str | Irreps, irreps_out: str | Irreps)[source]

Bases: Module

Tensor product factored into depthwise and point wise components

ref: https://arxiv.org/abs/2206.11990 ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157

forward(x, y, weight)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class e3tools.nn.TransformerBlock(irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps, edge_attr_dim: int, n_head: int = 1, irreps_query: Irreps | None = None, irreps_key: Irreps | None = None, irreps_ff_hidden_list: list[Irreps] | None = None, conv: Callable[[...], Module] | None = None)[source]

Bases: Module

Equivariant transformer block

forward(node_attr, edge_index, edge_attr, edge_sh)[source]

Computes the forward pass of equivariant graph attention

Let N be the number of nodes, and E be the number of edges

Parameters:
  • node_attr ([N, irreps_in.dim])

  • edge_index ([2, E])

  • edge_attr ([E, edge_attr_dim])

  • edge_sh ([E, irreps_sh.dim])

Returns:

out

Return type:

[N, irreps_out.dim]