Source code for qcd_ml.nn.matrix_layers.activation

r"""
------------

Activation functions for matrix-like fields, i.e., fields that transform as 

.. math::
    M(x) \rightarrow \Omega(x) M(x) \Omega(x)

"""

import torch


[docs] class LGE_ReTrAct(torch.nn.Module): r""" Given an activation function ``activation`` (:math:`F`) applies .. math:: W_j(x) \rightarrow F(\omega_j \mbox{Re}\mbox{Tr}(W_j(x)) \alpha_j) W_j(x) """ def __init__(self, activation, n_features): super(LGE_ReTrAct, self).__init__() self.activation = activation self.biases = torch.nn.Parameter(torch.randn(n_features, 1, 1, 1, 1, dtype=torch.double)) self.weights = torch.nn.Parameter(torch.randn(n_features, 1, 1, 1, 1, dtype=torch.double))
[docs] def forward(self, features): r""" .. math:: W_j(x) \rightarrow F(\omega_j \mbox{Re}\mbox{Tr}(W_j(x)) \alpha_j) W_j(x) """ re_tr = torch.einsum("...ii->...", features.real) prefactor = self.activation(self.weights.expand_as(re_tr) * re_tr + self.biases.expand_as(re_tr)) return torch.einsum("fabcd, fabcdij->fabcdij" , prefactor, features)