Source code for qcd_ml.nn.matrix_layers.convolution
r"""
------------
Convolutions for matrix-like fields, i.e., fields that transform as
.. math::
M(x) \rightarrow \Omega(x) M(x) \Omega(x)
"""
import torch
from ...base.paths import PathBuffer
[docs]
class LGE_Convolution(torch.nn.Module):
r"""
Provides a convolution for matrix-like fields, i.e., fields that transform as
.. math::
M(x) \rightarrow \Omega(x) M(x) \Omega(x)
The convolution is defined as
.. math::
W_i(x) \rightarrow \sum_{j\mu k} \omega_{i\mu k j} U_{\mu k}(x) W_j(x+k\mu) U_{\mu k}^\dagger(x)
See 10.1103/PhysRevLett.128.032003 for more details.
We implement this convolution differently: We define a gauge transporter along an arbitrary path :math:`T_p` as
.. math::
(T_p(M))(x) = ((\prod\limits_{\mu_k \in p} H_{\mu_k}) M)(x)
where
.. math::
(H_{\mu} M)(x) = U_{\mu}(x) M(x + \mu) U_{\mu}(x+\mu)^\dagger
Then, the convolution is defined as
.. math::
W_i(x) \rightarrow \sum_{jik} \omega_{j i k} T_{p_k}(W_j)(x)
"""
def __init__(self, n_input, n_output, paths, disable_cache=True):
super(LGE_Convolution, self).__init__()
self.n_input = n_input
self.n_output = n_output
self.paths = paths
self.disable_cache = disable_cache
# Store path buffers by link field.
# We expect that the link field is a torch tensor. In this case
# use id(U) as a key for the hash. This seems OK, since it is
# recommended here:
# https://github.com/pytorch/pytorch/issues/7733#issuecomment-390912112
# See also the entire issue discussion
# https://github.com/pytorch/pytorch/issues/7733.
self.path_buffer_cache = {}
self.weights = torch.nn.Parameter(
torch.randn(n_input
, n_output
, len(paths)
, dtype=torch.cdouble))
[docs]
def clear_path_buffers(self):
"""
If ``disable_cache=False``, this method can be used to clear the pre-computed cache.
"""
self.path_buffer_cache = {}
[docs]
def forward(self, U, features_in):
r"""
.. math::
W_i(x) \rightarrow \sum_{j\mu k} \omega_{i\mu k j} U_{\mu k}(x) W_j(x+k\mu) U_{\mu k}^\dagger(x)
"""
if id(U) in self.path_buffer_cache:
path_buffers = self.path_buffer_cache[id(U)]
else:
path_buffers = [PathBuffer(U, path) for path in self.paths]
if not self.disable_cache:
self.path_buffer_cache[id(U)] = path_buffers
transported = torch.stack([
torch.stack([pi.m_transport(fj) for pi in path_buffers])
for fj in features_in])
return torch.einsum("ikl,il...->k...", self.weights, transported)