Compositionally-Restricted Attention-Based Network (CrabNet)

The Compositionally-Restricted Attention-Based Network (CrabNet), inspired by natural language processing transformers, uses compositional information to predict material properties.

DOI

Open In Colab(PyPI) Read the Docs GitHub WorkflowStatus

PyPI Code style:black Lines of code GitHub

Conda Conda Conda Anaconda-Server Badge

:warning: This is a fork of the original CrabNet repository :warning:

This is a refactored version of CrabNet, published to PyPI (pip) and Anaconda (conda). In addition to using .csv files, it allows direct passing of Pandas DataFrames as training and validation datasets, similar to automatminer. It also exposes many of the model parameters at the top-level via CrabNet and uses the sklearn-like “instantiate, fit, predict” workflow. An extend_features is implemented which allows utilization of data other than the elemental compositions (e.g. state variables such as temperature or applied load). These changes make CrabNet portable, extensible, and more broadly applicable, and will be incorporated into the parent repository at a later date. Please refer to the CrabNet documentation for details on installation and usage. If you find CrabNet useful, please consider citing the following publication in npj Computational Materials:

Citing

@article{Wang2021crabnet,
 author = {Wang, Anthony Yu-Tung and Kauwe, Steven K. and Murdock, Ryan J. and Sparks, Taylor D.},
 year = {2021},
 title = {Compositionally restricted attention-based network for materials property predictions},
 pages = {77},
 volume = {7},
 number = {1},
 doi = {10.1038/s41524-021-00545-1},
 publisher = {{Nature Publishing Group}},
 shortjournal = {npj Comput. Mater.},
 journal = {npj Computational Materials}
}

Installation

conda install -c sgbaird crabnet

or

Install PyTorch (specific to your hardware, e.g. pip install torch==1.10.2+cu102 -f https://download.pytorch.org/whl/cu102/torch_stable.html), then

pip install crabnet

Basic Usage

Load Some Test Data

from crabnet.model import get_data
from crabnet.data.materials_data import elasticity
train_df, val_df = get_data(elasticity, "train.csv")

Instantiate CrabNet Model

from crabnet.crabnet_ import CrabNet

cb = CrabNet(mat_prop="elasticity")

Training

cb.fit(train_df)

Predictions

Predict on the training data:

train_pred, train_sigma = cb.predict(train_df, return_uncertainty=True)

Predict on the validation data:

val_pred, val_sigma = cb.predict(val_df)

Extend Features

To include additional features that get added after the transformer architecture, but before a recurrent neural network, include the additional features in your DataFrames and pass the name(s) of these additional features (i.e. columns) as a list into extend_features.

train_df["state_var0"] = np.random.rand(train_df.shape[0]) # dummy state variable
cb = CrabNet(
    mat_prop="hardness",
    train_df=train_df, # contains "formula", "target", and "state_var0" columns
    extend_features=["state_var0"],
    )

Reproduce publication results

To reproduce the publication results, please see the README instructions for CrabNet versions v1.. or earlier. For example, the first release: https://github.com/sparks-baird/CrabNet/releases/tag/release-for-chemrxiv. Trained weights are provided at: http://doi.org/10.5281/zenodo.4633866.

As a reference, with a desktop computer with an IntelTM i9-9900K processor, 32GB of RAM, and two NVIDIA RTX 2080 Ti’s, training our largest network (OQMD) takes roughly two hours.