Source code for crabnet.kingcrab

"""Contains classes for transformer architecture within CrabNet."""
from os.path import join, dirname

import numpy as np
import pandas as pd

import torch
from torch import nn
from collections import OrderedDict

RNG_SEED = 42
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)
data_type_torch = torch.float32


# %%
[docs]class ResidualNetwork(nn.Module): """ Feed forward Residual Neural Network as seen in Roost. https://doi.org/10.1038/s41467-020-19964-7 """
[docs] def __init__(self, input_dim, output_dim, hidden_layer_dims, bias=False): """Instantiate a ResidualNetwork model. Parameters ---------- input_dim : int Input dimensions for the Residual Network, specified in SubCrab() model class, by default 512 output_dim : int Output dimensions for Residual Network, by default 3 hidden_layer_dims : list(int) Hidden layer architecture for the Residual Network, by default [1024, 512, 256, 128] bias : bool Whether to bias the linear network, by default False """ super(ResidualNetwork, self).__init__() dims = [input_dim] + hidden_layer_dims self.fcs = nn.ModuleList( [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] ) self.res_fcs = nn.ModuleList( [ nn.Linear(dims[i], dims[i + 1], bias=bias) if (dims[i] != dims[i + 1]) else nn.Identity() for i in range(len(dims) - 1) ] ) self.acts = nn.ModuleList([nn.LeakyReLU() for _ in range(len(dims) - 1)]) self.fc_out = nn.Linear(dims[-1], output_dim)
[docs] def forward(self, fea): """Propagate Residual Network weights forward. Parameters ---------- fea : torch.tensor (n_dim) Tensor output of self attention block Returns ------- fc_out The output of the Residual Network """ for fc, res_fc, act in zip(self.fcs, self.res_fcs, self.acts): fea = act(fc(fea)) + res_fc(fea) return self.fc_out(fea)
def __repr__(self): """Return the class name.""" return f"{self.__class__.__name__}"
[docs]class TransferNetwork(nn.Module): """Learn extended representations of materials during transfer learning. This network was designed to have little impact on predictions during training and enhance learning with the inclusion of extended features. """
[docs] def __init__(self, input_dims, output_dims): """Instantiate a TransferNetwork to learn extended representations. Parameters ---------- input_dims : int Dimensions of input layer output_dims : int Dimensions of output layer """ super().__init__() self.layers = nn.Sequential( OrderedDict( [ ("fc1", nn.Linear(input_dims, 512)), ("leakyrelu1", nn.LeakyReLU()), ("fc2", nn.Linear(512, output_dims)), ("leakyrelu2", nn.LeakyReLU()), ] ) )
[docs] def forward(self, x): """Perform a forward pass of the TransferNetwork. Parameters ---------- x : _type_ _description_ Returns ------- _type_ _description_ """ x = self.layers(x) return x
[docs]class Embedder(nn.Module): """Perform composition-based embeddings of elemental features."""
[docs] def __init__( self, d_model: int, compute_device: str = None, elem_prop: str = "mat2vec", ): """Embed elemental features, similar to CBFV. Parameters ---------- d_model : int Row dimenions of elemental emeddings, by default 512 compute_device : str Name of device which the model will be run on elem_prop : str Which elemental feature vector to use. Possible values are "jarvis", "magpie", "mat2vec", "oliynyk", "onehot", "ptable", and "random_200", by default "mat2vec" """ super().__init__() self.d_model = d_model self.compute_device = compute_device elem_dir = join(dirname(__file__), "data", "element_properties") # # Choose what element information the model receives mat2vec = join(elem_dir, elem_prop + ".csv") # element embedding # mat2vec = f'{elem_dir}/onehot.csv' # onehot encoding (atomic number) # mat2vec = f'{elem_dir}/random_200.csv' # random vec for elements cbfv = pd.read_csv(mat2vec, index_col=0).values feat_size = cbfv.shape[-1] self.fc_mat2vec = nn.Linear(feat_size, d_model).to(self.compute_device) zeros = np.zeros((1, feat_size)) cat_array = np.concatenate([zeros, cbfv]) cat_array = torch.as_tensor(cat_array, dtype=data_type_torch) # NOTE: Parameters within nn.Embedding self.cbfv = nn.Embedding.from_pretrained(cat_array).to( self.compute_device, dtype=data_type_torch )
[docs] def forward(self, src): """Compute forward call for embedder class to perform elemental embeddings. Parameters ---------- src : torch.tensor Tensor containing element numbers corresponding to elements in compound Returns ------- torch.tensor Tensor containing elemental embeddings for compounds, reduced to d_model dimensions """ mat2vec_emb = self.cbfv(src) x_emb = self.fc_mat2vec(mat2vec_emb) return x_emb
# %%
[docs]class FractionalEncoder(nn.Module): """Encode element fractional amount using a "fractional encoding". This is inspired by the positional encoder discussed by Vaswani. https://arxiv.org/abs/1706.03762 """
[docs] def __init__(self, d_model, resolution=100, log10=False, compute_device=None): """Instantiate the FractionalEncoder. Parameters ---------- d_model : int Model size, see paper, by default 512 resolution : int Number of discretizations for the fractional prevalence encoding, by default 100 log10 : bool Whether to apply a log operation to fraction prevalence encoding, by default False compute_device : str The compute device to store and run the FractionalEncoder class """ super().__init__() self.d_model = d_model // 2 self.resolution = resolution self.log10 = log10 self.compute_device = compute_device x = torch.linspace( 0, self.resolution - 1, self.resolution, requires_grad=False ).view(self.resolution, 1) fraction = ( torch.linspace(0, self.d_model - 1, self.d_model, requires_grad=False) .view(1, self.d_model) .repeat(self.resolution, 1) ) pe = torch.zeros(self.resolution, self.d_model) pe[:, 0::2] = torch.sin(x / torch.pow(50, 2 * fraction[:, 0::2] / self.d_model)) pe[:, 1::2] = torch.cos(x / torch.pow(50, 2 * fraction[:, 1::2] / self.d_model)) pe = self.register_buffer("pe", pe)
[docs] def forward(self, x): """Perform the forward pass of the fractional encoding. Parameters ---------- x : torch.tensor Tensor of linear spaced values based on defined resolution Returns ------- out Sinusoidal expansions of elemental fractions """ x = x.clone() if self.log10: x = 0.0025 * (torch.log2(x)) ** 2 x[x > 1] = 1 # x = 1 - x # for sinusoidal encoding at x=0 x[x < 1 / self.resolution] = 1 / self.resolution frac_idx = torch.round(x * (self.resolution)).to(dtype=torch.long) - 1 out = self.pe[frac_idx] return out
# %%
[docs]class Encoder(nn.Module): """Create elemental descriptor matrix via element embeddings and frac. encodings. See the CrabNet paper for further details: https://www.nature.com/articles/s41524-021-00545-1 """
[docs] def __init__( self, d_model, N, heads, extend_features=None, fractional=True, attention=True, compute_device=None, pe_resolution=5000, ple_resolution=5000, elem_prop="mat2vec", emb_scaler=1.0, pos_scaler=1.0, pos_scaler_log=1.0, dim_feedforward=2048, dropout=0.1, ): """Instantiate the Encoder class to create elemental descriptor matrix (EDM). Parameters ---------- d_model : _type_ _description_ N : int, optional Number of encoder layers, by default 3 heads : int, optional Number of attention heads to use, by default 4 extend_features : Optional[List[str]] Additional features to grab from columns of the other DataFrames (e.g. state variables such as temperature or applied load), by default None fractional : bool, optional Whether to weight each element by its fractional contribution, by default True. attention : bool, optional Whether to perform self attention, by default True pe_resolution : int, optional Number of discretizations for the prevalence encoding, by default 5000 ple_resolution : int, optional Number of discretizations for the prevalence log encoding, by default 5000 elem_prop : str, optional Which elemental feature vector to use. Possible values are "jarvis", "magpie", "mat2vec", "oliynyk", "onehot", "ptable", and "random_200", by default "mat2vec" emb_scaler : float, optional _description_, by default 1.0 pos_scaler : float, optional Scaling factor applied to fractional encoder, by default 1.0 pos_scaler_log : float, optional Scaling factor applied to log fractional encoder, by default 1.0 dim_feedforward : int, optional Dimenions of the feed forward network following transformer, by default 2048 dropout : float, optional Percent dropout in the feed forward network following the transformer, by default 0.1 """ super().__init__() self.d_model = d_model self.N = N self.heads = heads self.extend_features = extend_features self.fractional = fractional self.attention = attention self.compute_device = compute_device self.pe_resolution = pe_resolution self.ple_resolution = ple_resolution self.elem_prop = elem_prop self.embed = Embedder(d_model=self.d_model, compute_device=self.compute_device) self.prevalence_encoder = FractionalEncoder( self.d_model, resolution=pe_resolution, log10=False ) self.prevalence_log_encoder = FractionalEncoder( self.d_model, resolution=ple_resolution, log10=True ) self.emb_scaler = nn.parameter.Parameter(torch.tensor([emb_scaler])) self.pos_scaler = nn.parameter.Parameter(torch.tensor([pos_scaler])) self.pos_scaler_log = nn.parameter.Parameter(torch.tensor([pos_scaler_log])) if self.attention: encoder_layer = nn.TransformerEncoderLayer( self.d_model, nhead=self.heads, dim_feedforward=dim_feedforward, dropout=dropout, ) self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers=self.N )
[docs] def forward(self, src, frac, extra_features=None): """Compute the forward pass for encoding the elemental descriptor matrix. Parameters ---------- src : torch.tensor Tensor containing integers corresponding to elements in compound frac : torch.tensor Tensor containing the fractions of each element in compound extra_features : bool, optional Whether to append extra features after encoding, by default None Returns ------- torch.tensor Tensor containing flattened transformer representations of compounds concatenated with extended features. """ x = self.embed(src) * self.emb_scaler # * 2 ** self.emb_scaler pe = torch.zeros_like(x) ple = torch.zeros_like(x) pe_scaler = self.pos_scaler ple_scaler = self.pos_scaler_log pe[:, :, : self.d_model // 2] = self.prevalence_encoder(frac) * pe_scaler ple[:, :, self.d_model // 2 :] = self.prevalence_log_encoder(frac) * ple_scaler mask = frac.unsqueeze(dim=-1) mask = torch.matmul(mask, mask.transpose(-2, -1)) mask[mask != 0] = 1 src_mask = mask[:, 0] != 1 if self.attention: x_src = x + pe + ple x_src = x_src.transpose(0, 1) x = self.transformer_encoder(x_src, src_key_padding_mask=src_mask) x = x.transpose(0, 1) if self.fractional: x = x * frac.unsqueeze(2).repeat(1, 1, self.d_model) hmask = mask[:, :, 0:1].repeat(1, 1, self.d_model) if mask is not None: x = x.masked_fill(hmask == 0, 0) if self.extend_features is not None: n_elements = x.shape[1] X_extra = extra_features.repeat(1, 1, n_elements).permute([1, 2, 0]) x = torch.concat((x, X_extra), axis=2) return x
# %%
[docs]class SubCrab(nn.Module): """SubCrab model class which implements the transformer architecture."""
[docs] def __init__( self, out_dims=3, d_model=512, extend_features=None, d_extend=0, N=3, heads=4, fractional=False, attention=True, compute_device=None, out_hidden=[1024, 512, 256, 128], pe_resolution=5000, ple_resolution=5000, elem_prop="mat2vec", bias=False, emb_scaler=1.0, pos_scaler=1.0, pos_scaler_log=1.0, dim_feedforward=2048, dropout=0.1, ): """Instantiate a SubCrab class to be used within CrabNet. Parameters ---------- out_dims : int, optional Output dimensions for Residual Network, by default 3 d_model : int, optional Model size. See paper, by default 512 extend_features : _type_, optional Additional features to grab from columns of the other DataFrames (e.g. state variables such as temperature or applied load), by default None d_extend : int, optional Number of extended features, by default 0 N : int, optional Number of attention layers, by default 3 heads : int, optional Number of attention heads, by default 4 frac : bool, optional Whether to multiply `x` by the fractional amounts for each element, by default False attn : bool, optional Whether to perform self attention, by default True compute_device : _type_, optional Computing device to run model on, by default None out_hidden : list(int), optional Architecture of hidden layers in the Residual Network, by default [1024, 512, 256, 128] pe_resolution : int, optional Number of discretizations for the prevalence encoding, by default 5000 ple_resolution : int, optional Number of discretizations for the prevalence log encoding, by default 5000 elem_prop : str, optional Which elemental feature vector to use. Possible values are "jarvis", "magpie", "mat2vec", "oliynyk", "onehot", "ptable", and "random_200", by default "mat2vec" bias : bool, optional Whether to bias the Residual Network, by default False emb_scaler : float, optional Float value by which to scale the elemental embeddings, by default 1.0 pos_scaler : float, optional Float value by which to scale the fractional encodings, by default 1.0 pos_scaler_log : float, optional Float value by which to scale the log fractional encodings, by default 1.0 dim_feedforward : int, optional Dimenions of the feed forward network following transformer, by default 2048 dropout : float, optional Percent dropout in the feed forward network following the transformer, by default 0.1 """ super().__init__() self.avg = True self.out_dims = out_dims self.d_model = d_model self.extend_features = extend_features self.d_extend = d_extend self.N = N self.heads = heads self.fractional = fractional self.attention = attention self.compute_device = compute_device self.bias = bias self.encoder = Encoder( d_model=self.d_model, N=self.N, heads=self.heads, attention=self.attention, compute_device=self.compute_device, pe_resolution=pe_resolution, ple_resolution=ple_resolution, elem_prop=elem_prop, emb_scaler=emb_scaler, pos_scaler=pos_scaler, pos_scaler_log=pos_scaler_log, dim_feedforward=dim_feedforward, dropout=dropout, ) self.out_hidden = out_hidden self.output_nn = ResidualNetwork( self.d_model + self.d_extend, self.out_dims, self.out_hidden, self.bias, )
[docs] def forward(self, src, frac, extra_features=None): """Compute forward pass of the SubCrab model class (i.e. transformer). Parameters ---------- src : torch.tensor Tensor containing element numbers corresponding to elements in compound frac : torch.tensor Tensor containing fractional amounts of each element in compound extra_features : bool, optional Whether to append extra features after encoding, by default None Returns ------- torch.tensor Model output containing predicted value and uncertainty for that value """ output = self.encoder(src, frac, extra_features) # output = self.transfer_nn(output) # average the "element contribution", mask so you only average "elements" (i.e. # not padded zero values) elem_pad_mask = (src == 0).unsqueeze(-1).repeat(1, 1, self.out_dims) output = self.output_nn(output) # simple linear if self.avg: output = output.masked_fill(elem_pad_mask, 0) output = output.sum(dim=1) / (~elem_pad_mask).sum(dim=1) output, logits = output.chunk(2, dim=-1) probability = torch.ones_like(output) probability[:, : logits.shape[-1]] = torch.sigmoid(logits) output = output * probability return output
# %% if __name__ == "__main__": model = SubCrab()