m3gnet_fitting

Contents

m3gnet_fitting#

autoplex.fitting.common.utils.m3gnet_fitting(db_dir, path_to_hyperparameters=MLIP_RSS_DEFAULTS_FILE_PATH, device='cuda', ref_energy_name='REF_energy', ref_force_name='REF_forces', ref_virial_name='REF_virial', fit_kwargs=None)[source]#

Perform the M3GNet potential fitting.

Parameters:
  • db_dir (Path) – Directory containing the training and testing data files.

  • path_to_hyperparameters (str or Path.) – Path to JSON file containing the M3GNet hyperparameters.

  • device (str) – Device on which the model will be trained, e.g., ‘cuda’ or ‘cpu’.

  • ref_energy_name (str, optional) – Reference energy name.

  • ref_force_name (str, optional) – Reference force name.

  • ref_virial_name (str, optional) – Reference virial name.

  • fit_kwargs (dict.) – optional dictionary with parameters for m3gnet fitting with keys same as mlip-rss-defaults.json.

Keyword Arguments:
  • exp_name (str) – Name of the experiment, used for saving model checkpoints and logs.

  • results_dir (str) – Directory to store the training results and fitted model.

  • cutoff (float) – Cutoff radius for atomic interactions in length units.

  • threebody_cutoff (float) – Cutoff radius for three-body interactions in length units.

  • batch_size (int) – Number of structures per batch during training.

  • max_epochs (int) – Maximum number of training epochs.

  • include_stresses (bool) – If True, includes stress tensors in the model predictions and training process.

  • hidden_dim (int) – Dimensionality of the hidden layers in the model.

  • num_units (int) – Number of units in each dense layer of the model.

  • max_l (int) – Maximum degree of spherical harmonics.

  • max_n (int) – Maximum radial function degree.

  • test_equal_to_val (bool) – If True, the testing dataset will be the same as the validation dataset.

Returns:

A dictionary containing keys such as ‘train_error’, ‘test_error’, and ‘path_to_fitted_model’, representing the training error, test error, and the location of the saved model, respectively.

Return type:

dict[str, float]

References