"""
Database class
"""
import os
import logging
import numpy as np
import pymatgen.analysis.structure_matcher as sm
from ase.calculators.calculator import CalculationFailed
from ase.db import connect
from pyxtal import pyxtal
from pyxtal.util import ase2pymatgen
[docs]
def setup_worker_logger(log_file):
"""
Set up the logger for each worker process.
"""
logging.getLogger().handlers.clear()
logging.basicConfig(format="%(asctime)s| %(message)s",
filename=log_file,
level=logging.INFO)
[docs]
def call_opt_single(p):
"""
Optimize a single structure and log the result.
Args:
p (tuple): A tuple where the first element is an identifier (id), and the
remaining elements are the arguments to pass to `opt_single`.
Returns:
tuple: A tuple (id, xtal, eng) where:
- id (int): The identifier of the structure.
- xtal: The optimized structure.
- eng (float): The energy of the opt_structure, or None if it failed.
Behavior:
This function calls `opt_single` to perform the optimization of the structure
associated with the given id.
"""
#logger = logging.getLogger()
#logger.info(f"ID: {p[0]} *{sum(p[1].numIons)}")
myid = p[0]
xtal, eng, status = opt_single(*p)
return myid, xtal, eng
[docs]
def opt_single(id, xtal, calc, *args):
"""
Optimize a structure using the specified calculator.
Args:
id (int): Identifier of the structure to be optimized.
xtal: Crystal structure object to be optimized.
calc (str): The calculator to use ('GULP', 'DFTB', 'VASP', 'MACE').
*args: Additional arguments to pass to the calculator function.
Returns:
tuple: The result of the optimization, which typically includes:
- xtal: The optimized structure.
- energy (float): The energy of the optimized structure.
- status (bool): Whether the optimization was successful.
Raises:
ValueError: If an unsupported calculator is specified.
"""
if calc == 'GULP':
return gulp_opt_single(id, xtal, *args)
elif calc == 'DFTB':
return dftb_opt_single(id, xtal, *args)
elif calc == 'VASP':
return vasp_opt_single(id, xtal, *args)
elif calc == 'MACE':
return mace_opt_single(id, xtal, *args)
else:
raise ValueError("Cannot support this calcultor", calc)
[docs]
def dftb_opt_single(id, xtal, skf_dir, steps, symmetrize, criteria, kresol=0.05):
"""
Single DFTB optimization for a given atomic xtal
Args:
id (int): id of the give xtal
xtal: pyxtal instance
skf_dir (str): path of skf files
steps (int): number of relaxation steps
criteria (dicts): to check if the structure
"""
from pyxtal.interface.dftb import DFTB, DFTB_relax
cwd = os.getcwd()
atoms = xtal.to_ase(resort=False)
eng = None
stress = None
try:
if symmetrize:
s = DFTB_relax(
atoms,
skf_dir,
True,
int(steps / 2),
kresol=kresol,
folder=".",
scc_error=1e-5,
scc_iter=100,
logfile="ase.log",
)
stress = np.sum(s.get_stress()[:3]) / 0.006241509125883258 / 3
else:
my = DFTB(
atoms,
skf_dir,
kresol=kresol * 1.5,
folder=".",
scc_error=0.1,
scc_iter=100,
)
s, eng = my.run(mode="vc-relax", step=int(steps / 2))
my = DFTB(s, skf_dir, kresol=kresol, folder=".",
scc_error=1e-4, scc_iter=100)
s, eng = my.run(mode="vc-relax", step=int(steps / 2))
s = my.struc
except CalculationFailed:
# This is due to covergence error in geometry optimization
# Here we simply read the last energy
my.calc.read_results()
eng = my.calc.results["energy"]
s = my.struc
except:
s = None
print("Problem in DFTB Geometry optimization", id)
xtal.to_file("bug.cif") # ; import sys; sys.exit()
os.chdir(cwd)
status = False
if s is not None:
c = pyxtal()
c.from_seed(s)
if eng is None:
eng = s.get_potential_energy() / len(s)
else:
eng /= len(s)
status = process_xtal(id, xtal, eng, criteria)
print(xtal.get_xtal_string())
return xtal, eng, status
else:
return None, None, False
[docs]
def vasp_opt_single(id, xtal, path, cmd, criteria):
"""
Single VASP optimization for a given atomic xtal
Args:
id (int): id of the give xtal
xtal: pyxtal instance
path: calculation folder
cmd: vasp command
criteria (dicts): to check if the structure
"""
from pyxtal.interface.vasp import optimize as vasp_opt
cwd = os.getcwd()
path += '/g' + str(id)
status = False
xtal, eng, _, error = vasp_opt(xtal,
path,
cmd=cmd,
walltime="59m")
if not error:
status = process_xtal(id, xtal, eng, criteria)
else:
os.chdir(cwd)
return xtal, eng, status
[docs]
def gulp_opt_single(id, xtal, ff_lib, path, criteria):
"""
Perform a single GULP optimization for a given crystal structure.
Args:
id (int): Identifier for the current structure.
xtal: PyXtal instance representing the crystal to be optimized.
ff_lib (str): Force field library for GULP, e.g., 'reaxff', 'tersoff'.
path (str): Path to the folder where the calculation is stored.
criteria (dict): Dictionary to check the validity of the opt_structure.
Returns:
tuple:
- xtal: Optimized PyXtal instance.
- eng (float): Energy of the optimized structure.
- status (bool): Whether the optimization process is successful.
Behavior:
This function performs a GULP optimization using the force field.
After the optimization, it checks the validity of the structure and
attempts to remove the calculation folder if it is empty.
"""
from pyxtal.interface.gulp import single_optimize as gulp_opt
# Create the path for this specific structure
path += '/g' + str(id)
# Perform the optimization with GULP
xtal, eng, _, error = gulp_opt(
xtal,
ff=ff_lib,
label=str(id),
path=path,
symmetry=True,
)
# Default status to False, will be updated if successful
status = False
if not error:
status = process_xtal(id, xtal, eng, criteria)
try:
os.rmdir(path)
except:
print("Folder is not empty", path)
return xtal, eng, status
[docs]
def mace_opt_single(id, xtal, step, fmax, criteria):
"""
Perform a single MACE optimization for a given atomic crystal structure.
Args:
id (int): Identifier for the current structure.
xtal: PyXtal instance representing the crystal structure.
step (int): Maximum number of relaxation steps. Default is 250.
fmax (float): fmax for relaxation
criteria (dict): Dictionary to check the validity of the optimized structure.
Returns:
tuple:
- xtal: Optimized PyXtal instance (or None if optimization failed).
- eng (float): Energy/atom of the opt_structure (or None if it failed).
- status (bool): Whether the optimization was successful.
"""
from pyxtal.interface.ase_opt import ASE_relax as mace_opt
logger = logging.getLogger()
atoms = xtal.to_ase(resort=False)
s = mace_opt(atoms,
'MACE',
opt_cell=True,
step=step,
fmax=fmax,
max_time=9.0 * max([1, (len(atoms)/200)]),
label=str(id))
if s is None:
logger.info(f"mace_opt_single Failure {id}")
return None, None, False
try:
xtal = pyxtal()
xtal.from_seed(s)
eng = s.get_potential_energy() / len(s)
status = process_xtal(id, xtal, eng, criteria)
logger.info(f"mace_opt_single Success {id}")
return xtal, eng, status
except:
logger.info(f"mace_opt_single Bug {id}")
return None, None, False
[docs]
def process_xtal(id, xtal, eng, criteria):
status = xtal.check_validity(
criteria) if criteria is not None else True
if status:
header = f"{id:4d}"
dicts = {"validity": status, "energy": eng}
print(xtal.get_xtal_string(header=header, dicts=dicts))
return status
[docs]
def make_entry_from_pyxtal(xtal):
"""
Generate an entry dictionary from a PyXtal object, assuming
the SMILES and CCDC number information is provided.
Args:
xtal: PyXtal object (must contain the SMILES (`xtal.tag["smiles"]`)
and CCDC number (`xtal.tag["ccdc_number"]`) in the `xtal.tag`.
Returns:
tuple: (ase_atoms, entry_dict, None)
- ase_atoms: ASE Atoms object converted from the PyXtal structure.
- entry_dict (dict): A dictionary containing information
- None: Placeholder for future use (currently returns None).
Structure of `entry_dict`:
- "csd_code" (str): CSD code (if available) for the crystal structure.
- "mol_smi" (str): SMILES representation of the molecule.
- "ccdc_number" (str): CCDC identifier number.
- "space_group" (str): Space group symbol of the crystal.
- "spg_num" (int): Space group number.
- "Z" (int): Number of molecules in the unit cell.
- "Zprime" (float): Z' value of the crystal.
- "url" (str): URL link to the CCDC database entry for the crystal.
- "mol_formula" (str): Molecular formula of the structure.
- "mol_weight" (float): Molecular weight of the structure.
- "mol_name" (str): Name of the molecule, typically the CSD code.
- "l_type" (str): Lattice type of the structure.
Returns None if the PyXtal structure is invalid (i.e., `xtal.valid` is False).
Example:
entry = make_entry_from_pyxtal(xtal_instance)
ase_atoms, entry_dict, _ = entry
Notes:
- The CCDC link is generated using the structure's CCDC number.
"""
from rdkit import Chem
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem.rdMolDescriptors import CalcMolFormula
if xtal.valid:
url0 = "https://www.ccdc.cam.ac.uk/structures/Search?Ccdcid="
# Create RDKit molecule from SMILES string
m = Chem.MolFromSmiles(xtal.tag["smiles"])
# Calculate molecular weight and molecular formula using RDKit
mol_wt = ExactMolWt(m)
mol_formula = CalcMolFormula(m)
# Create a dictionary containing information
kvp = {
"csd_code": xtal.tag["csd_code"],
"mol_smi": xtal.tag["smiles"],
"ccdc_number": xtal.tag["ccdc_number"],
"space_group": xtal.group.symbol,
"spg_num": xtal.group.number,
"Z": sum(xtal.numMols),
"Zprime": xtal.get_zprime()[0],
"url": url0 + str(xtal.tag["ccdc_number"]),
"mol_formula": mol_formula,
"mol_weight": mol_wt,
"mol_name": xtal.tag["csd_code"],
"l_type": xtal.lattice.ltype,
}
# Return the ASE Atoms the entry dictionary, and None as a placeholder
return (xtal.to_ase(), kvp, None)
else:
return None
[docs]
def make_entry_from_CSD_web(code, number, smiles, name=None):
"""
make enetry dictionary from csd web https://www.ccdc.cam.ac.uk/structures
Args:
code: CSD style letter entry
number: ccdc number
smiles: the corresponding molecular smiles
name: name of the compound
"""
# xtal = pyxtal(molecular=True)
#
# return make_entry_from_pyxtal(xtal)
raise NotImplementedError("To do in future")
[docs]
def make_entry_from_CSD(code):
"""
make entry dictionary from CSD codes
Args:
code: a list of CSD codes
"""
from pyxtal.msg import CSDError
xtal = pyxtal(molecular=True)
try:
xtal.from_CSD(code)
return make_entry_from_pyxtal(xtal)
except CSDError as e:
print("CSDError", code, e.message)
return None
[docs]
def make_db_from_CSD(dbname, codes):
"""
make database from CSD codes
Args:
dbname: db file name
codes: a list of CSD codes
"""
# open
db = database(dbname)
# add structure
for i, code in enumerate(codes):
entry = make_entry_from_CSD(code)
if entry is not None:
db.add(entry)
print(i, code)
return db
[docs]
class database:
"""
This is a database class to process crystal data.
Args:
db_name: `*.db` format from ase database
"""
def __init__(self, db_name):
self.db_name = db_name
# if not os.path.exists(db_name):
# raise ValueError(db_name, 'doesnot exist')
self.db = connect(db_name, serial=True)
self.codes = self.get_all_codes()
self.keys = [
"csd_code",
"space_group",
"spg_num",
"Z",
"Zprime",
"url",
"mol_name",
"mol_smi",
"mol_formula",
"mol_weight",
"l_type",
]
self.calculators = [
"charmm",
"gulp",
"ani",
"dftb_D3",
"dftb_TS",
]
[docs]
def vacuum(self):
self.db.vacuum()
[docs]
def get_all_codes(self, group=None):
"""
Get all codes
"""
codes = []
for row in self.db.select():
if row.csd_code not in codes:
if group is None:
codes.append(row.csd_code)
else:
if row.group == group:
codes.append(row.csd_code)
else:
print("find duplicate! remove", row.id, row.csd_code)
self.db.delete([row.id])
return codes
# self.codes = codes
[docs]
def add(self, entry):
(atom, kvp, data) = entry
if kvp["csd_code"] not in self.codes:
kvp0 = self.process_kvp(kvp)
self.db.write(atom, key_value_pairs=kvp0, data=data)
self.codes.append(kvp["csd_code"])
[docs]
def add_from_code(self, code):
entry = make_entry_from_CSD(code)
if entry is not None:
self.add(entry)
else:
print(f"{code:s} is not a valid entry")
[docs]
def process_kvp(self, kvp):
kvp0 = {}
for key in self.keys:
if key in kvp:
kvp0[key] = kvp[key]
else:
print("Error, cannot find ", key, " from the input")
return None
return kvp0
[docs]
def check_status(self, show=False):
"""
Check the current status of each entry
"""
ids = []
for row in self.db.select():
if len(row.data.keys()) == len(self.calculators):
ids.append(row.id)
if show:
row_info = self.get_row_info(id=row.id)
self.view(row_info)
else:
print(row.csd_code) # , row.data['charmm_info']['prm'])
return ids
[docs]
def copy(self, db_name, csd_codes):
"""
copy the entries to another db
Args:
db_name: db file name
csd_codes: list of codes
"""
if db_name == self.db_name:
raise RuntimeError("Cannot use the same db file for copy")
with connect(db_name, serial=True) as db:
for csd_code in csd_codes:
row_info = self.get_row_info(code=csd_code)
(atom, kvp, data) = row_info
db.write(atom, key_value_pairs=kvp, data=data)
[docs]
def view(self, row_info):
"""
print the summary of benchmark results
Args:
row: row object
"""
from pyxtal.representation import representation
(atom, kvp, data) = row_info
# Reference
xtal = self.get_pyxtal(kvp["csd_code"])
rep = xtal.get_1D_representation()
print("\n", kvp["csd_code"], kvp["mol_smi"], xtal.lattice.volume)
print(rep.to_string() + " reference")
# calcs
for key in data:
calc = key[:-5]
time = data[key]["time"]
rep = data[key]["rep"]
if type(rep[0]) is not list:
rep = [rep]
rep = representation(rep, kvp["mol_smi"]).to_string()
(dv, msd1, msd2) = data[key]["diff"]
strs = f"{rep:s} {calc:8s} {time / 60:6.2f} {dv:6.3f}"
if msd1 is not None:
strs += f"{msd1:6.3f}{msd2:6.3f}"
print(strs)
[docs]
def get_row_info(self, id=None, code=None):
match = False
if id is not None:
for row in self.db.select(id=id):
match = True
break
elif code is not None:
for row in self.db.select(csd_code=code):
match = True
break
if match:
kvp = {}
for key in self.keys:
kvp[key] = row.key_value_pairs[key]
data0 = {}
for calc in self.calculators:
key = calc + "_info"
if key in row.data:
data0[key] = row.data[key]
atom = self.db.get_atoms(id=row.id)
return (atom, kvp, data0)
else:
msg = "cannot find the entry from " + id + code
raise RuntimeError(msg)
[docs]
def get_row(self, code):
for row in self.db.select(csd_code=code):
return row
msg = "cannot find the entry from " + code
raise RuntimeError(msg)
[docs]
def get_pyxtal(self, code):
from pyxtal import pyxtal
from pyxtal.msg import ReadSeedError
from pyxtal.util import ase2pymatgen
row = self.get_row(code)
atom = self.db.get_atoms(id=row.id)
# Reference
pmg = ase2pymatgen(atom)
smi = row.mol_smi
smiles = smi.split(".")
molecules = [smile + ".smi" for smile in smiles]
xtal = pyxtal(molecular=True)
try:
xtal.from_seed(pmg, molecules=molecules)
except ReadSeedError:
xtal.from_seed(pmg, molecules=molecules, add_H=True)
return xtal
[docs]
def compute(self, row, work_dir, skf_dir):
if len(row.data.keys()) < len(self.calculators):
# not label information, run antechamber
atom = self.db.get_atoms(id=row.id)
if "gulp_info" not in row.data:
# pmg, c_info, g_info = get_parameters(row, atom)
# row.data = {"charmm_info": c_info, "gulp_info": g_info}
pass
else:
pmg = ase2pymatgen(atom)
#data = compute(row, pmg, work_dir, skf_dir)
#self.db.update(row.id, data=data)
#print("updated the data for", row.csd_code)
[docs]
class database_topology:
"""
This is a database class to process atomic crystal data
Args:
db_name (str): `*.db` format from ase database
rank (int): default 0
size (int): default 1
ltol (float): lattice tolerance
stol (float): site tolerance
atol (float): angle tolerance
log_file (str): log_file
"""
def __init__(self, db_name, rank=0, size=1, ltol=0.05, stol=0.05, atol=3,
log_file='db.log'):
self.rank = rank
self.size = size
self.db_name = db_name
self.db = connect(db_name, serial=True)
self.keys = [
"space_group_number",
"pearson_symbol",
"similarity0",
"similarity",
"density",
"dof",
"topology",
"topology_detail",
"dimension",
"wps",
"ff_energy",
"ff_lib",
"ff_relaxed",
"mace_energy",
"mace_relaxed",
"dftb_energy",
"dftb_relaxed",
"vasp_energy",
"vasp_relaxed",
]
self.matcher = sm.StructureMatcher(
ltol=ltol, stol=stol, angle_tol=atol)
# Define logfile
self.log_file = log_file
logging.getLogger().handlers.clear()
logging.basicConfig(format="%(asctime)s| %(message)s",
filename=self.log_file,
level=logging.INFO)
self.logging = logging
[docs]
def vacuum(self):
self.db.vacuum()
[docs]
def print_memory_usage(self):
import psutil
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
self.logging.info(f"Rank {self.rank} memory: {mem:.1f} MB")
print(f"Rank {self.rank} memory: {mem:.1f} MB")
[docs]
def get_pyxtal(self, id, use_relaxed=None, tol=1e-4):
"""
Get pyxtal based on row_id, if use_relaxed, get pyxtal from ff_relaxed
Args:
id (int): row id
use_relaxed (str): 'ff_relaxed', 'vasp_relaxed'
"""
from pymatgen.core import Structure
from pyxtal import pyxtal
from pyxtal.util import ase2pymatgen
from pyxtal.symmetry import Group
row = self.db.get(id)
#print(row.space_group_number, row.topology, row.pearson_symbol, row.wps, row.mace_energy)
if use_relaxed is not None:
if hasattr(row, use_relaxed):
xtal_str = getattr(row, use_relaxed)
pmg = Structure.from_str(xtal_str, fmt="cif")
else:
print(f"No {use_relaxed} attributes for structure", id)
atom = self.db.get_atoms(id=id)
pmg = ase2pymatgen(atom)
else:
#hn = Group(row.space_group_number).hall_number
xtal1 = pyxtal()
#xtal1.from_seed('1.cif', tol=tol)#, hn=hn)
atom = self.db.get_atoms(id=id)
#print(row.topology)
#atom.write('1.cif', format='cif')#, direct=True, vasp5=True)
xtal1.from_seed(atom, tol=tol)#, hn=hn)
pmg = ase2pymatgen(atom)
xtal = pyxtal()
try:
xtal.from_seed(pmg, tol=tol)
#if xtal.group.number != row.space_group_number: print(xtal); import sys; sys.exit()
if xtal is not None and xtal.valid:
for key in self.keys:
if hasattr(row, key):
setattr(xtal, key, getattr(row, key))
return xtal
except:
print(xtal_str)
#import sys; sys.exit()
print("Cannot load the structure")
[docs]
def get_all_xtals(self, include_energy=False):
"""
Get all pyxtal instances from the current db
"""
xtals = []
for row in self.db.select():
xtal = self.get_pyxtal(id=row.id)
if xtal is not None:
if include_energy and hasattr(row, 'vasp_energy'):
xtal.energy = row.vasp_energy
xtals.append(xtal)
return xtals
[docs]
def add_xtal(self, xtal, kvp={}):
"""
Add new xtal to the given db
"""
spg_num = xtal.group.number
density = xtal.get_density()
dof = xtal.get_dof()
wps = [s.wp.get_label() for s in xtal.atom_sites]
_kvp = {
"space_group_number": spg_num,
"pearson_symbol": xtal.get_Pearson_Symbol(),
"wps": str(wps),
"density": density,
"dof": dof,
}
kvp.update(_kvp)
atoms = xtal.to_ase(resort=False)
self.db.write(atoms, key_value_pairs=kvp)
[docs]
def add_strucs_from_db(self, db_file, check=False,
id_min=0, id_max=None, tol=1e-3,
freq=50, use_relaxed=None, sort=None,
max_count=None, criteria=None,
min_atoms=0, max_atoms=250,
ignore_check='vasp_energy',
same_number=False):
"""
Add new structures from another database file.
Args:
db_file (str): Path to the source database file
check (bool): Whether to check if structure already exists before adding
id_min (int): Starting ID to import from source database. Default is 0
id_max (int): Ending ID to import from source database. Default is None
tol (float): Tolerance in Angstroms for symmetry detection. Default is 1e-3
freq (int): Print progress message every N structures. Default is 50
use_relaxed (str): Relaxed structure to use - 'ff_relaxed', 'vasp_relaxed'
Default is None to use unrelaxed structures
sort (str): key to sort the structure, e.g. 'mace_energy'
Default is None to use row.id
max_count (int): Number of maximum structure to add
Default is None to all all structures
criteria (dict): criteria to check if a valid.
"""
cifname = 'my_add.cif'
print(f"\nAdding new strucs from {db_file:s}")
count = 0
with connect(db_file, serial=True) as db:
if id_max is None:
id_max = db.count()
for row in db.select(sort=sort):
if id_min <= row.id <= id_max:
xtal = pyxtal()
if use_relaxed is None:
atoms = row.toatoms()
try:
xtal.from_seed(atoms, tol=tol)
except:
xtal = None
content = (row.mace_energy, row.pearson_symbol, row.wps)
print("Faild to load xtal", content)
else:
with open(cifname, 'w') as f:
f.write(getattr(row, use_relaxed))
try:
xtal.from_seed(cifname, tol=tol)
except:
xtal = None
print("Faild to load xtal", row.mace_energy, row.pearson_symbol, row.wps)
if xtal is not None and xtal.valid:
if criteria is not None and not xtal.check_validity(criteria):
print("hit a invalid structure", row.id, xtal.get_xtal_string())
print(xtal)
continue
#print(xtal)
mace_eng = None if not hasattr(row, 'mace_energy') else row.mace_energy
atoms = xtal.to_ase()
add = True
if check:
if not hasattr(row, ignore_check):
add = self.check_new_structure(xtal, mace_eng,
max_atoms=max_atoms,
min_atoms=min_atoms,
same_number=same_number)
if add:
kvp = {}
for key in self.keys:
if key == "space_group_number":
kvp[key] = xtal.group.number
elif key == "density":
kvp[key] = xtal.get_density()
elif key == "dof":
kvp[key] = xtal.get_dof()
elif key == "wps":
kvp[key] = str([s.wp.get_label()
for s in xtal.atom_sites])
elif key == "pearson_symbol":
kvp[key] = xtal.get_Pearson_Symbol()
elif hasattr(row, key):
kvp[key] = getattr(row, key)
self.db.write(atoms, key_value_pairs=kvp)
count += 1
if count % freq == 0:
print(f"Adding {count:4d} strucs from {db_file:s}")
if max_count is not None and count == max_count:
break
else:
print("Fail to convert xtal")
[docs]
def check_new_structure(self, xtal, eng=None, same_group=False,
same_number=False, d_tol=2e-1, e_tol=1e-2,
max_atoms=250, min_atoms=0, return_id=False):
"""
Check if the input crystal structure already exists in the database.
Args:
xtal: PyXtal object representing the crystal structure to check
eng (float, optional): Energy of the structure to compare
same_group (bool): Whether to only compare structures with same space group
d_tol (float): Tolerance for density comparison
e_tol (float): Tolerance for energy comparison
max_atoms (int): maximum number of atoms for checking
min_atoms (int): minimum number of atoms for checking
Returns:
bool: True if structure is new/unique, False if it matches an existing structure
Note:
Compares structures based on:
- Space group number (if same_group=True)
- Density (within d_tol)
- Energy (within e_tol if provided)
- Structure similarity via pymatgen.analysis.structure_matcher
"""
if sum(xtal.numIons) > max_atoms or sum(xtal.numIons) < min_atoms:
return True
s_pmg = xtal.to_pymatgen()
for row in self.db.select(sort='-id'): # Sort by id in descending order
if row.natoms > max_atoms or row.natoms < min_atoms:
continue
if same_number and row.natoms != len(s_pmg):
continue
if eng is not None and abs(eng - row.mace_energy) > e_tol:
continue
if same_group and row.space_group_number != xtal.group.number:
continue
if abs(row.density - xtal.get_density()) > d_tol:
continue
ref = self.db.get_atoms(id=row.id)
ref_pmg = ase2pymatgen(ref)
if self.matcher.fit(s_pmg, ref_pmg, symmetric=True):
print("skip the duplicate", xtal.get_xtal_string())
#print(row.id, row.space_group_number, row.wps, row.mace_energy, row.density)
if return_id:
return False, row.id
else:
return False
if return_id:
return True, None
else:
return True
[docs]
def clean_structures_spg_topology(self, dim=None):
"""
Clean up the db by removing duplicate structures based on their properties.
Args:
dim (int, optional): Filter structures by dimension. Only keep structures with this
dimension if specified. Defaults to None.
The function removes structures that have identical:
- Number of atoms
- Space group
- Topology
- Wyckoff positions (wps)
"""
unique_rows = []
to_delete = []
for row in self.db.select():
unique = True
# Ignore unwanted dimension
if dim is not None and hasattr(row, "dimension") and row.dimension != dim:
# print(row.dimension, dim)
unique = False
else:
for prop in unique_rows:
(natoms, spg, wps, topology) = prop
if (natoms == row.natoms and spg == row.space_group_number and wps == row.wps) and hasattr(
row, "topology"
):
if row.topology == "aaa":
if row.topology_detail == topology:
unique = False
break
elif row.topology == topology:
unique = False
break
if unique:
if hasattr(row, "topology"):
unique_rows.append(
(
row.natoms,
row.space_group_number,
row.wps,
row.topology if row.topology != "aaa" else row.topology_detail,
)
)
else:
unique_rows.append(
(row.natoms, row.space_group_number, row.wps, None))
else:
to_delete.append(row.id)
if len(to_delete) > 0:
print(len(to_delete), "structures were deleted", to_delete)
self.db.delete(to_delete)
[docs]
def get_row(self, id):
for row in self.db.select(id=id):
return row
raise RuntimeError(msg)
[docs]
def clean_structures(self, ids=(None, None), dtol=2e-3, etol=1e-3, criteria=None, eng_key='mace_energy'):
"""
Clean up the db by removing the duplicate structures
Here we check the follow criteria
- same number of atoms
- same density
- same energy
Args:
dtol (float): tolerance of density
etol (float): tolerance of energy
criteria (dict): including
"""
unique_rows = []
to_delete = []
ids, xtals = self.select_xtals(ids)
for id, xtal in zip(ids, xtals):
row = self.db.get(id)
xtal = self.get_pyxtal(id)
unique = True
if criteria is not None:
if not xtal.check_validity(criteria, True):
unique = False
print(
"Found unsatisfied criteria",
row.id,
row.space_group_number,
row.wps,
)
if unique and (
"MAX_energy" in criteria and hasattr(
row, eng_key) and getattr(row, key) > criteria["MAX_energy"]
):
unique = False
print(
"Unsatisfied energy",
row.id,
getattr(row, eng_key),
row.space_group_number,
row.wps,
)
if unique and (
"MAX_similarity" in criteria
and hasattr(row, "similarity")
and row.similarity > criteria["MAX_similarity"]
):
unique = False
print(
"Unsatisfied similarity",
row.id,
row.similarity,
row.space_group_number,
row.wps,
)
if unique and (
"BAD_topology" in criteria
and hasattr(row, "topology")
and row.topology[:3] in criteria["BAD_topology"]
):
unique = False
print(
"Unsatisfied topology",
row.id,
row.topology,
row.space_group_number,
row.wps,
)
if unique and (
"BAD_dimension" in criteria
and hasattr(row, "dimension")
and row.dimension in criteria["BAD_dimension"]
):
unique = False
print(
"Unsatisfied dimension",
row.id,
row.topology,
row.space_group_number,
row.wps,
)
if unique:
for prop in unique_rows:
(natoms, spg, wps, den, energy) = prop
if natoms == row.natoms and spg == row.space_group_number and wps == row.wps:
if hasattr(row, eng_key) and energy is not None:
if abs(getattr(row, eng_key) - energy) < etol:
unique = False
break
elif abs(den - row.density) < dtol:
unique = False
break
if unique:
if hasattr(row, eng_key):
unique_rows.append(
(
row.natoms,
row.space_group_number,
row.wps,
row.density,
row.ff_energy,
)
)
else:
unique_rows.append(
(row.natoms, row.space_group_number, row.wps, row.density, None))
else:
to_delete.append(row.id)
print(len(to_delete), "structures were deleted", to_delete)
self.db.delete(to_delete)
[docs]
def clean_structures_pmg(self, ids=(None, None), min_id=None, dtol=5e-2, criteria=None):
"""
Clean up the database by removing duplicate structures based on density and pymatgen matcher.
This method checks for duplicates by comparing structure density within a tolerance and using
pymatgen's StructureMatcher. It can also filter structures based on various criteria like
coordination numbers, energies, topology, etc.
ids (tuple, optional): Range of IDs (min, max) to process. Defaults to (None, None).
min_id (int, optional): Minimum ID to consider. Structures with lower IDs won't be deleted.
Defaults to None.
dtol (float, optional): Density tolerance for comparing structures. Defaults to 5e-2.
criteria (dict, optional): Dictionary of filtering criteria. Defaults to None.
Supported criteria keys:
- 'CN': Dict of required coordination numbers per element
- 'cutoff': Float, cutoff distance for connectivity
- 'MAX_energy': Float, maximum allowed energy
- 'MAX_similarity': Float, maximum allowed similarity value
- 'BAD_topology': List of forbidden topology types
- 'BAD_dimension': List of forbidden dimensionality values
Example criteria:
{
'CN': {'C': 3},
'BAD_dimension': [0, 2]
}
Returns:
None. Modifies database in place by deleting duplicate/invalid structures.
"""
unique_rows = []
to_delete = []
ids, xtals = self.select_xtals(ids)
if min_id is None:
min_id = min(ids)
for id, xtal in zip(ids, xtals):
row = self.db.get(id)
xtal = self.get_pyxtal(id)
unique = True
if id > min_id and criteria is not None:
if not xtal.check_validity(criteria, True):
unique = False
print(
"Found unsatisfied criteria",
row.id,
row.space_group_number,
row.wps,
)
if unique and (
"MAX_energy" in criteria and hasattr(
row, "ff_energy") and row.ff_energy > criteria["MAX_energy"]
):
unique = False
print(
"Unsatisfied energy",
row.id,
row.ff_energy,
row.space_group_number,
row.wps,
)
if unique and (
"MAX_similarity" in criteria
and hasattr(row, "similarity")
and row.similarity > criteria["MAX_similarity"]
):
unique = False
print(
"Unsatisfied similarity",
row.id,
row.similarity,
row.space_group_number,
row.wps,
)
if unique and (
"BAD_topology" in criteria
and hasattr(row, "topology")
and row.topology[:3] in criteria["BAD_topology"]
):
unique = False
print(
"Unsatisfied topology",
row.id,
row.topology,
row.space_group_number,
row.wps,
)
if unique and (
"BAD_dimension" in criteria
and hasattr(row, "dimension")
and row.dimension in criteria["BAD_dimension"]
):
unique = False
print(
"Unsatisfied dimension",
row.id,
row.topology,
row.space_group_number,
row.wps,
)
if unique and id > min_id:
for prop in unique_rows:
(rowid, den) = prop
if abs(den - row.density) < dtol:
ref_pmg = xtal.to_pymatgen()
s_pmg = ase2pymatgen(self.db.get_atoms(id=rowid))
# , symmetric=True):
if self.matcher.fit(s_pmg, ref_pmg):
print(
"Found duplicate",
row.id,
row.space_group_number,
row.wps,
)
unique = False
break
if unique:
unique_rows.append((row.id, row.density))
else:
to_delete.append(row.id)
print(len(to_delete), "structures were deleted", to_delete)
self.db.delete(to_delete)
[docs]
def get_max_id(self):
"""
Get the maximum row id
"""
max_id = None
for row in self.db.select():
if max_id is None or row.id > max_id:
max_id = row.id + 1
return max_id
[docs]
def select_xtals(self, ids, N_atoms=(None, None), overwrite=False, attribute=None, use_relaxed=None):
"""
Extract xtals based on attribute name.
Args:
ids (tuple): Minimum and maximum row IDs to extract, e.g. (1, 10)
N_atoms (tuple): Minimum and maximum number of atoms to extract, e.g. (2, 100)
overwrite (bool): Whether to overwrite existing entries
attribute (str): Attribute name to check for extraction
use_relaxed (str): Type of relaxed structure to use ('ff_relaxed' or 'vasp_relaxed')
Returns:
tuple: (ids, xtals) where ids is a list of row IDs and xtals is a list of
corresponding pyxtal objects
"""
(min_id, max_id) = ids
if min_id is None: min_id = 1
if max_id is None: max_id = self.get_max_id()
(min_atoms, max_atoms) = N_atoms
if min_atoms is None: min_atoms = 1
if max_atoms is None: max_atoms = 5000
ids, xtals = [], []
for row in self.db.select():
if attribute is None or (overwrite and not hasattr(row, attribute)):
if min_id <= row.id <= max_id and min_atoms < row.natoms <= max_atoms:
xtal = self.get_pyxtal(row.id, use_relaxed)
ids.append(row.id)
xtals.append(xtal)
if len(xtals) % 100 == 0:
print("Loading xtals from db", len(xtals))
return ids, xtals
[docs]
def select_xtal(self, ids, N_atoms=(None, None), overwrite=False, attribute=None, use_relaxed=None):
"""
Lazy extraction of selected xtals from the database.
Args:
ids (tuple): Minimum and maximum row IDs to extract, e.g. (1, 10)
N_atoms (tuple): Minimum and maximum number of atoms to extract, e.g. (2, 100)
overwrite (bool): Whether to overwrite existing entries
attribute (str): Attribute name to check for extraction
use_relaxed (str): Type of relaxed structure to use ('ff_relaxed' or 'vasp_relaxed')
Yields:
tuple: (id, xtal) where id is the row ID and xtal is the corresponding pyxtal object
"""
(min_id, max_id) = ids
if min_id is None: min_id = 1
if max_id is None: max_id = self.get_max_id()
(min_atoms, max_atoms) = N_atoms
if min_atoms is None: min_atoms = 1
if max_atoms is None: max_atoms = 5000
ids, xtals = [], []
for row in self.db.select():
if not overwrite and hasattr(row, attribute):
continue
#print(attribute, overwrite, hasattr(row, attribute), getattr(row, attribute))
id, natoms = row.id, row.natoms
if min_id <= id <= max_id and \
min_atoms < natoms <= max_atoms \
and id % self.size== self.rank:
xtal = self.get_pyxtal(id, use_relaxed)
yield id, xtal
[docs]
def update_row_energy(
self,
calculator='GULP',
ids=(None, None),
N_atoms=(None, None),
ncpu=1,
criteria=None,
symmetrize=False,
overwrite=False,
write_freq=100,
ff_lib='reaxff',
steps=250,
fmax=0.1,
use_relaxed=None,
cmd=None,
calc_folder=None,
skf_dir=None,
):
"""
Update the row energy in the database for a given calculator.
Args:
calculator (str): 'GULP', 'MACE', 'VASP', 'DFTB'
ids (tuple): A tuple specifying row IDs to update (e.g., (0, 100)).
ncpu (int): number of parallel processes
criteria (dict, optional): Criteria when selecting structures.
symmetrize (bool): symmetrize the structure before calculation
overwrite (bool): overwrite the existing energy attributes.
write_freq (int): frequency to update db for ncpu=1
ff_lib (str): Force field to use for GULP ('reaxff' by default).
steps (int): Number of optimization steps for DFTB (default is 250).
fmax (float): force tolerance for mace (defalut is 0.1)
use_relaxed (str, optional): Use relaxed structures (e.g. 'ff_relaxed')
cmd (str, optional): Command for VASP calculations
calc_folder (str, optional): calc_folder for GULP/VASP calculations
skf_dir (str, optional): Directory for DFTB potential files
Functionality:
Using the selected calculator, it updates the energy rows of the
database. If `ncpu > 1`, run in parallel; otherwise in serial.
Calculator Options:
- 'GULP': Uses a force field (e.g., 'reaxff').
- 'MACE': Uses the MACE calculator.
- 'DFTB': Uses DFTB+ with symmetrization options.
- 'VASP': Uses VASP, with a specified command (`cmd`).
"""
label = calculator.lower() + "_energy"
if calculator == 'GULP':
label = 'ff_energy'
if calc_folder is None:
calc_folder = calculator.lower() + "_calc"
if calculator != 'MACE':
#self.logging.info("make new folders", calc_folder, os.getpwd())
os.makedirs(calc_folder, exist_ok=True)
# Generate structures for calculation
generator = self.select_xtal(ids, N_atoms, overwrite, label, use_relaxed)
# Set up arguments for the chosen calculator
args_up = []
if calculator == 'GULP':
args = [calculator, ff_lib, calc_folder, criteria]
args_up = [ff_lib]
elif calculator == 'MACE':
args = [calculator, steps, fmax, criteria]
elif calculator == 'DFTB':
args = [calculator, skf_dir, steps, symmetrize, criteria]
elif calculator == 'VASP':
args = [calculator, calc_folder, cmd, criteria]
else:
raise ValueError(f"Unsupported calculator: {calculator}")
# Perform calculation serially or in parallel
self.logging.info(f"Rank-{self.rank} row_energy {calculator} {self.db_name}")
if ncpu == 1:
self.update_row_energy_serial(generator, write_freq, args, args_up)
else:
self.update_row_energy_mproc(ncpu, generator, args, args_up)
self.logging.info(f"Rank-{self.rank} complete update_row_energy")
[docs]
def update_row_energy_serial(self, generator, write_freq, args, args_up):
"""
Perform a serial update of row energies
Args:
generator (generator): Yielding tuples of (id, xtal), where:
- `id` (int): Unique identifier for the structure.
- `xtal` (object): pyxtal instance.
write_freq (int): Frequency to update the database.
args (list): Additional arguments to the function `opt_single`.
args_up (list): Additional arguments for function `_update_db`.
Functionality:
It iterates over structures provided by `generator`,
optimizes them using `opt_single`, and collects results that have
converged (`status == True`). Once the number of results
reaches `write_freq`, it updates the database.
"""
results = []
for id, xtal in generator:
self.logging.info(f"Processing {id} {xtal.lattice} {args[0]}")
print(f"Processing {id} {xtal.lattice} {args[0]}")
res = opt_single(id, xtal, *args)
(xtal, eng, status) = res
if status:
results.append((id, xtal, eng))
if len(results) >= write_freq:
self._update_db(results, args[0], *args_up)
results = []
self.print_memory_usage()
if len(results) > 0:
self._update_db(results, args[0], *args_up)
[docs]
def update_row_energy_mproc(self, ncpu, generator, args, args_up):
"""
Perform parallel row energy updates by optimizing atomic structures.
Args:
ncpu (int): Number of CPUs to use for parallel processing.
generator (generator): yielding tuples of (id, xtal), where:
- `id` (int): Unique identifier for the structure.
- `xtal` (object): pyxtal instance.
args (list): Additional arguments passed to `call_opt_single`.
- Typically includes a calculator or potential parameters.
args_up (list): Additional arguments for function `_update_db`.
Functionality:
This function distributes the structures across multiple CPUs
using `multiprocessing.Pool`. It creates chunks (based on `ncpu`),
and process them in parallel by calling `call_opt_single`.
Successful results are periodically written to the database.
The function also prints memory usage after each database update.
Parallelization Process:
- The `Pool` is initialized with `ncpu` processes.
- Structures are divided into chunks with the `chunkify` function.
- Each chunk is processed by `call_opt_single` via the pool.
- Successful results are periodically written to the database.
- The pool is closed and joined after processing is complete.
"""
from multiprocessing import Pool
self.logging.info(f"Parallel optimizations {ncpu}")
pool = Pool(processes=ncpu,
initializer=setup_worker_logger,
initargs=(self.log_file,))
def chunkify(generator, chunk_size):
chunk = []
for item in generator:
chunk.append(item)
if len(chunk) == chunk_size:
yield chunk
chunk = []
if chunk:
yield chunk
for chunk in chunkify(generator, ncpu*10):
myargs = []
for _id, xtal in chunk:
if xtal is not None:
myargs.append(tuple([_id, xtal] + args))
results = []
self.logging.info(f"Start minicycle: {myargs[0][0]}-{myargs[-1][0]}")
for result in pool.imap_unordered(call_opt_single,
myargs,
chunksize=1):
if result is not None:
(myid, xtal, eng) = result
if eng is not None:
results.append(result)
numIons = sum(xtal.numIons)
count = len(results)
self.logging.info(f"Add {myid:4d} {eng:.3f} *{numIons} {count}")
# Only do frequent update for slow calculator VASP
if len(results) >= ncpu and args[0] == 'VASP':
self._update_db(results, args[0], *args_up)
self.logging.info(f"Finish minibatch: {len(results)}")
self.print_memory_usage()
results = []
self.logging.info(f"Done minicycle: {myargs[0][0]}-{myargs[-1][0]}")
# After the loop, handle the remaining results
if results:
self.logging.info(f"Start Update db: {len(results)}")
self._update_db(results, args[0], *args_up)
self.logging.info(f"Finish Update db: {len(results)}")
pool.close()
pool.join()
def _update_db(self, results, calc, *args):
"""
Update db with the calculation_results
https://wiki.fysik.dtu.dk/ase/ase/db/db.html#writing-and-updating-many-rows-efficiently
Args:
results: list of (id, xtal, eng) tuples
calc (str): calculator
"""
#self.logging.info(f"====================Update db: {len(results)}")
if calc == 'GULP':
ff_lib = args[0]
with self.db:
for result in results:
(id, xtal, eng) = result
if xtal is not None:
if calc == 'GULP':
self.db.update(id,
ff_energy=eng,
ff_lib=ff_lib,
ff_relaxed=xtal.to_file())
elif calc == 'MACE':
self.db.update(id,
mace_energy=eng,
mace_relaxed=xtal.to_file())
elif calc == 'VASP':
self.db.update(id,
vasp_energy=eng,
vasp_relaxed=xtal.to_file())
elif calc == 'DFTB':
self.db.update(id,
dftb_energy=eng,
dftb_relaxed=xtal.to_file())
#self.logging.info(f'update_db_{calc}, {id}')
[docs]
def update_row_topology(self, StructureType="Auto", overwrite=True, prefix=None, ref_dim=3, timeout=60):
"""
Update row topology using CrystalNets.jl via subprocess (faster than juliacall).
Args:
StructureType (str): Type of structure to analyze. Options are:
- 'Zeolite': For zeolite structures
- 'MOF': For metal-organic frameworks
- 'Auto': For automatic detection
overwrite (bool): Whether to overwrite existing topology attributes.
prefix (str): Prefix for temporary CIF files.
ref_dim (int): Reference dimensionality to compare against.
timeout (int): Timeout in seconds for each Julia call. Default is 60.
"""
import subprocess
import json
import os
from time import time
# Create Julia script for batch topology processing (handles multiple ARGS)
julia_script = f"""
using CrystalNets
using JSON
CrystalNets.toggle_warning(false)
CrystalNets.toggle_export(false)
structure_type = "{StructureType}"
if structure_type == "Zeolite"
option = CrystalNets.Options(structure=CrystalNets.StructureType.Zeolite)
elseif structure_type == "MOF"
option = CrystalNets.Options(structure=CrystalNets.StructureType.MOF)
else
option = CrystalNets.Options(structure=CrystalNets.StructureType.Auto)
end
function process_one(cif_file)
try
result = CrystalNets.determine_topology(cif_file, option)
output = []
results_list = length(result) > 1 ? collect(result) : [result[1]]
for res in results_list
name = string(res[1])
count = res[2]
genome = res[1][CrystalNets.Clustering.Auto]
dim = CrystalNets.ndims(CrystalNets.PeriodicGraph(genome))
push!(output, Dict("dim" => dim, "name" => name, "count" => count))
end
return output
catch e
return Dict("error" => string(e))
end
end
# Process all input files and return an array aligned to ARGS
function process_batch(files)
results = Vector{{Any}}()
for f in files
push!(results, process_one(f))
end
return results
end
if length(ARGS) > 0
println(JSON.json(process_batch(ARGS)))
end
"""
# Save Julia script
script_path = prefix + "_process_topology.jl"
with open(script_path, "w") as f:
f.write(julia_script)
def parse_topology(topology_list):
"""Parse topology list to get dimension, name, and detail"""
dim = 0
name = ""
detail = "None"
for i, topo in enumerate(topology_list):
d = topo["dim"]
n = topo["name"]
if d > dim:
dim = d
tmp = n.split(",")[0]
if tmp.startswith("UNKNOWN"):
detail = tmp[7:]
tmp = "aaa"
elif tmp.startswith("unstable"):
tmp = "unstable"
name += tmp
if topo["count"] > 1:
name += f"({topo['count']})"
if i + 1 < len(topology_list):
name += "-"
return dim, name, detail
# Collect rows to process
rows_to_process = []
for row in self.db.select():
if overwrite or not hasattr(row, "topology"):
rows_to_process.append(row.id)
if len(rows_to_process) == 0:
self.logging.info("No rows to process for topology update")
return
self.logging.info(f"Processing {len(rows_to_process)} structures for topology")
# Process structures in batches of 100: write CIFs once, call Julia once per batch
updates = []
batch_size = 100
def write_cifs_for_batch(batch_ids):
files = []
for row_id in batch_ids:
atoms = self.db.get_atoms(row_id)
cif_file = f"{prefix}_{row_id}.cif" if prefix is not None else f"tmp_{row_id}.cif"
atoms.write(cif_file, format="cif", parallel=False)
files.append(cif_file)
return files
def cleanup_files(files):
for f in files:
try:
if os.path.exists(f):
os.remove(f)
except Exception:
pass
# Loop over rows_to_process in chunks
for start in range(0, len(rows_to_process), batch_size):
end = min(start + batch_size, len(rows_to_process))
batch_ids = rows_to_process[start:end]
# 1) Write all CIFs for the batch
cif_files = write_cifs_for_batch(batch_ids)
# 2) Call Julia once with all file paths as arguments
try:
t0 = time()
result = subprocess.run(
["julia", script_path, *cif_files],
capture_output=True,
text=True,
timeout=timeout
)
elapsed = time() - t0
if result.returncode != 0:
self.logging.warning(f"Julia failed for batch {start}-{end}: {result.stderr}")
# Mark all as errors in this batch
for row_id in batch_ids:
updates.append((row_id, "error", 3, "julia_error"))
else:
# Expect JSON array of results aligned with inputs
try:
batch_output = json.loads(result.stdout.strip())
except Exception as e:
self.logging.warning(f"Failed to parse JSON for batch {start}-{end}: {e}")
for row_id in batch_ids:
updates.append((row_id, "error", 3, "json_error"))
cleanup_files(cif_files)
continue
# If single file, CrystalNets returns object; normalize to list
if isinstance(batch_output, dict) and ("error" in batch_output or "dim" in batch_output):
batch_output = [batch_output]
# If Julia printed one JSON object per line, split and parse
if not isinstance(batch_output, list):
lines = [line for line in result.stdout.splitlines() if line.strip()]
batch_output = []
for line in lines:
try:
batch_output.append(json.loads(line))
except Exception:
batch_output.append({"error": "line_parse_error"})
# 3) Map results back to row ids
print(f"Batch output length: {len(batch_output)}, expected: {len(batch_ids)}, elapsed: {elapsed:.2f}s")
for idx, row_id in enumerate(batch_ids):
#print(f"Processing result for row {row_id}, index {idx}, {batch_output[idx]}")
if idx >= len(batch_output):
updates.append((row_id, "error", 3, "missing_output"))
continue
out = batch_output[idx]
# out should be a list of topo dicts for this file; handle both list/dict
topo_list = out if isinstance(out, list) else [out]
dim, name, detail = parse_topology(topo_list)
if name.startswith("FAILED"):
name = '0-dimensional'
# Optional verbose line for matches
if dim == ref_dim:
row = self.db.get(row_id)
print(f"Row {row_id}: {row.space_group_number} {row.wps} dim={dim} {name}")
updates.append((row_id, name, dim, detail))
except subprocess.TimeoutExpired:
self.logging.warning(f"Timeout for batch {start}-{end} after {timeout}s")
for row_id in batch_ids:
updates.append((row_id, "timeout", 3, "timeout"))
except Exception as e:
self.logging.warning(f"Error processing batch {start}-{end}: {e}")
for row_id in batch_ids:
updates.append((row_id, "error", 3, str(e)[:100]))
finally:
# 4) Clean up CIF files for this batch
cleanup_files(cif_files)
# 5) Write batch results to DB
if updates:
self.logging.info(f"Batch updating {len(updates)} rows")
with self.db:
for (rid, tname, tdim, tdetail) in updates:
self.db.update(rid, topology=tname, dimension=tdim, topology_detail=tdetail)
updates = []
# Clean up Julia script
if os.path.exists(script_path):
os.remove(script_path)
self.logging.info(f"Completed topology update for {len(rows_to_process)} structures")
[docs]
def update_db_description(self):
"""
Update database description using robocrys.
Uses robocrystallographer (https://github.com/hackingmaterials/robocrystallographer)
to generate natural language descriptions of crystal structures.
For each row in the database that doesn't have a description:
1. Converts ASE atoms to pymatgen structure
2. Uses StructureCondenser to analyze bonding/connectivity
3. Uses StructureDescriber to generate text description
4. Updates the database row with the description
Note:
Use it with caution, as it may take a long time to run.
"""
from robocrys import StructureCondenser, StructureDescriber
condenser = StructureCondenser()
describer = StructureDescriber()
for row in self.db.select():
if not hasattr(row, "description"):
atoms = self.db.get_atoms(row.id)
pmg = ase2pymatgen(atoms)
try:
condensed_structure = condenser.condense_structure(pmg)
description = describer.describe(condensed_structure)
except:
description = "N/A"
self.db.update(row.id, description=description)
print("\n======Updating\n", description)
else:
print("\n======Existing\n", row.description)
[docs]
def export_structures(
self,
fmt="vasp",
folder="mof_out",
criteria=None,
sort_by="similarity",
overwrite=True,
cutoff=None,
use_relaxed=None,
):
"""
Export structures from database according to given criteria.
Args:
fmt (str): Output format (``vasp`` or ``cif``)
folder (str): Path to output folder
criteria (dict): Dictionary of validity criteria
sort_by (str): Attribute to sort structures by
overwrite (bool): Whether to remove existing output folder
cutoff (int): Maximum number of structures to export
use_relaxed (str, optional): e.g., ``ff_relaxed`` or ``vasp_relaxed``
"""
import shutil
if cutoff is None:
cutoff = self.db.count()
if not os.path.exists(folder):
os.makedirs(folder)
else:
if overwrite:
shutil.rmtree(folder)
os.makedirs(folder)
keys = [
"id",
"pearson_symbol",
"space_group_number",
"density",
"dof",
"similarity",
"ff_energy",
"vasp_energy",
"mace_energy",
"topology",
]
properties = []
for row in self.db.select():
spg = row.space_group_number
den = row.density
dof = row.dof
ps = row.pearson_symbol
sim = float(row.similarity) if hasattr(
row, "similarity") and row.similarity is not None else None
top = row.topology if hasattr(row, "topology") else None
ff_eng = float(row.ff_energy) if hasattr(
row, "ff_energy") else None
vasp_eng = float(row.vasp_energy) if hasattr(
row, "vasp_energy") else None
mace_eng = float(row.mace_energy) if hasattr(
row, "mace_energy") else None
properties.append([row.id, ps, spg, den, dof,
sim, ff_eng, vasp_eng, mace_eng,
top])
dicts = {}
for i, key in enumerate(keys):
if properties[0][i] is not None:
dicts[key] = [prop[i] for prop in properties]
if sort_by in keys:
col = keys.index(sort_by) # + 1
else:
print("supported attributes", keys)
raise ValueError("Cannot sort by", sort_by)
print(f"====Exporting {len(properties)} structures")
properties = [prop for prop in properties if prop[col] is not None]
sorted_properties = sorted(properties, key=lambda x: x[col])
for entry in sorted_properties[:cutoff]:
[id, ps, spg, den, dof, sim, ff_eng, vasp_eng, mace_eng, top] = entry
id = int(id)
spg = int(spg)
sim = float(sim)
den = float(den)
dof = int(dof)
if vasp_eng is not None:
eng = float(vasp_eng)
elif mace_eng is not None:
eng = float(mace_eng)
elif ff_eng is not None:
eng = float(ff_eng)
else:
eng = None
if True:
#try:
xtal = self.get_pyxtal(id, use_relaxed)
number, symbol = xtal.group.number, xtal.group.symbol.replace(
"/", "")
# convert to the desired subgroup representation if needed
#if number != spg:
# paths = xtal.group.path_to_subgroup(spg)
# xtal = xtal.to_subgroup(paths)
# number, symbol = (
# xtal.group.number,
# xtal.group.symbol.replace("/", ""),
# )
label = os.path.join(
folder,
f"{id:d}-{xtal.get_Pearson_Symbol():s}-{number:d}-{symbol:s}",
)
status = xtal.check_validity(
criteria, True) if criteria is not None else True
#except:
# status = False
# label = "Error"
if status:
try:
# if True:
xtal.set_site_coordination()
for s in xtal.atom_sites:
_l, _sp, _cn = s.wp.get_label(), s.specie, s.coordination
label += f"-{_l:s}-{_sp:s}{_cn:d}"
label += f"-S{sim:.3f}"
if len(label) > 40:
label = label[:40]
except:
print("Problem in setting site coordination")
if den is not None:
label += f"-D{abs(den):.2f}"
if eng is not None:
label += f"-E{abs(eng):.3f}"
if top is not None:
label += f"-T{top:s}"
# if sim is not None: label += '-S{:.2f}'.format(sim)
print("====Exporting:", label)
if fmt == "vasp":
xtal.to_file(label + ".vasp", fmt="poscar")
elif fmt == "cif":
xtal.to_file(label + ".cif")
else:
print("====Skippng: ", label)
return dicts
[docs]
def get_label(self, i):
if i < 10:
folder = f"cpu00{i}"
elif i < 100:
folder = f"cpu0{i}"
else:
folder = f"cpu0{i}"
return folder
[docs]
def get_db_unique(self, db_name=None, prec=3, key='ff_energy', max_N_atoms=64):
"""
Get a database file containing only unique structures based on topology and energy.
Args:
db_name (str, optional): Filename for the new database.
If None, will use original name with '_unique' suffix.
prec (int, optional): Precision for rounding energy values. Default is 3.
key (str, optional): Energy attribute name to use for filtering.
Default is 'ff_energy'.
max_N_atoms (int, optional): Maximum n_atoms for pmg match. Default is 64.
Returns:
int: Number of unique structures in the new database.
Note:
Two structures are considered identical if they have:
- Same density value (within precision)
- Same energy value (within precision)
- Pymatgen match
When duplicates are found, the structure with lower DOF is kept.
"""
from pymatgen.analysis.structure_matcher import StructureMatcher
matcher = StructureMatcher(stol=0.3, ltol=0.2, angle_tol=5)
print(f"The {self.db_name:s} has {self.db.count():d} strucs")
if db_name is None:
db_name = self.db_name[:-3] + "_unique.db"
if os.path.exists(db_name):
os.remove(db_name)
lists = []
for row in self.db.select():
if hasattr(row, key) and getattr(row, key) is not None:
dof, den, energy = row.dof, round(row.density, prec), round(getattr(row, key), prec)
spg, wps = row.space_group_number, row.wps
is_unique = True
pmg = ase2pymatgen(row.toatoms())
list_entry = (row.id, dof, den, energy, spg, wps, pmg)
for list_entry_existing in lists:
(_id, _dof, _den, _energy, _spg, _wps, _pmg) = list_entry_existing
if den == _den and energy == _energy:
# check pymatgen match
if len(_pmg) > max_N_atoms or len(pmg) > max_N_atoms:
if spg == _spg and wps == _wps:
is_unique = False
print("Duplicate", row.id, den, energy)
break
# for large structures, skip pymatgen match to save time
else:
if matcher.fit(pmg, _pmg):
is_unique = False
if dof < _dof:
print("Updating", row.id, den, energy)
lists.remove(list_entry_existing)
lists.append(list_entry)
else:
print("Duplicate", row.id, den, energy)
break
if is_unique:
print("Adding", row.id, den, energy)
lists.append(list_entry)
ids = [entry[0] for entry in lists]
with connect(db_name, serial=True) as db:
for id in ids:
row = self.db.get(id)
kvp = {}
for key in self.keys:
if hasattr(row, key):
kvp[key] = getattr(row, key)
db.write(row.toatoms(), key_value_pairs=kvp)
print(f"Created {db_name:s} with {db.count():d} strucs")
return db.count()
[docs]
def get_db_unique_topology(self, db_name=None, prec=3, update_topology=True, key='ff_energy'):
"""
Get a database file containing only unique structures based on topology and energy.
Args:
db_name (str, optional): Filename for the new database.
If None, will use original name with '_unique' suffix.
prec (int, optional): Precision for rounding energy values. Default is 3.
update_topology (bool, optional): Whether to update topology before filtering.
Default is True.
key (str, optional): Energy attribute name to use for filtering.
Default is 'ff_energy'.
Returns:
int: Number of unique structures in the new database.
Note:
Two structures are considered identical if they have:
- Same topology
- Same topology detail
- Same energy value (within precision)
When duplicates are found, the structure with lower DOF is kept.
"""
print(f"The {self.db_name:s} has {self.db.count():d} strucs")
if db_name is None:
db_name = self.db_name[:-3] + "_unique.db"
if os.path.exists(db_name):
os.remove(db_name)
unique_props = {} # Using a dictionary to store unique properties
if update_topology: self.update_row_topology()
for row in self.db.select():
if hasattr(row, key) and getattr(row, key) is not None:
top, top_detail = row.topology, row.topology_detail
dof, energy = row.dof, round(getattr(row, key), prec)
prop_key = (top, top_detail, energy)
# A dictionary lookup
if prop_key in unique_props:
_id, _dof = unique_props[prop_key]
if dof < _dof:
print("Updating", row.id, top, energy)
unique_props[prop_key] = (row.id, dof)
else:
print("Duplicate", row.id, top, energy)
else:
print("Adding", row.id, top, energy)
unique_props[prop_key] = (row.id, dof)
ids = [unique_props[key][0] for key in unique_props.keys()]
with connect(db_name, serial=True) as db:
for id in ids:
row = self.db.get(id)
kvp = {}
for key in self.keys:
if hasattr(row, key):
kvp[key] = getattr(row, key)
db.write(row.toatoms(), key_value_pairs=kvp)
print(f"Created {db_name:s} with {db.count():d} strucs")
return db.count()
[docs]
def check_overlap(self, reference_db, etol=2e-3, verbose=True):
"""
Check the overlap with a reference database.
Args:
reference_db (str): Path to the reference database file
etol (float, optional): Energy tolerance for identifying identical structures.
Default is 2e-3.
verbose (bool, optional): Whether to print detailed overlap information.
Default is True.
Returns:
list: List of overlapping structures, where each entry contains:
(id, pearson_symbol, dof, topology, ff_energy)
Note:
Two structures are considered overlapping if they have:
- Same topology
- Same topology detail
- Force field energies within etol of each other
"""
db_ref = database_topology(reference_db, log_file = self.log_file)
print(f"\nCurrent database {self.db_name}: {self.db.count()}")
print(f"Reference database {db_ref.db_name}: {db_ref.db.count()}")
ref_data = []
for row in db_ref.db.select():
if hasattr(row, "topology") and hasattr(row, "ff_energy"):
ref_data.append(
(row.topology, row.topology_detail, row.ff_energy))
overlaps = []
for row in self.db.select():
if hasattr(row, "topology") and hasattr(row, "ff_energy"):
for ref in ref_data:
(top, top_detail, ff_energy) = ref
if (
row.topology == top
and row.topology_detail == top_detail
and abs(row.ff_energy - ff_energy) < etol
):
# strs = 'Find {:4d} {:6s}'.format(row.id, row.pearson_symbol)
# strs += ' {:12s} {:10.3f}'.format(row.topology, row.ff_energy)
# print(strs)
overlaps.append(
(
row.id,
row.pearson_symbol,
row.dof,
row.topology,
row.ff_energy,
)
)
break
strs = f"\nThe number of overlap is: {len(overlaps):d}"
strs += f"/{self.db.count():d}/{db_ref.db.count():d}"
print(strs)
sorted_overlaps = sorted(overlaps, key=lambda x: x[-1])
if verbose:
for entry in sorted_overlaps:
print("{:4d} {:6s} {:4d} {:20s} {:10.3f}".format(*entry))
return overlaps
[docs]
def print_info(self, excluded_ids=None, cutoff=100):
"""
Print out the summary of the database based on the calculated energy
Mostly used to quickly view the most interesting low-energy structures.
Todo: show vasp_energy if available
Args:
excluded_ids (list): list of unwanted row ids
cutoff (int): the cutoff value for the print
"""
if excluded_ids is None:
excluded_ids = []
print(f"\nCurrent database {self.db_name}: {self.db.count()}")
output = []
for row in self.db.select():
if row.id not in excluded_ids and hasattr(row, "topology") and hasattr(row, "ff_energy"):
output.append(
(
row.id,
row.pearson_symbol,
row.dof,
row.topology,
row.ff_energy,
)
)
sorted_output = sorted(output, key=lambda x: x[-1])
for entry in sorted_output[:cutoff]:
print("{:4d} {:6s} {:4d} {:20s} {:10.3f}".format(*entry))
strs = f"Showed structures: {len(sorted_output)}/{self.db.count()}"
print(strs)
[docs]
def plot_histogram(self, prop, ax=None, filename=None, xlim=None, nbins=20):
"""
Plot the histogram of a specified row property.
Args:
prop (str): The name of the property to plot (e.g., 'ff_energy').
ax (matplotlib.axes.Axes, optional): Pre-existing axis to plot on.
If None, a new ax will be created.
filename (str, optional): Path to save the plot (e.g., 'plot.png').
If None, the plot will not be saved.
xlim (tuple, optional): Limits for the x-axis (e.g., (0, 10)).
If None, the x-axis will scale automatically.
nbins (int, optional): Number of bins for the histogram. Default is 20.
Returns:
matplotlib.axes.Axes: The axis object with the histogram plotted.
"""
import matplotlib.pyplot as plt
if ax is None:
f, ax = plt.subplots()
# Get the properties from the database
props = self.get_properties(prop)
# Check if there are values to plot
if not props:
raise ValueError(f"No rows contain the property '{prop}'.")
ax.hist(props, nbins, density=True, alpha=0.75)
# Set x-axis limits if provided
if xlim is not None:
ax.set_xlim(xlim)
ax.set_xlabel(prop)
# Save the plot if a filename is provided
if filename is not None:
plt.savefig(filename)
return ax
[docs]
def get_properties(self, prop):
"""
Retrieve a list of specific property values from the database rows.
Args:
prop (str): The property name to retrieve (e.g., 'ff_energy')
Returns:
list: A list of property values for rows that have the specified property.
If a row does not contain the property, it is ignored.
Raises:
Warning: If no rows in the database contain the specified property.
"""
props = []
# Loop through all rows in the database and collect the property values
for row in self.db.select():
if hasattr(row, prop):
props.append(getattr(row, prop))
# Print summary of rows
name, count = self.db_name, self.db.count()
print(f"Database {name} has {prop}: {len(props)}/{count}")
# Warn if no properties were found
if count == 0:
raise Warning(
f"No rows in the database contain the property '{prop}'.")
return props
if __name__ == "__main__":
# open
if False:
db = database("test.db")
print("Total number of entries", len(db.codes))
# view structure
c = db.get_pyxtal("HXMTAM")
print(c)
if False:
db = database_topology("../MOF-Builder/C-sp2/sp2-sacada-0506.db")
# xtal = db.get_pyxtal(1)
# print(xtal)
# db.add_xtal(xtal, kvp={'similarity': 0.1})
# db.update_row_ff_energy(ids=(0, 2), overwrite=True)
# db.update_row_ff_energy(ncpu=2, ids=(2, 20), overwrite=True)
# brew install coreutils to get timeout in maca
# os.environ['ASE_DFTB_COMMAND'] = 'timeout 1m /Users/qzhu8/opt/dftb+/bin/dftb+ > PREFIX.out'
os.environ["ASE_DFTB_COMMAND"] = "/Users/qzhu8/opt/dftb+/bin/dftb+ > PREFIX.out"
skf_dir = "/Users/qzhu8/GitHub/MOF-Builder/3ob-3-1/"
# db.update_row_dftb_energy(skf_dir, ncpu=1, ids=(0, 2), overwrite=True)
db.update_row_dftb_energy(
skf_dir, ncpu=1, ids=(17, 17), overwrite=True)
db = database_topology("total.db")
db.get_db_unique()
db1 = database_topology("sp2_sacada.db")
db1.get_db_unique()
db = database_topology("total_unique.db")
db.check_overlap("sp2_sacada_unique.db")
db1.export_structures(folder="mof_out_sacada")