Source code for pyxtal.interface.ase_opt

import signal
from time import time
import numpy as np
from ase.constraints import FixSymmetry
from ase.filters import UnitCellFilter
from ase.optimize.fire import FIRE
import logging

import os
from mace.calculators import mace_mp
_cached_mace_mp = None


[docs] def get_calculator(calculator): global _cached_mace_mp if type(calculator) is str: if calculator == 'ANI': import torchani calc = torchani.models.ANI2x().ase() else: if _cached_mace_mp is None: _cached_mace_mp = mace_mp(model='small', dispersion=True, device='cpu') calc = _cached_mace_mp else: calc = calculator return calc
[docs] def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile=None, max_time=10.0, label='ase'): #def ASE_relax(struc, calculator, opt_cell=False, step=500, fmax=0.1, logfile='ase.log', max_time=10.0, label='ase'): """ ASE optimizer Args: struc: ase atoms object calculator (str): 'ANI', 'MACE' step: optimization steps (int) max_time: float (minutes) """ def handler(signum, frame): raise TimeoutError("Optimization timed out") step_init = min([30, int(step/2)]) logger = logging.getLogger() max_time *= 60 timeout = int(max_time) signal.signal(signal.SIGALRM, handler) signal.alarm(timeout) #logger.info(f"{label} start calculation") _fmax = 1e+5 try: #if True: calc = get_calculator(calculator) struc.set_calculator(calc) struc.set_constraint(FixSymmetry(struc)) if opt_cell: ecf = UnitCellFilter(struc) dyn = FIRE(ecf, a=0.1, logfile=logfile) if logfile is not None else FIRE(ecf, a=0.1) else: dyn = FIRE(struc, a=0.1, logfile=logfile) if logfile is not None else FIRE(struc, a=0.1) # Run relaxation dyn.run(fmax=fmax, steps=step_init) forces = dyn.optimizable.get_forces() _fmax = np.sqrt((forces ** 2).sum(axis=1).max()) if _fmax < 1e+3 and step > step_init: dyn.run(fmax=fmax, steps=step-step_init) forces = dyn.optimizable.get_forces() _fmax = np.sqrt((forces ** 2).sum(axis=1).max()) eng = struc.get_potential_energy() / len(struc) if _fmax > 100: logger.info(f"Warning {label} big stress {eng:.2f} / {_fmax:.2f}, skip") struc = None else: logger.info(f"{label} Success {eng:.2f} / {_fmax:.2f}") else: logger.info(f"Warning {label} big stress {_fmax:.2f} for 20 steps, skip") struc = None signal.alarm(0) # Cancel the alarm if finished within time except TimeoutError: logger.info(f"Warning {label} timed out after {timeout} seconds.") struc = None except TypeError: logger.info(f"Warning {label} spglib error in getting the lattice") struc = None signal.alarm(0) # Cancel the alarm if finished within time tag = 'False' if struc is None else 'True' logger.info(f"Finishing {label} {tag}") #signal.alarm(0) # Cancel the alarm return struc #, eng, _fmax
[docs] class ASE_optimizer: """ This is a ASE optimizer to perform oragnic crystal structure optimization. We assume that the geometry has been well optimized by classical FF Args: struc: pyxtal object calculator (str): 'ANI', 'MACE' opt_lat (bool): to opt lattice or not log_file (str): output file """ def __init__(self, struc, calculator='MACE', opt_lat=True, logfile=None): self.structure = struc self.calculator = get_calculator(calculator) self.opt_lat = opt_lat self.stress = None self.forces = None self.optimized = True self.positions = None self.cell = None self.cputime = 0 self.logfile = logfile
[docs] def run(self, steps=10): t0 = time() s = self.structure.to_ase(resort=False) s.set_constraint(FixSymmetry(s)) s.set_calculator(self.calculator)#; print("Setup Fire") if not self.opt_lat: dyn = FIRE(s, a=0.1, logfile=self.logfile)#, force_consistent=False) dyn.run(fmax=0.1, steps=steps) else: #ecf = FrechetCellFilter(s) ecf = UnitCellFilter(s) dyn = FIRE(ecf, a=0.1, logfile=self.logfile)#, force_consistent=False) dyn.run(fmax=0.1, steps=steps) self.structure.lattice.set_matrix(s.get_cell()) positions = s.get_scaled_positions() #try: if True: # s.write('../1.cif', format='cif') count = 0 for _i, site in enumerate(self.structure.mol_sites): coords0, _ = site._get_coords_and_species(first=True) coords1 = positions[count : count + len(site.molecule.mol)] for j, coor in enumerate(coords1): diff = coor - coords0[j] diff -= np.round(diff) abs_diff = np.dot(diff, s.get_cell()) # print(j, coor, coords0[j], diff, np.linalg.norm(abs_diff)) if abs(np.linalg.norm(abs_diff)) < 2.0: coords1[j] = coords0[j] + diff else: print(coords1[j], coords1[j], np.linalg.norm(abs_diff)) import sys; sys.exit() site.update(coords1, self.structure.lattice) count += len(site.molecule.mol) * site.wp.multiplicity self.structure.optimize_lattice() self.structure.energy = s.get_potential_energy() self.cell = s.get_cell() # print(self.structure.lattice) #except: # self.structure.energy = 10000 # self.optimized = False # print("Structure is wrong after optimization") s.set_calculator() s.set_constraint() self.cputime = time() - t0
if __name__ == "__main__": import os, warnings from pyxtal.db import database warnings.filterwarnings("ignore") work_dir = "tmp" if not os.path.exists(work_dir): os.makedirs(work_dir) db = database("pyxtal/database/test.db") struc = db.get_pyxtal("ACSALA") calc = ASE_optimizer(struc) print(calc.structure.lattice) calc.run() print(calc.structure.energy) print(calc.structure.lattice)