Source code for pyxtal.db

"""
Database class
"""

import os
import logging

import numpy as np
import pymatgen.analysis.structure_matcher as sm
from ase.calculators.calculator import CalculationFailed
from ase.db import connect

from pyxtal import pyxtal
from pyxtal.util import ase2pymatgen


[docs] def setup_worker_logger(log_file): """ Set up the logger for each worker process. """ logging.getLogger().handlers.clear() logging.basicConfig(format="%(asctime)s| %(message)s", filename=log_file, level=logging.INFO)
[docs] def call_opt_single(p): """ Optimize a single structure and log the result. Args: p (tuple): A tuple where the first element is an identifier (id), and the remaining elements are the arguments to pass to `opt_single`. Returns: tuple: A tuple (id, xtal, eng) where: - id (int): The identifier of the structure. - xtal: The optimized structure. - eng (float): The energy of the opt_structure, or None if it failed. Behavior: This function calls `opt_single` to perform the optimization of the structure associated with the given id. """ #logger = logging.getLogger() #logger.info(f"ID: {p[0]} *{sum(p[1].numIons)}") myid = p[0] xtal, eng, status = opt_single(*p) return myid, xtal, eng
[docs] def opt_single(id, xtal, calc, *args): """ Optimize a structure using the specified calculator. Args: id (int): Identifier of the structure to be optimized. xtal: Crystal structure object to be optimized. calc (str): The calculator to use ('GULP', 'DFTB', 'VASP', 'MACE'). *args: Additional arguments to pass to the calculator function. Returns: tuple: The result of the optimization, which typically includes: - xtal: The optimized structure. - energy (float): The energy of the optimized structure. - status (bool): Whether the optimization was successful. Raises: ValueError: If an unsupported calculator is specified. """ if calc == 'GULP': return gulp_opt_single(id, xtal, *args) elif calc == 'DFTB': return dftb_opt_single(id, xtal, *args) elif calc == 'VASP': return vasp_opt_single(id, xtal, *args) elif calc == 'MACE': return mace_opt_single(id, xtal, *args) else: raise ValueError("Cannot support this calcultor", calc)
[docs] def dftb_opt_single(id, xtal, skf_dir, steps, symmetrize, criteria, kresol=0.05): """ Single DFTB optimization for a given atomic xtal Args: id (int): id of the give xtal xtal: pyxtal instance skf_dir (str): path of skf files steps (int): number of relaxation steps criteria (dicts): to check if the structure """ from pyxtal.interface.dftb import DFTB, DFTB_relax cwd = os.getcwd() atoms = xtal.to_ase(resort=False) eng = None stress = None try: if symmetrize: s = DFTB_relax( atoms, skf_dir, True, int(steps / 2), kresol=kresol, folder=".", scc_error=1e-5, scc_iter=100, logfile="ase.log", ) stress = np.sum(s.get_stress()[:3]) / 0.006241509125883258 / 3 else: my = DFTB( atoms, skf_dir, kresol=kresol * 1.5, folder=".", scc_error=0.1, scc_iter=100, ) s, eng = my.run(mode="vc-relax", step=int(steps / 2)) my = DFTB(s, skf_dir, kresol=kresol, folder=".", scc_error=1e-4, scc_iter=100) s, eng = my.run(mode="vc-relax", step=int(steps / 2)) s = my.struc except CalculationFailed: # This is due to covergence error in geometry optimization # Here we simply read the last energy my.calc.read_results() eng = my.calc.results["energy"] s = my.struc except: s = None print("Problem in DFTB Geometry optimization", id) xtal.to_file("bug.cif") # ; import sys; sys.exit() os.chdir(cwd) status = False if s is not None: c = pyxtal() c.from_seed(s) if eng is None: eng = s.get_potential_energy() / len(s) else: eng /= len(s) status = process_xtal(id, xtal, eng, criteria) print(xtal.get_xtal_string()) return xtal, eng, status else: return None, None, False
[docs] def vasp_opt_single(id, xtal, path, cmd, criteria): """ Single VASP optimization for a given atomic xtal Args: id (int): id of the give xtal xtal: pyxtal instance path: calculation folder cmd: vasp command criteria (dicts): to check if the structure """ from pyxtal.interface.vasp import optimize as vasp_opt cwd = os.getcwd() path += '/g' + str(id) status = False xtal, eng, _, error = vasp_opt(xtal, path, cmd=cmd, walltime="59m") if not error: status = process_xtal(id, xtal, eng, criteria) else: os.chdir(cwd) return xtal, eng, status
[docs] def gulp_opt_single(id, xtal, ff_lib, path, criteria): """ Perform a single GULP optimization for a given crystal structure. Args: id (int): Identifier for the current structure. xtal: PyXtal instance representing the crystal to be optimized. ff_lib (str): Force field library for GULP, e.g., 'reaxff', 'tersoff'. path (str): Path to the folder where the calculation is stored. criteria (dict): Dictionary to check the validity of the opt_structure. Returns: tuple: - xtal: Optimized PyXtal instance. - eng (float): Energy of the optimized structure. - status (bool): Whether the optimization process is successful. Behavior: This function performs a GULP optimization using the force field. After the optimization, it checks the validity of the structure and attempts to remove the calculation folder if it is empty. """ from pyxtal.interface.gulp import single_optimize as gulp_opt # Create the path for this specific structure path += '/g' + str(id) # Perform the optimization with GULP xtal, eng, _, error = gulp_opt( xtal, ff=ff_lib, label=str(id), path=path, symmetry=True, ) # Default status to False, will be updated if successful status = False if not error: status = process_xtal(id, xtal, eng, criteria) try: os.rmdir(path) except: print("Folder is not empty", path) return xtal, eng, status
[docs] def mace_opt_single(id, xtal, step, fmax, criteria): """ Perform a single MACE optimization for a given atomic crystal structure. Args: id (int): Identifier for the current structure. xtal: PyXtal instance representing the crystal structure. step (int): Maximum number of relaxation steps. Default is 250. fmax (float): fmax for relaxation criteria (dict): Dictionary to check the validity of the optimized structure. Returns: tuple: - xtal: Optimized PyXtal instance (or None if optimization failed). - eng (float): Energy/atom of the opt_structure (or None if it failed). - status (bool): Whether the optimization was successful. """ from pyxtal.interface.ase_opt import ASE_relax as mace_opt logger = logging.getLogger() atoms = xtal.to_ase(resort=False) s = mace_opt(atoms, 'MACE', opt_cell=True, step=step, fmax=fmax, max_time=9.0 * max([1, (len(atoms)/200)]), label=str(id)) if s is None: logger.info(f"mace_opt_single Failure {id}") return None, None, False try: xtal = pyxtal() xtal.from_seed(s) eng = s.get_potential_energy() / len(s) status = process_xtal(id, xtal, eng, criteria) logger.info(f"mace_opt_single Success {id}") return xtal, eng, status except: logger.info(f"mace_opt_single Bug {id}") return None, None, False
[docs] def process_xtal(id, xtal, eng, criteria): status = xtal.check_validity( criteria) if criteria is not None else True if status: header = f"{id:4d}" dicts = {"validity": status, "energy": eng} print(xtal.get_xtal_string(header=header, dicts=dicts)) return status
[docs] def make_entry_from_pyxtal(xtal): """ Generate an entry dictionary from a PyXtal object, assuming the SMILES and CCDC number information is provided. Args: xtal: PyXtal object (must contain the SMILES (`xtal.tag["smiles"]`) and CCDC number (`xtal.tag["ccdc_number"]`) in the `xtal.tag`. Returns: tuple: (ase_atoms, entry_dict, None) - ase_atoms: ASE Atoms object converted from the PyXtal structure. - entry_dict (dict): A dictionary containing information - None: Placeholder for future use (currently returns None). Structure of `entry_dict`: - "csd_code" (str): CSD code (if available) for the crystal structure. - "mol_smi" (str): SMILES representation of the molecule. - "ccdc_number" (str): CCDC identifier number. - "space_group" (str): Space group symbol of the crystal. - "spg_num" (int): Space group number. - "Z" (int): Number of molecules in the unit cell. - "Zprime" (float): Z' value of the crystal. - "url" (str): URL link to the CCDC database entry for the crystal. - "mol_formula" (str): Molecular formula of the structure. - "mol_weight" (float): Molecular weight of the structure. - "mol_name" (str): Name of the molecule, typically the CSD code. - "l_type" (str): Lattice type of the structure. Returns None if the PyXtal structure is invalid (i.e., `xtal.valid` is False). Example: entry = make_entry_from_pyxtal(xtal_instance) ase_atoms, entry_dict, _ = entry Notes: - The CCDC link is generated using the structure's CCDC number. """ from rdkit import Chem from rdkit.Chem.Descriptors import ExactMolWt from rdkit.Chem.rdMolDescriptors import CalcMolFormula if xtal.valid: url0 = "https://www.ccdc.cam.ac.uk/structures/Search?Ccdcid=" # Create RDKit molecule from SMILES string m = Chem.MolFromSmiles(xtal.tag["smiles"]) # Calculate molecular weight and molecular formula using RDKit mol_wt = ExactMolWt(m) mol_formula = CalcMolFormula(m) # Create a dictionary containing information kvp = { "csd_code": xtal.tag["csd_code"], "mol_smi": xtal.tag["smiles"], "ccdc_number": xtal.tag["ccdc_number"], "space_group": xtal.group.symbol, "spg_num": xtal.group.number, "Z": sum(xtal.numMols), "Zprime": xtal.get_zprime()[0], "url": url0 + str(xtal.tag["ccdc_number"]), "mol_formula": mol_formula, "mol_weight": mol_wt, "mol_name": xtal.tag["csd_code"], "l_type": xtal.lattice.ltype, } # Return the ASE Atoms the entry dictionary, and None as a placeholder return (xtal.to_ase(), kvp, None) else: return None
[docs] def make_entry_from_CSD_web(code, number, smiles, name=None): """ make enetry dictionary from csd web https://www.ccdc.cam.ac.uk/structures Args: code: CSD style letter entry number: ccdc number smiles: the corresponding molecular smiles name: name of the compound """ # xtal = pyxtal(molecular=True) # # return make_entry_from_pyxtal(xtal) raise NotImplementedError("To do in future")
[docs] def make_entry_from_CSD(code): """ make entry dictionary from CSD codes Args: code: a list of CSD codes """ from pyxtal.msg import CSDError xtal = pyxtal(molecular=True) try: xtal.from_CSD(code) return make_entry_from_pyxtal(xtal) except CSDError as e: print("CSDError", code, e.message) return None
[docs] def make_db_from_CSD(dbname, codes): """ make database from CSD codes Args: dbname: db file name codes: a list of CSD codes """ # open db = database(dbname) # add structure for i, code in enumerate(codes): entry = make_entry_from_CSD(code) if entry is not None: db.add(entry) print(i, code) return db
[docs] class database: """ This is a database class to process crystal data. Args: db_name: `*.db` format from ase database """ def __init__(self, db_name): self.db_name = db_name # if not os.path.exists(db_name): # raise ValueError(db_name, 'doesnot exist') self.db = connect(db_name, serial=True) self.codes = self.get_all_codes() self.keys = [ "csd_code", "space_group", "spg_num", "Z", "Zprime", "url", "mol_name", "mol_smi", "mol_formula", "mol_weight", "l_type", ] self.calculators = [ "charmm", "gulp", "ani", "dftb_D3", "dftb_TS", ]
[docs] def vacuum(self): self.db.vacuum()
[docs] def get_all_codes(self, group=None): """ Get all codes """ codes = [] for row in self.db.select(): if row.csd_code not in codes: if group is None: codes.append(row.csd_code) else: if row.group == group: codes.append(row.csd_code) else: print("find duplicate! remove", row.id, row.csd_code) self.db.delete([row.id]) return codes
# self.codes = codes
[docs] def add(self, entry): (atom, kvp, data) = entry if kvp["csd_code"] not in self.codes: kvp0 = self.process_kvp(kvp) self.db.write(atom, key_value_pairs=kvp0, data=data) self.codes.append(kvp["csd_code"])
[docs] def add_from_code(self, code): entry = make_entry_from_CSD(code) if entry is not None: self.add(entry) else: print(f"{code:s} is not a valid entry")
[docs] def process_kvp(self, kvp): kvp0 = {} for key in self.keys: if key in kvp: kvp0[key] = kvp[key] else: print("Error, cannot find ", key, " from the input") return None return kvp0
[docs] def check_status(self, show=False): """ Check the current status of each entry """ ids = [] for row in self.db.select(): if len(row.data.keys()) == len(self.calculators): ids.append(row.id) if show: row_info = self.get_row_info(id=row.id) self.view(row_info) else: print(row.csd_code) # , row.data['charmm_info']['prm']) return ids
[docs] def copy(self, db_name, csd_codes): """ copy the entries to another db Args: db_name: db file name csd_codes: list of codes """ if db_name == self.db_name: raise RuntimeError("Cannot use the same db file for copy") with connect(db_name, serial=True) as db: for csd_code in csd_codes: row_info = self.get_row_info(code=csd_code) (atom, kvp, data) = row_info db.write(atom, key_value_pairs=kvp, data=data)
[docs] def view(self, row_info): """ print the summary of benchmark results Args: row: row object """ from pyxtal.representation import representation (atom, kvp, data) = row_info # Reference xtal = self.get_pyxtal(kvp["csd_code"]) rep = xtal.get_1D_representation() print("\n", kvp["csd_code"], kvp["mol_smi"], xtal.lattice.volume) print(rep.to_string() + " reference") # calcs for key in data: calc = key[:-5] time = data[key]["time"] rep = data[key]["rep"] if type(rep[0]) is not list: rep = [rep] rep = representation(rep, kvp["mol_smi"]).to_string() (dv, msd1, msd2) = data[key]["diff"] strs = f"{rep:s} {calc:8s} {time / 60:6.2f} {dv:6.3f}" if msd1 is not None: strs += f"{msd1:6.3f}{msd2:6.3f}" print(strs)
[docs] def get_row_info(self, id=None, code=None): match = False if id is not None: for row in self.db.select(id=id): match = True break elif code is not None: for row in self.db.select(csd_code=code): match = True break if match: kvp = {} for key in self.keys: kvp[key] = row.key_value_pairs[key] data0 = {} for calc in self.calculators: key = calc + "_info" if key in row.data: data0[key] = row.data[key] atom = self.db.get_atoms(id=row.id) return (atom, kvp, data0) else: msg = "cannot find the entry from " + id + code raise RuntimeError(msg)
[docs] def get_row(self, code): for row in self.db.select(csd_code=code): return row msg = "cannot find the entry from " + code raise RuntimeError(msg)
[docs] def get_pyxtal(self, code): from pyxtal import pyxtal from pyxtal.msg import ReadSeedError from pyxtal.util import ase2pymatgen row = self.get_row(code) atom = self.db.get_atoms(id=row.id) # Reference pmg = ase2pymatgen(atom) smi = row.mol_smi smiles = smi.split(".") molecules = [smile + ".smi" for smile in smiles] xtal = pyxtal(molecular=True) try: xtal.from_seed(pmg, molecules=molecules) except ReadSeedError: xtal.from_seed(pmg, molecules=molecules, add_H=True) return xtal
[docs] def compute(self, row, work_dir, skf_dir): if len(row.data.keys()) < len(self.calculators): # not label information, run antechamber atom = self.db.get_atoms(id=row.id) if "gulp_info" not in row.data: # pmg, c_info, g_info = get_parameters(row, atom) # row.data = {"charmm_info": c_info, "gulp_info": g_info} pass else: pmg = ase2pymatgen(atom)
#data = compute(row, pmg, work_dir, skf_dir) #self.db.update(row.id, data=data) #print("updated the data for", row.csd_code)
[docs] class database_topology: """ This is a database class to process atomic crystal data Args: db_name (str): `*.db` format from ase database rank (int): default 0 size (int): default 1 ltol (float): lattice tolerance stol (float): site tolerance atol (float): angle tolerance log_file (str): log_file """ def __init__(self, db_name, rank=0, size=1, ltol=0.05, stol=0.05, atol=3, log_file='db.log'): self.rank = rank self.size = size self.db_name = db_name self.db = connect(db_name, serial=True) self.keys = [ "space_group_number", "pearson_symbol", "similarity0", "similarity", "density", "dof", "topology", "topology_detail", "dimension", "wps", "ff_energy", "ff_lib", "ff_relaxed", "mace_energy", "mace_relaxed", "dftb_energy", "dftb_relaxed", "vasp_energy", "vasp_relaxed", ] self.matcher = sm.StructureMatcher( ltol=ltol, stol=stol, angle_tol=atol) # Define logfile self.log_file = log_file logging.getLogger().handlers.clear() logging.basicConfig(format="%(asctime)s| %(message)s", filename=self.log_file, level=logging.INFO) self.logging = logging
[docs] def vacuum(self): self.db.vacuum()
[docs] def print_memory_usage(self): import psutil process = psutil.Process(os.getpid()) mem = process.memory_info().rss / 1024 ** 2 self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB") print(f"Rank {self.rank} memory: {mem:.1f} MB")
[docs] def get_pyxtal(self, id, use_relaxed=None, tol=1e-4): """ Get pyxtal based on row_id, if use_relaxed, get pyxtal from ff_relaxed Args: id (int): row id use_relaxed (str): 'ff_relaxed', 'vasp_relaxed' """ from pymatgen.core import Structure from pyxtal import pyxtal from pyxtal.util import ase2pymatgen from pyxtal.symmetry import Group row = self.db.get(id) #print(row.space_group_number, row.topology, row.pearson_symbol, row.wps, row.mace_energy) if use_relaxed is not None: if hasattr(row, use_relaxed): xtal_str = getattr(row, use_relaxed) pmg = Structure.from_str(xtal_str, fmt="cif") else: print(f"No {use_relaxed} attributes for structure", id) atom = self.db.get_atoms(id=id) pmg = ase2pymatgen(atom) else: #hn = Group(row.space_group_number).hall_number xtal1 = pyxtal() #xtal1.from_seed('1.cif', tol=tol)#, hn=hn) atom = self.db.get_atoms(id=id) #print(row.topology) #atom.write('1.cif', format='cif')#, direct=True, vasp5=True) xtal1.from_seed(atom, tol=tol)#, hn=hn) pmg = ase2pymatgen(atom) xtal = pyxtal() try: xtal.from_seed(pmg, tol=tol) #if xtal.group.number != row.space_group_number: print(xtal); import sys; sys.exit() if xtal is not None and xtal.valid: for key in self.keys: if hasattr(row, key): setattr(xtal, key, getattr(row, key)) return xtal except: print(xtal_str) #import sys; sys.exit() print("Cannot load the structure")
[docs] def get_all_xtals(self, include_energy=False): """ Get all pyxtal instances from the current db """ xtals = [] for row in self.db.select(): xtal = self.get_pyxtal(id=row.id) if xtal is not None: if include_energy and hasattr(row, 'vasp_energy'): xtal.energy = row.vasp_energy xtals.append(xtal) return xtals
[docs] def add_xtal(self, xtal, kvp={}): """ Add new xtal to the given db """ spg_num = xtal.group.number density = xtal.get_density() dof = xtal.get_dof() wps = [s.wp.get_label() for s in xtal.atom_sites] _kvp = { "space_group_number": spg_num, "pearson_symbol": xtal.get_Pearson_Symbol(), "wps": str(wps), "density": density, "dof": dof, } kvp.update(_kvp) atoms = xtal.to_ase(resort=False) self.db.write(atoms, key_value_pairs=kvp)
[docs] def add_strucs_from_db(self, db_file, check=False, id_min=0, id_max=None, tol=1e-3, freq=50, use_relaxed=None, sort=None, max_count=None, criteria=None, min_atoms=0, max_atoms=250, ignore_check='vasp_energy', same_number=False): """ Add new structures from another database file. Args: db_file (str): Path to the source database file check (bool): Whether to check if structure already exists before adding id_min (int): Starting ID to import from source database. Default is 0 id_max (int): Ending ID to import from source database. Default is None tol (float): Tolerance in Angstroms for symmetry detection. Default is 1e-3 freq (int): Print progress message every N structures. Default is 50 use_relaxed (str): Relaxed structure to use - 'ff_relaxed', 'vasp_relaxed' Default is None to use unrelaxed structures sort (str): key to sort the structure, e.g. 'mace_energy' Default is None to use row.id max_count (int): Number of maximum structure to add Default is None to all all structures criteria (dict): criteria to check if a valid. """ cifname = 'my_add.cif' print(f"\nAdding new strucs from {db_file:s}") count = 0 with connect(db_file, serial=True) as db: if id_max is None: id_max = db.count() for row in db.select(sort=sort): if id_min <= row.id <= id_max: xtal = pyxtal() if use_relaxed is None: atoms = row.toatoms() try: xtal.from_seed(atoms, tol=tol) except: xtal = None content = (row.mace_energy, row.pearson_symbol, row.wps) print("Faild to load xtal", content) else: with open(cifname, 'w') as f: f.write(getattr(row, use_relaxed)) try: xtal.from_seed(cifname, tol=tol) except: xtal = None print("Faild to load xtal", row.mace_energy, row.pearson_symbol, row.wps) if xtal is not None and xtal.valid: if criteria is not None and not xtal.check_validity(criteria): print("hit a invalid structure", row.id, xtal.get_xtal_string()) print(xtal) continue #print(xtal) mace_eng = None if not hasattr(row, 'mace_energy') else row.mace_energy atoms = xtal.to_ase() add = True if check: if not hasattr(row, ignore_check): add = self.check_new_structure(xtal, mace_eng, max_atoms=max_atoms, min_atoms=min_atoms, same_number=same_number) if add: kvp = {} for key in self.keys: if key == "space_group_number": kvp[key] = xtal.group.number elif key == "density": kvp[key] = xtal.get_density() elif key == "dof": kvp[key] = xtal.get_dof() elif key == "wps": kvp[key] = str([s.wp.get_label() for s in xtal.atom_sites]) elif key == "pearson_symbol": kvp[key] = xtal.get_Pearson_Symbol() elif hasattr(row, key): kvp[key] = getattr(row, key) self.db.write(atoms, key_value_pairs=kvp) count += 1 if count % freq == 0: print(f"Adding {count:4d} strucs from {db_file:s}") if max_count is not None and count == max_count: break else: print("Fail to convert xtal")
[docs] def check_new_structure(self, xtal, eng=None, same_group=False, same_number=False, d_tol=2e-1, e_tol=1e-2, max_atoms=250, min_atoms=0, return_id=False): """ Check if the input crystal structure already exists in the database. Args: xtal: PyXtal object representing the crystal structure to check eng (float, optional): Energy of the structure to compare same_group (bool): Whether to only compare structures with same space group d_tol (float): Tolerance for density comparison e_tol (float): Tolerance for energy comparison max_atoms (int): maximum number of atoms for checking min_atoms (int): minimum number of atoms for checking Returns: bool: True if structure is new/unique, False if it matches an existing structure Note: Compares structures based on: - Space group number (if same_group=True) - Density (within d_tol) - Energy (within e_tol if provided) - Structure similarity via pymatgen.analysis.structure_matcher """ if sum(xtal.numIons) > max_atoms or sum(xtal.numIons) < min_atoms: return True s_pmg = xtal.to_pymatgen() for row in self.db.select(sort='-id'): # Sort by id in descending order if row.natoms > max_atoms or row.natoms < min_atoms: continue if same_number and row.natoms != len(s_pmg): continue if eng is not None and abs(eng - row.mace_energy) > e_tol: continue if same_group and row.space_group_number != xtal.group.number: continue if abs(row.density - xtal.get_density()) > d_tol: continue ref = self.db.get_atoms(id=row.id) ref_pmg = ase2pymatgen(ref) if self.matcher.fit(s_pmg, ref_pmg, symmetric=True): print("skip the duplicate", xtal.get_xtal_string()) #print(row.id, row.space_group_number, row.wps, row.mace_energy, row.density) if return_id: return False, row.id else: return False if return_id: return True, None else: return True
[docs] def clean_structures_spg_topology(self, dim=None): """ Clean up the db by removing duplicate structures based on their properties. Args: dim (int, optional): Filter structures by dimension. Only keep structures with this dimension if specified. Defaults to None. The function removes structures that have identical: - Number of atoms - Space group - Topology - Wyckoff positions (wps) """ unique_rows = [] to_delete = [] for row in self.db.select(): unique = True # Ignore unwanted dimension if dim is not None and hasattr(row, "dimension") and row.dimension != dim: # print(row.dimension, dim) unique = False else: for prop in unique_rows: (natoms, spg, wps, topology) = prop if (natoms == row.natoms and spg == row.space_group_number and wps == row.wps) and hasattr( row, "topology" ): if row.topology == "aaa": if row.topology_detail == topology: unique = False break elif row.topology == topology: unique = False break if unique: if hasattr(row, "topology"): unique_rows.append( ( row.natoms, row.space_group_number, row.wps, row.topology if row.topology != "aaa" else row.topology_detail, ) ) else: unique_rows.append( (row.natoms, row.space_group_number, row.wps, None)) else: to_delete.append(row.id) if len(to_delete) > 0: print(len(to_delete), "structures were deleted", to_delete) self.db.delete(to_delete)
[docs] def get_row(self, id): for row in self.db.select(id=id): return row raise RuntimeError(msg)
[docs] def clean_structures(self, ids=(None, None), dtol=2e-3, etol=1e-3, criteria=None, eng_key='mace_energy'): """ Clean up the db by removing the duplicate structures Here we check the follow criteria - same number of atoms - same density - same energy Args: dtol (float): tolerance of density etol (float): tolerance of energy criteria (dict): including """ unique_rows = [] to_delete = [] ids, xtals = self.select_xtals(ids) for id, xtal in zip(ids, xtals): row = self.db.get(id) xtal = self.get_pyxtal(id) unique = True if criteria is not None: if not xtal.check_validity(criteria, True): unique = False print( "Found unsatisfied criteria", row.id, row.space_group_number, row.wps, ) if unique and ( "MAX_energy" in criteria and hasattr( row, eng_key) and getattr(row, key) > criteria["MAX_energy"] ): unique = False print( "Unsatisfied energy", row.id, getattr(row, eng_key), row.space_group_number, row.wps, ) if unique and ( "MAX_similarity" in criteria and hasattr(row, "similarity") and row.similarity > criteria["MAX_similarity"] ): unique = False print( "Unsatisfied similarity", row.id, row.similarity, row.space_group_number, row.wps, ) if unique and ( "BAD_topology" in criteria and hasattr(row, "topology") and row.topology[:3] in criteria["BAD_topology"] ): unique = False print( "Unsatisfied topology", row.id, row.topology, row.space_group_number, row.wps, ) if unique and ( "BAD_dimension" in criteria and hasattr(row, "dimension") and row.dimension in criteria["BAD_dimension"] ): unique = False print( "Unsatisfied dimension", row.id, row.topology, row.space_group_number, row.wps, ) if unique: for prop in unique_rows: (natoms, spg, wps, den, energy) = prop if natoms == row.natoms and spg == row.space_group_number and wps == row.wps: if hasattr(row, eng_key) and energy is not None: if abs(getattr(row, eng_key) - energy) < etol: unique = False break elif abs(den - row.density) < dtol: unique = False break if unique: if hasattr(row, eng_key): unique_rows.append( ( row.natoms, row.space_group_number, row.wps, row.density, row.ff_energy, ) ) else: unique_rows.append( (row.natoms, row.space_group_number, row.wps, row.density, None)) else: to_delete.append(row.id) print(len(to_delete), "structures were deleted", to_delete) self.db.delete(to_delete)
[docs] def clean_structures_pmg(self, ids=(None, None), min_id=None, dtol=5e-2, criteria=None): """ Clean up the database by removing duplicate structures based on density and pymatgen matcher. This method checks for duplicates by comparing structure density within a tolerance and using pymatgen's StructureMatcher. It can also filter structures based on various criteria like coordination numbers, energies, topology, etc. ids (tuple, optional): Range of IDs (min, max) to process. Defaults to (None, None). min_id (int, optional): Minimum ID to consider. Structures with lower IDs won't be deleted. Defaults to None. dtol (float, optional): Density tolerance for comparing structures. Defaults to 5e-2. criteria (dict, optional): Dictionary of filtering criteria. Defaults to None. Supported criteria keys: - 'CN': Dict of required coordination numbers per element - 'cutoff': Float, cutoff distance for connectivity - 'MAX_energy': Float, maximum allowed energy - 'MAX_similarity': Float, maximum allowed similarity value - 'BAD_topology': List of forbidden topology types - 'BAD_dimension': List of forbidden dimensionality values Example criteria: { 'CN': {'C': 3}, 'BAD_dimension': [0, 2] } Returns: None. Modifies database in place by deleting duplicate/invalid structures. """ unique_rows = [] to_delete = [] ids, xtals = self.select_xtals(ids) if min_id is None: min_id = min(ids) for id, xtal in zip(ids, xtals): row = self.db.get(id) xtal = self.get_pyxtal(id) unique = True if id > min_id and criteria is not None: if not xtal.check_validity(criteria, True): unique = False print( "Found unsatisfied criteria", row.id, row.space_group_number, row.wps, ) if unique and ( "MAX_energy" in criteria and hasattr( row, "ff_energy") and row.ff_energy > criteria["MAX_energy"] ): unique = False print( "Unsatisfied energy", row.id, row.ff_energy, row.space_group_number, row.wps, ) if unique and ( "MAX_similarity" in criteria and hasattr(row, "similarity") and row.similarity > criteria["MAX_similarity"] ): unique = False print( "Unsatisfied similarity", row.id, row.similarity, row.space_group_number, row.wps, ) if unique and ( "BAD_topology" in criteria and hasattr(row, "topology") and row.topology[:3] in criteria["BAD_topology"] ): unique = False print( "Unsatisfied topology", row.id, row.topology, row.space_group_number, row.wps, ) if unique and ( "BAD_dimension" in criteria and hasattr(row, "dimension") and row.dimension in criteria["BAD_dimension"] ): unique = False print( "Unsatisfied dimension", row.id, row.topology, row.space_group_number, row.wps, ) if unique and id > min_id: for prop in unique_rows: (rowid, den) = prop if abs(den - row.density) < dtol: ref_pmg = xtal.to_pymatgen() s_pmg = ase2pymatgen(self.db.get_atoms(id=rowid)) # , symmetric=True): if self.matcher.fit(s_pmg, ref_pmg): print( "Found duplicate", row.id, row.space_group_number, row.wps, ) unique = False break if unique: unique_rows.append((row.id, row.density)) else: to_delete.append(row.id) print(len(to_delete), "structures were deleted", to_delete) self.db.delete(to_delete)
[docs] def get_max_id(self): """ Get the maximum row id """ max_id = None for row in self.db.select(): if max_id is None or row.id > max_id: max_id = row.id + 1 return max_id
[docs] def select_xtals(self, ids, N_atoms=(None, None), overwrite=False, attribute=None, use_relaxed=None): """ Extract xtals based on attribute name. Args: ids (tuple): Minimum and maximum row IDs to extract, e.g. (1, 10) N_atoms (tuple): Minimum and maximum number of atoms to extract, e.g. (2, 100) overwrite (bool): Whether to overwrite existing entries attribute (str): Attribute name to check for extraction use_relaxed (str): Type of relaxed structure to use ('ff_relaxed' or 'vasp_relaxed') Returns: tuple: (ids, xtals) where ids is a list of row IDs and xtals is a list of corresponding pyxtal objects """ (min_id, max_id) = ids if min_id is None: min_id = 1 if max_id is None: max_id = self.get_max_id() (min_atoms, max_atoms) = N_atoms if min_atoms is None: min_atoms = 1 if max_atoms is None: max_atoms = 5000 ids, xtals = [], [] for row in self.db.select(): if attribute is None or (overwrite and not hasattr(row, attribute)): if min_id <= row.id <= max_id and min_atoms < row.natoms <= max_atoms: xtal = self.get_pyxtal(row.id, use_relaxed) ids.append(row.id) xtals.append(xtal) if len(xtals) % 100 == 0: print("Loading xtals from db", len(xtals)) return ids, xtals
[docs] def select_xtal(self, ids, N_atoms=(None, None), overwrite=False, attribute=None, use_relaxed=None): """ Lazy extraction of selected xtals from the database. Args: ids (tuple): Minimum and maximum row IDs to extract, e.g. (1, 10) N_atoms (tuple): Minimum and maximum number of atoms to extract, e.g. (2, 100) overwrite (bool): Whether to overwrite existing entries attribute (str): Attribute name to check for extraction use_relaxed (str): Type of relaxed structure to use ('ff_relaxed' or 'vasp_relaxed') Yields: tuple: (id, xtal) where id is the row ID and xtal is the corresponding pyxtal object """ (min_id, max_id) = ids if min_id is None: min_id = 1 if max_id is None: max_id = self.get_max_id() (min_atoms, max_atoms) = N_atoms if min_atoms is None: min_atoms = 1 if max_atoms is None: max_atoms = 5000 ids, xtals = [], [] for row in self.db.select(): if not overwrite and hasattr(row, attribute): continue #print(attribute, overwrite, hasattr(row, attribute), getattr(row, attribute)) id, natoms = row.id, row.natoms if min_id <= id <= max_id and \ min_atoms < natoms <= max_atoms \ and id % self.size== self.rank: xtal = self.get_pyxtal(id, use_relaxed) yield id, xtal
[docs] def update_row_energy( self, calculator='GULP', ids=(None, None), N_atoms=(None, None), ncpu=1, criteria=None, symmetrize=False, overwrite=False, write_freq=100, ff_lib='reaxff', steps=250, fmax=0.1, use_relaxed=None, cmd=None, calc_folder=None, skf_dir=None, ): """ Update the row energy in the database for a given calculator. Args: calculator (str): 'GULP', 'MACE', 'VASP', 'DFTB' ids (tuple): A tuple specifying row IDs to update (e.g., (0, 100)). ncpu (int): number of parallel processes criteria (dict, optional): Criteria when selecting structures. symmetrize (bool): symmetrize the structure before calculation overwrite (bool): overwrite the existing energy attributes. write_freq (int): frequency to update db for ncpu=1 ff_lib (str): Force field to use for GULP ('reaxff' by default). steps (int): Number of optimization steps for DFTB (default is 250). fmax (float): force tolerance for mace (defalut is 0.1) use_relaxed (str, optional): Use relaxed structures (e.g. 'ff_relaxed') cmd (str, optional): Command for VASP calculations calc_folder (str, optional): calc_folder for GULP/VASP calculations skf_dir (str, optional): Directory for DFTB potential files Functionality: Using the selected calculator, it updates the energy rows of the database. If `ncpu > 1`, run in parallel; otherwise in serial. Calculator Options: - 'GULP': Uses a force field (e.g., 'reaxff'). - 'MACE': Uses the MACE calculator. - 'DFTB': Uses DFTB+ with symmetrization options. - 'VASP': Uses VASP, with a specified command (`cmd`). """ label = calculator.lower() + "_energy" if calculator == 'GULP': label = 'ff_energy' if calc_folder is None: calc_folder = calculator.lower() + "_calc" if calculator != 'MACE': #self.logging.info("make new folders", calc_folder, os.getpwd()) os.makedirs(calc_folder, exist_ok=True) # Generate structures for calculation generator = self.select_xtal(ids, N_atoms, overwrite, label, use_relaxed) # Set up arguments for the chosen calculator args_up = [] if calculator == 'GULP': args = [calculator, ff_lib, calc_folder, criteria] args_up = [ff_lib] elif calculator == 'MACE': args = [calculator, steps, fmax, criteria] elif calculator == 'DFTB': args = [calculator, skf_dir, steps, symmetrize, criteria] elif calculator == 'VASP': args = [calculator, calc_folder, cmd, criteria] else: raise ValueError(f"Unsupported calculator: {calculator}") # Perform calculation serially or in parallel self.logging.info(f"Rank-{self.rank} row_energy {calculator} {self.db_name}") if ncpu == 1: self.update_row_energy_serial(generator, write_freq, args, args_up) else: self.update_row_energy_mproc(ncpu, generator, args, args_up) self.logging.info(f"Rank-{self.rank} complete update_row_energy")
[docs] def update_row_energy_serial(self, generator, write_freq, args, args_up): """ Perform a serial update of row energies Args: generator (generator): Yielding tuples of (id, xtal), where: - `id` (int): Unique identifier for the structure. - `xtal` (object): pyxtal instance. write_freq (int): Frequency to update the database. args (list): Additional arguments to the function `opt_single`. args_up (list): Additional arguments for function `_update_db`. Functionality: It iterates over structures provided by `generator`, optimizes them using `opt_single`, and collects results that have converged (`status == True`). Once the number of results reaches `write_freq`, it updates the database. """ results = [] for id, xtal in generator: self.logging.info(f"Processing {id} {xtal.lattice} {args[0]}") print(f"Processing {id} {xtal.lattice} {args[0]}") res = opt_single(id, xtal, *args) (xtal, eng, status) = res if status: results.append((id, xtal, eng)) if len(results) >= write_freq: self._update_db(results, args[0], *args_up) results = [] self.print_memory_usage() if len(results) > 0: self._update_db(results, args[0], *args_up)
[docs] def update_row_energy_mproc(self, ncpu, generator, args, args_up): """ Perform parallel row energy updates by optimizing atomic structures. Args: ncpu (int): Number of CPUs to use for parallel processing. generator (generator): yielding tuples of (id, xtal), where: - `id` (int): Unique identifier for the structure. - `xtal` (object): pyxtal instance. args (list): Additional arguments passed to `call_opt_single`. - Typically includes a calculator or potential parameters. args_up (list): Additional arguments for function `_update_db`. Functionality: This function distributes the structures across multiple CPUs using `multiprocessing.Pool`. It creates chunks (based on `ncpu`), and process them in parallel by calling `call_opt_single`. Successful results are periodically written to the database. The function also prints memory usage after each database update. Parallelization Process: - The `Pool` is initialized with `ncpu` processes. - Structures are divided into chunks with the `chunkify` function. - Each chunk is processed by `call_opt_single` via the pool. - Successful results are periodically written to the database. - The pool is closed and joined after processing is complete. """ from multiprocessing import Pool self.logging.info(f"Parallel optimizations {ncpu}") pool = Pool(processes=ncpu, initializer=setup_worker_logger, initargs=(self.log_file,)) def chunkify(generator, chunk_size): chunk = [] for item in generator: chunk.append(item) if len(chunk) == chunk_size: yield chunk chunk = [] if chunk: yield chunk for chunk in chunkify(generator, ncpu*10): myargs = [] for _id, xtal in chunk: if xtal is not None: myargs.append(tuple([_id, xtal] + args)) results = [] self.logging.info(f"Start minicycle: {myargs[0][0]}-{myargs[-1][0]}") for result in pool.imap_unordered(call_opt_single, myargs, chunksize=1): if result is not None: (myid, xtal, eng) = result if eng is not None: results.append(result) numIons = sum(xtal.numIons) count = len(results) self.logging.info(f"Add {myid:4d} {eng:.3f} *{numIons} {count}") # Only do frequent update for slow calculator VASP if len(results) >= ncpu and args[0] == 'VASP': self._update_db(results, args[0], *args_up) self.logging.info(f"Finish minibatch: {len(results)}") self.print_memory_usage() results = [] self.logging.info(f"Done minicycle: {myargs[0][0]}-{myargs[-1][0]}") # After the loop, handle the remaining results if results: self.logging.info(f"Start Update db: {len(results)}") self._update_db(results, args[0], *args_up) self.logging.info(f"Finish Update db: {len(results)}") pool.close() pool.join()
def _update_db(self, results, calc, *args): """ Update db with the calculation_results https://wiki.fysik.dtu.dk/ase/ase/db/db.html#writing-and-updating-many-rows-efficiently Args: results: list of (id, xtal, eng) tuples calc (str): calculator """ #self.logging.info(f"====================Update db: {len(results)}") if calc == 'GULP': ff_lib = args[0] with self.db: for result in results: (id, xtal, eng) = result if xtal is not None: if calc == 'GULP': self.db.update(id, ff_energy=eng, ff_lib=ff_lib, ff_relaxed=xtal.to_file()) elif calc == 'MACE': self.db.update(id, mace_energy=eng, mace_relaxed=xtal.to_file()) elif calc == 'VASP': self.db.update(id, vasp_energy=eng, vasp_relaxed=xtal.to_file()) elif calc == 'DFTB': self.db.update(id, dftb_energy=eng, dftb_relaxed=xtal.to_file()) #self.logging.info(f'update_db_{calc}, {id}')
[docs] def update_row_topology(self, StructureType="Auto", overwrite=True, prefix=None, ref_dim=3, timeout=60): """ Update row topology using CrystalNets.jl via subprocess (faster than juliacall). Args: StructureType (str): Type of structure to analyze. Options are: - 'Zeolite': For zeolite structures - 'MOF': For metal-organic frameworks - 'Auto': For automatic detection overwrite (bool): Whether to overwrite existing topology attributes. prefix (str): Prefix for temporary CIF files. ref_dim (int): Reference dimensionality to compare against. timeout (int): Timeout in seconds for each Julia call. Default is 60. """ import subprocess import json import os from time import time # Create Julia script for batch topology processing (handles multiple ARGS) julia_script = f""" using CrystalNets using JSON CrystalNets.toggle_warning(false) CrystalNets.toggle_export(false) structure_type = "{StructureType}" if structure_type == "Zeolite" option = CrystalNets.Options(structure=CrystalNets.StructureType.Zeolite) elseif structure_type == "MOF" option = CrystalNets.Options(structure=CrystalNets.StructureType.MOF) else option = CrystalNets.Options(structure=CrystalNets.StructureType.Auto) end function process_one(cif_file) try result = CrystalNets.determine_topology(cif_file, option) output = [] results_list = length(result) > 1 ? collect(result) : [result[1]] for res in results_list name = string(res[1]) count = res[2] genome = res[1][CrystalNets.Clustering.Auto] dim = CrystalNets.ndims(CrystalNets.PeriodicGraph(genome)) push!(output, Dict("dim" => dim, "name" => name, "count" => count)) end return output catch e return Dict("error" => string(e)) end end # Process all input files and return an array aligned to ARGS function process_batch(files) results = Vector{{Any}}() for f in files push!(results, process_one(f)) end return results end if length(ARGS) > 0 println(JSON.json(process_batch(ARGS))) end """ # Save Julia script script_path = prefix + "_process_topology.jl" with open(script_path, "w") as f: f.write(julia_script) def parse_topology(topology_list): """Parse topology list to get dimension, name, and detail""" dim = 0 name = "" detail = "None" for i, topo in enumerate(topology_list): d = topo["dim"] n = topo["name"] if d > dim: dim = d tmp = n.split(",")[0] if tmp.startswith("UNKNOWN"): detail = tmp[7:] tmp = "aaa" elif tmp.startswith("unstable"): tmp = "unstable" name += tmp if topo["count"] > 1: name += f"({topo['count']})" if i + 1 < len(topology_list): name += "-" return dim, name, detail # Collect rows to process rows_to_process = [] for row in self.db.select(): if overwrite or not hasattr(row, "topology"): rows_to_process.append(row.id) if len(rows_to_process) == 0: self.logging.info("No rows to process for topology update") return self.logging.info(f"Processing {len(rows_to_process)} structures for topology") # Process structures in batches of 100: write CIFs once, call Julia once per batch updates = [] batch_size = 100 def write_cifs_for_batch(batch_ids): files = [] for row_id in batch_ids: atoms = self.db.get_atoms(row_id) cif_file = f"{prefix}_{row_id}.cif" if prefix is not None else f"tmp_{row_id}.cif" atoms.write(cif_file, format="cif", parallel=False) files.append(cif_file) return files def cleanup_files(files): for f in files: try: if os.path.exists(f): os.remove(f) except Exception: pass # Loop over rows_to_process in chunks for start in range(0, len(rows_to_process), batch_size): end = min(start + batch_size, len(rows_to_process)) batch_ids = rows_to_process[start:end] # 1) Write all CIFs for the batch cif_files = write_cifs_for_batch(batch_ids) # 2) Call Julia once with all file paths as arguments try: t0 = time() result = subprocess.run( ["julia", script_path, *cif_files], capture_output=True, text=True, timeout=timeout ) elapsed = time() - t0 if result.returncode != 0: self.logging.warning(f"Julia failed for batch {start}-{end}: {result.stderr}") # Mark all as errors in this batch for row_id in batch_ids: updates.append((row_id, "error", 3, "julia_error")) else: # Expect JSON array of results aligned with inputs try: batch_output = json.loads(result.stdout.strip()) except Exception as e: self.logging.warning(f"Failed to parse JSON for batch {start}-{end}: {e}") for row_id in batch_ids: updates.append((row_id, "error", 3, "json_error")) cleanup_files(cif_files) continue # If single file, CrystalNets returns object; normalize to list if isinstance(batch_output, dict) and ("error" in batch_output or "dim" in batch_output): batch_output = [batch_output] # If Julia printed one JSON object per line, split and parse if not isinstance(batch_output, list): lines = [line for line in result.stdout.splitlines() if line.strip()] batch_output = [] for line in lines: try: batch_output.append(json.loads(line)) except Exception: batch_output.append({"error": "line_parse_error"}) # 3) Map results back to row ids print(f"Batch output length: {len(batch_output)}, expected: {len(batch_ids)}, elapsed: {elapsed:.2f}s") for idx, row_id in enumerate(batch_ids): #print(f"Processing result for row {row_id}, index {idx}, {batch_output[idx]}") if idx >= len(batch_output): updates.append((row_id, "error", 3, "missing_output")) continue out = batch_output[idx] # out should be a list of topo dicts for this file; handle both list/dict topo_list = out if isinstance(out, list) else [out] dim, name, detail = parse_topology(topo_list) if name.startswith("FAILED"): name = '0-dimensional' # Optional verbose line for matches if dim == ref_dim: row = self.db.get(row_id) print(f"Row {row_id}: {row.space_group_number} {row.wps} dim={dim} {name}") updates.append((row_id, name, dim, detail)) except subprocess.TimeoutExpired: self.logging.warning(f"Timeout for batch {start}-{end} after {timeout}s") for row_id in batch_ids: updates.append((row_id, "timeout", 3, "timeout")) except Exception as e: self.logging.warning(f"Error processing batch {start}-{end}: {e}") for row_id in batch_ids: updates.append((row_id, "error", 3, str(e)[:100])) finally: # 4) Clean up CIF files for this batch cleanup_files(cif_files) # 5) Write batch results to DB if updates: self.logging.info(f"Batch updating {len(updates)} rows") with self.db: for (rid, tname, tdim, tdetail) in updates: self.db.update(rid, topology=tname, dimension=tdim, topology_detail=tdetail) updates = [] # Clean up Julia script if os.path.exists(script_path): os.remove(script_path) self.logging.info(f"Completed topology update for {len(rows_to_process)} structures")
[docs] def update_db_description(self): """ Update database description using robocrys. Uses robocrystallographer (https://github.com/hackingmaterials/robocrystallographer) to generate natural language descriptions of crystal structures. For each row in the database that doesn't have a description: 1. Converts ASE atoms to pymatgen structure 2. Uses StructureCondenser to analyze bonding/connectivity 3. Uses StructureDescriber to generate text description 4. Updates the database row with the description Note: Use it with caution, as it may take a long time to run. """ from robocrys import StructureCondenser, StructureDescriber condenser = StructureCondenser() describer = StructureDescriber() for row in self.db.select(): if not hasattr(row, "description"): atoms = self.db.get_atoms(row.id) pmg = ase2pymatgen(atoms) try: condensed_structure = condenser.condense_structure(pmg) description = describer.describe(condensed_structure) except: description = "N/A" self.db.update(row.id, description=description) print("\n======Updating\n", description) else: print("\n======Existing\n", row.description)
[docs] def export_structures( self, fmt="vasp", folder="mof_out", criteria=None, sort_by="similarity", overwrite=True, cutoff=None, use_relaxed=None, ): """ Export structures from database according to given criteria. Args: fmt (str): Output format (``vasp`` or ``cif``) folder (str): Path to output folder criteria (dict): Dictionary of validity criteria sort_by (str): Attribute to sort structures by overwrite (bool): Whether to remove existing output folder cutoff (int): Maximum number of structures to export use_relaxed (str, optional): e.g., ``ff_relaxed`` or ``vasp_relaxed`` """ import shutil if cutoff is None: cutoff = self.db.count() if not os.path.exists(folder): os.makedirs(folder) else: if overwrite: shutil.rmtree(folder) os.makedirs(folder) keys = [ "id", "pearson_symbol", "space_group_number", "density", "dof", "similarity", "ff_energy", "vasp_energy", "mace_energy", "topology", ] properties = [] for row in self.db.select(): spg = row.space_group_number den = row.density dof = row.dof ps = row.pearson_symbol sim = float(row.similarity) if hasattr( row, "similarity") and row.similarity is not None else None top = row.topology if hasattr(row, "topology") else None ff_eng = float(row.ff_energy) if hasattr( row, "ff_energy") else None vasp_eng = float(row.vasp_energy) if hasattr( row, "vasp_energy") else None mace_eng = float(row.mace_energy) if hasattr( row, "mace_energy") else None properties.append([row.id, ps, spg, den, dof, sim, ff_eng, vasp_eng, mace_eng, top]) dicts = {} for i, key in enumerate(keys): if properties[0][i] is not None: dicts[key] = [prop[i] for prop in properties] if sort_by in keys: col = keys.index(sort_by) # + 1 else: print("supported attributes", keys) raise ValueError("Cannot sort by", sort_by) print(f"====Exporting {len(properties)} structures") properties = [prop for prop in properties if prop[col] is not None] sorted_properties = sorted(properties, key=lambda x: x[col]) for entry in sorted_properties[:cutoff]: [id, ps, spg, den, dof, sim, ff_eng, vasp_eng, mace_eng, top] = entry id = int(id) spg = int(spg) sim = float(sim) den = float(den) dof = int(dof) if vasp_eng is not None: eng = float(vasp_eng) elif mace_eng is not None: eng = float(mace_eng) elif ff_eng is not None: eng = float(ff_eng) else: eng = None if True: #try: xtal = self.get_pyxtal(id, use_relaxed) number, symbol = xtal.group.number, xtal.group.symbol.replace( "/", "") # convert to the desired subgroup representation if needed #if number != spg: # paths = xtal.group.path_to_subgroup(spg) # xtal = xtal.to_subgroup(paths) # number, symbol = ( # xtal.group.number, # xtal.group.symbol.replace("/", ""), # ) label = os.path.join( folder, f"{id:d}-{xtal.get_Pearson_Symbol():s}-{number:d}-{symbol:s}", ) status = xtal.check_validity( criteria, True) if criteria is not None else True #except: # status = False # label = "Error" if status: try: # if True: xtal.set_site_coordination() for s in xtal.atom_sites: _l, _sp, _cn = s.wp.get_label(), s.specie, s.coordination label += f"-{_l:s}-{_sp:s}{_cn:d}" label += f"-S{sim:.3f}" if len(label) > 40: label = label[:40] except: print("Problem in setting site coordination") if den is not None: label += f"-D{abs(den):.2f}" if eng is not None: label += f"-E{abs(eng):.3f}" if top is not None: label += f"-T{top:s}" # if sim is not None: label += '-S{:.2f}'.format(sim) print("====Exporting:", label) if fmt == "vasp": xtal.to_file(label + ".vasp", fmt="poscar") elif fmt == "cif": xtal.to_file(label + ".cif") else: print("====Skippng: ", label) return dicts
[docs] def get_label(self, i): if i < 10: folder = f"cpu00{i}" elif i < 100: folder = f"cpu0{i}" else: folder = f"cpu0{i}" return folder
[docs] def get_db_unique(self, db_name=None, prec=3, key='ff_energy', max_N_atoms=64): """ Get a database file containing only unique structures based on topology and energy. Args: db_name (str, optional): Filename for the new database. If None, will use original name with '_unique' suffix. prec (int, optional): Precision for rounding energy values. Default is 3. key (str, optional): Energy attribute name to use for filtering. Default is 'ff_energy'. max_N_atoms (int, optional): Maximum n_atoms for pmg match. Default is 64. Returns: int: Number of unique structures in the new database. Note: Two structures are considered identical if they have: - Same density value (within precision) - Same energy value (within precision) - Pymatgen match When duplicates are found, the structure with lower DOF is kept. """ from pymatgen.analysis.structure_matcher import StructureMatcher matcher = StructureMatcher(stol=0.3, ltol=0.2, angle_tol=5) print(f"The {self.db_name:s} has {self.db.count():d} strucs") if db_name is None: db_name = self.db_name[:-3] + "_unique.db" if os.path.exists(db_name): os.remove(db_name) lists = [] for row in self.db.select(): if hasattr(row, key) and getattr(row, key) is not None: dof, den, energy = row.dof, round(row.density, prec), round(getattr(row, key), prec) spg, wps = row.space_group_number, row.wps is_unique = True pmg = ase2pymatgen(row.toatoms()) list_entry = (row.id, dof, den, energy, spg, wps, pmg) for list_entry_existing in lists: (_id, _dof, _den, _energy, _spg, _wps, _pmg) = list_entry_existing if den == _den and energy == _energy: # check pymatgen match if len(_pmg) > max_N_atoms or len(pmg) > max_N_atoms: if spg == _spg and wps == _wps: is_unique = False print("Duplicate", row.id, den, energy) break # for large structures, skip pymatgen match to save time else: if matcher.fit(pmg, _pmg): is_unique = False if dof < _dof: print("Updating", row.id, den, energy) lists.remove(list_entry_existing) lists.append(list_entry) else: print("Duplicate", row.id, den, energy) break if is_unique: print("Adding", row.id, den, energy) lists.append(list_entry) ids = [entry[0] for entry in lists] with connect(db_name, serial=True) as db: for id in ids: row = self.db.get(id) kvp = {} for key in self.keys: if hasattr(row, key): kvp[key] = getattr(row, key) db.write(row.toatoms(), key_value_pairs=kvp) print(f"Created {db_name:s} with {db.count():d} strucs") return db.count()
[docs] def get_db_unique_topology(self, db_name=None, prec=3, update_topology=True, key='ff_energy'): """ Get a database file containing only unique structures based on topology and energy. Args: db_name (str, optional): Filename for the new database. If None, will use original name with '_unique' suffix. prec (int, optional): Precision for rounding energy values. Default is 3. update_topology (bool, optional): Whether to update topology before filtering. Default is True. key (str, optional): Energy attribute name to use for filtering. Default is 'ff_energy'. Returns: int: Number of unique structures in the new database. Note: Two structures are considered identical if they have: - Same topology - Same topology detail - Same energy value (within precision) When duplicates are found, the structure with lower DOF is kept. """ print(f"The {self.db_name:s} has {self.db.count():d} strucs") if db_name is None: db_name = self.db_name[:-3] + "_unique.db" if os.path.exists(db_name): os.remove(db_name) unique_props = {} # Using a dictionary to store unique properties if update_topology: self.update_row_topology() for row in self.db.select(): if hasattr(row, key) and getattr(row, key) is not None: top, top_detail = row.topology, row.topology_detail dof, energy = row.dof, round(getattr(row, key), prec) prop_key = (top, top_detail, energy) # A dictionary lookup if prop_key in unique_props: _id, _dof = unique_props[prop_key] if dof < _dof: print("Updating", row.id, top, energy) unique_props[prop_key] = (row.id, dof) else: print("Duplicate", row.id, top, energy) else: print("Adding", row.id, top, energy) unique_props[prop_key] = (row.id, dof) ids = [unique_props[key][0] for key in unique_props.keys()] with connect(db_name, serial=True) as db: for id in ids: row = self.db.get(id) kvp = {} for key in self.keys: if hasattr(row, key): kvp[key] = getattr(row, key) db.write(row.toatoms(), key_value_pairs=kvp) print(f"Created {db_name:s} with {db.count():d} strucs") return db.count()
[docs] def check_overlap(self, reference_db, etol=2e-3, verbose=True): """ Check the overlap with a reference database. Args: reference_db (str): Path to the reference database file etol (float, optional): Energy tolerance for identifying identical structures. Default is 2e-3. verbose (bool, optional): Whether to print detailed overlap information. Default is True. Returns: list: List of overlapping structures, where each entry contains: (id, pearson_symbol, dof, topology, ff_energy) Note: Two structures are considered overlapping if they have: - Same topology - Same topology detail - Force field energies within etol of each other """ db_ref = database_topology(reference_db, log_file = self.log_file) print(f"\nCurrent database {self.db_name}: {self.db.count()}") print(f"Reference database {db_ref.db_name}: {db_ref.db.count()}") ref_data = [] for row in db_ref.db.select(): if hasattr(row, "topology") and hasattr(row, "ff_energy"): ref_data.append( (row.topology, row.topology_detail, row.ff_energy)) overlaps = [] for row in self.db.select(): if hasattr(row, "topology") and hasattr(row, "ff_energy"): for ref in ref_data: (top, top_detail, ff_energy) = ref if ( row.topology == top and row.topology_detail == top_detail and abs(row.ff_energy - ff_energy) < etol ): # strs = 'Find {:4d} {:6s}'.format(row.id, row.pearson_symbol) # strs += ' {:12s} {:10.3f}'.format(row.topology, row.ff_energy) # print(strs) overlaps.append( ( row.id, row.pearson_symbol, row.dof, row.topology, row.ff_energy, ) ) break strs = f"\nThe number of overlap is: {len(overlaps):d}" strs += f"/{self.db.count():d}/{db_ref.db.count():d}" print(strs) sorted_overlaps = sorted(overlaps, key=lambda x: x[-1]) if verbose: for entry in sorted_overlaps: print("{:4d} {:6s} {:4d} {:20s} {:10.3f}".format(*entry)) return overlaps
[docs] def print_info(self, excluded_ids=None, cutoff=100): """ Print out the summary of the database based on the calculated energy Mostly used to quickly view the most interesting low-energy structures. Todo: show vasp_energy if available Args: excluded_ids (list): list of unwanted row ids cutoff (int): the cutoff value for the print """ if excluded_ids is None: excluded_ids = [] print(f"\nCurrent database {self.db_name}: {self.db.count()}") output = [] for row in self.db.select(): if row.id not in excluded_ids and hasattr(row, "topology") and hasattr(row, "ff_energy"): output.append( ( row.id, row.pearson_symbol, row.dof, row.topology, row.ff_energy, ) ) sorted_output = sorted(output, key=lambda x: x[-1]) for entry in sorted_output[:cutoff]: print("{:4d} {:6s} {:4d} {:20s} {:10.3f}".format(*entry)) strs = f"Showed structures: {len(sorted_output)}/{self.db.count()}" print(strs)
[docs] def plot_histogram(self, prop, ax=None, filename=None, xlim=None, nbins=20): """ Plot the histogram of a specified row property. Args: prop (str): The name of the property to plot (e.g., 'ff_energy'). ax (matplotlib.axes.Axes, optional): Pre-existing axis to plot on. If None, a new ax will be created. filename (str, optional): Path to save the plot (e.g., 'plot.png'). If None, the plot will not be saved. xlim (tuple, optional): Limits for the x-axis (e.g., (0, 10)). If None, the x-axis will scale automatically. nbins (int, optional): Number of bins for the histogram. Default is 20. Returns: matplotlib.axes.Axes: The axis object with the histogram plotted. """ import matplotlib.pyplot as plt if ax is None: f, ax = plt.subplots() # Get the properties from the database props = self.get_properties(prop) # Check if there are values to plot if not props: raise ValueError(f"No rows contain the property '{prop}'.") ax.hist(props, nbins, density=True, alpha=0.75) # Set x-axis limits if provided if xlim is not None: ax.set_xlim(xlim) ax.set_xlabel(prop) # Save the plot if a filename is provided if filename is not None: plt.savefig(filename) return ax
[docs] def get_properties(self, prop): """ Retrieve a list of specific property values from the database rows. Args: prop (str): The property name to retrieve (e.g., 'ff_energy') Returns: list: A list of property values for rows that have the specified property. If a row does not contain the property, it is ignored. Raises: Warning: If no rows in the database contain the specified property. """ props = [] # Loop through all rows in the database and collect the property values for row in self.db.select(): if hasattr(row, prop): props.append(getattr(row, prop)) # Print summary of rows name, count = self.db_name, self.db.count() print(f"Database {name} has {prop}: {len(props)}/{count}") # Warn if no properties were found if count == 0: raise Warning( f"No rows in the database contain the property '{prop}'.") return props
if __name__ == "__main__": # open if False: db = database("test.db") print("Total number of entries", len(db.codes)) # view structure c = db.get_pyxtal("HXMTAM") print(c) if False: db = database_topology("../MOF-Builder/C-sp2/sp2-sacada-0506.db") # xtal = db.get_pyxtal(1) # print(xtal) # db.add_xtal(xtal, kvp={'similarity': 0.1}) # db.update_row_ff_energy(ids=(0, 2), overwrite=True) # db.update_row_ff_energy(ncpu=2, ids=(2, 20), overwrite=True) # brew install coreutils to get timeout in maca # os.environ['ASE_DFTB_COMMAND'] = 'timeout 1m /Users/qzhu8/opt/dftb+/bin/dftb+ > PREFIX.out' os.environ["ASE_DFTB_COMMAND"] = "/Users/qzhu8/opt/dftb+/bin/dftb+ > PREFIX.out" skf_dir = "/Users/qzhu8/GitHub/MOF-Builder/3ob-3-1/" # db.update_row_dftb_energy(skf_dir, ncpu=1, ids=(0, 2), overwrite=True) db.update_row_dftb_energy( skf_dir, ncpu=1, ids=(17, 17), overwrite=True) db = database_topology("total.db") db.get_db_unique() db1 = database_topology("sp2_sacada.db") db1.get_db_unique() db = database_topology("total_unique.db") db.check_overlap("sp2_sacada_unique.db") db1.export_structures(folder="mof_out_sacada")