Source code for e3tools.nn._interaction
import torch
from e3nn import o3
[docs]
class LinearSelfInteraction(torch.nn.Module):
"""
Equivariant linear self interaction layer
Parameters
----------
f: torch.nn.Module
Equivariant layer to wrap.
f.irreps_in and f.irreps_out must be defined
"""
def __init__(self, f):
super().__init__()
self.f = f
self.irreps_in = f.irreps_in
self.irreps_out = f.irreps_out
self.skip_connection = o3.Linear(self.irreps_in, self.irreps_out)
self.self_interaction = o3.Linear(self.irreps_out, self.irreps_out)
[docs]
def forward(self, x: torch.Tensor, *args) -> torch.Tensor:
"""Combines the input layer with a skip connection."""
s = self.skip_connection(x)
x = self.f(x, *args)
x = self.self_interaction(x)
return x + s