Compositionally-Restricted Attention-Based Network (CrabNet)
The Compositionally-Restricted Attention-Based Network (CrabNet
), inspired by natural language processing transformers, uses compositional information to predict material properties.
:warning: This is a fork of the original CrabNet repository :warning:
This is a refactored version of CrabNet, published to PyPI (pip
) and Anaconda
(conda
). In addition to using .csv
files, it allows direct passing of Pandas
DataFrames as training and validation datasets, similar to
automatminer. It also exposes many of
the model parameters at the top-level via CrabNet
and uses the sklearn
-like “instantiate, fit, predict” workflow. An extend_features
is
implemented which allows utilization of data other than the elemental compositions (e.g.
state variables such as temperature or applied load). These changes make CrabNet
portable, extensible, and more broadly applicable, and will be incorporated into the parent repository at a later
date. Please refer to the CrabNet documentation for details on installation and usage. If you find CrabNet useful, please consider citing the following publication in npj Computational Materials:
Citing
@article{Wang2021crabnet,
author = {Wang, Anthony Yu-Tung and Kauwe, Steven K. and Murdock, Ryan J. and Sparks, Taylor D.},
year = {2021},
title = {Compositionally restricted attention-based network for materials property predictions},
pages = {77},
volume = {7},
number = {1},
doi = {10.1038/s41524-021-00545-1},
publisher = {{Nature Publishing Group}},
shortjournal = {npj Comput. Mater.},
journal = {npj Computational Materials}
}
Installation
conda install -c sgbaird crabnet
or
Install PyTorch (specific to your hardware, e.g. pip install torch==1.10.2+cu102 -f https://download.pytorch.org/whl/cu102/torch_stable.html
), then
pip install crabnet
Basic Usage
Load Some Test Data
from crabnet.model import get_data
from crabnet.data.materials_data import elasticity
train_df, val_df = get_data(elasticity, "train.csv")
Instantiate CrabNet Model
from crabnet.crabnet_ import CrabNet
cb = CrabNet(mat_prop="elasticity")
Training
cb.fit(train_df)
Predictions
Predict on the training data:
train_pred, train_sigma = cb.predict(train_df, return_uncertainty=True)
Predict on the validation data:
val_pred, val_sigma = cb.predict(val_df)
Extend Features
To include additional features that get added after the transformer architecture, but before a recurrent neural network, include the additional features in your DataFrames and pass the name(s) of these additional features (i.e. columns) as a list into extend_features
.
train_df["state_var0"] = np.random.rand(train_df.shape[0]) # dummy state variable
cb = CrabNet(
mat_prop="hardness",
train_df=train_df, # contains "formula", "target", and "state_var0" columns
extend_features=["state_var0"],
)
Reproduce publication results
To reproduce the publication results, please see the README instructions for CrabNet versions v1.. or earlier. For example, the first release: https://github.com/sparks-baird/CrabNet/releases/tag/release-for-chemrxiv. Trained weights are provided at: http://doi.org/10.5281/zenodo.4633866.
As a reference, with a desktop computer with an IntelTM i9-9900K processor, 32GB of RAM, and two NVIDIA RTX 2080 Ti’s, training our largest network (OQMD) takes roughly two hours.