nep_fitting

Contents

nep_fitting#

autoplex.fitting.common.utils.nep_fitting(db_dir, hyperparameters=NEP_HYPERS, ref_energy_name='REF_energy', ref_force_name='REF_forces', ref_virial_name='REF_virial', species_list=None, gpu_identifier_indices=list[0], fit_kwargs=None)[source]#

Perform the NEP (Neural evolution Potential) model fitting.

Parameters:
  • db_dir (Path) – Directory containing the training and testing data files.

  • path_to_hyperparameters (str or Path.) – Path to JSON file containing the M3GNet hyperparameters.

  • ref_energy_name (str, optional) – Reference energy name.

  • ref_force_name (str, optional) – Reference force name.

  • ref_virial_name (str, optional) – Reference virial name.

  • species_list (list) – List of element names (strings)

  • gpu_identifier_indices (list[int]) – Indices that identifies the GPU that NEP should be run with

  • fit_kwargs (dict.) – optional dictionary with parameters for NEP fitting with keys same as mlip-rss-defaults.json.

  • hyperparameters (NEPSettings(version=4, type=[1, 'X'], type_weight=1.0, model_type=0, prediction=0, cutoff=[6, 5], n_max=[4, 4], basis_size=[8, 8], l_max=[4, 2, 1], neuron=80, lambda_1=0.0, lambda_e=1.0, lambda_f=1.0, lambda_v=0.1, force_delta=0, batch=1000, population=60, generation=100000, zbl=2))

Keyword Arguments:
  • version (int) – NEP model version to train can be 3 or 4. Default is 4.

  • type (list[int, str]) – Number of atom types and list of chemical species. Number of atom types must be an integer, followed by chemical symbols of species as in periodic table for which model needs to be trained, separated by comma. Default is [1, “X”] as a placeholder. Example: [2, “Pb”, “Te”].

  • type_weight (float) – Weights for different chemical species. Default is 1.0

  • model_type (int) – Type of model that is being trained. Can be 0 (potential), 1 (dipole), 2 (polarizability). Default is 0.

  • prediction (int) – Mode of NEP run. Set 0 for training and 1 for inference. Default is 0.

  • cutoff (list[int, int]) – Radial and angular cutoff. First element is for radial cutoff and second element is for angular cutoff. Default is [6, 5].

  • n_max (list[int, int]) – Number of radial and angular descriptors. First element is for radial and second element is for angular. Default is [4, 4].

  • basis_size (list[int, int]) – Number of basis functions that are used to build the radial and angular descriptor. First element is for radial descriptor and second element is for angular descriptor. Default is [8, 8].

  • l_max (list[int, int, int]) – The maximum expansion order for the angular terms. First element is for three-body, second element is for four-body and third element is for five-body. Default is [4, 2, 1].

  • neuron (int) – Number of neurons in the hidden layer. Default is 80.

  • lambda_1 (float) – Weight for L1 regularization. Default is 0.

  • lambda_e (float) – Weight for energy loss. Default is 1.

  • lambda_f (float) – Weight for force loss. Default is 1.

  • lambda_v (float) – Weight for virial loss. Default is 0.1.

  • force_delta (float) – Sets bias the on the loss function to put more emphasis on obtaining accurate predictions for smaller forces. Default is 0.

  • batch (int) – Batch size for training. Default is 1000.

  • population (int) – Size of the population used by the SNES algorithm. Default is 50.

  • generation (bool) – Sets the max number of generations for SNES algorithm.

  • zbl (float) – Cutoff to use in universal ZBL potential at short distances. Acceptable values are in range 1 to 2.5. Default is 2.

Return type:

dict

References

Returns:

A dictionary mapping ‘train_error’, ‘test_error’, and ‘mlip_path’.

Return type:

dict[str, float]

Parameters:
  • db_dir (str | Path)

  • hyperparameters (NEPSettings(version=4, type=[1, 'X'], type_weight=1.0, model_type=0, prediction=0, cutoff=[6, 5], n_max=[4, 4], basis_size=[8, 8], l_max=[4, 2, 1], neuron=80, lambda_1=0.0, lambda_e=1.0, lambda_f=1.0, lambda_v=0.1, force_delta=0, batch=1000, population=60, generation=100000, zbl=2))

  • ref_energy_name (str)

  • ref_force_name (str)

  • ref_virial_name (str)

  • species_list (list | None)

  • gpu_identifier_indices (list[int])

  • fit_kwargs (dict | None)