Source code for pyxtal.optimize.base

"""
A base class for global optimization including:

- WFS: Width First Sampling
- DFS: Depth First Sampling
- QRS: Quasi Random Sampling
"""
from __future__ import annotations
from multiprocessing import Pool
from concurrent.futures import TimeoutError
import signal

import logging
import os
from time import time
from typing import TYPE_CHECKING

import numpy as np
from pymatgen.analysis.structure_matcher import StructureMatcher
from numpy.random import Generator

from pyxtal.molecule import find_rotor_from_smile, pyxtal_molecule
from pyxtal.representation import representation
from pyxtal.util import new_struc
from pyxtal.optimize.common import optimizer, randomizer
from pyxtal.optimize.common import optimizer_par, optimizer_single
from pyxtal.lattice import Lattice
from pyxtal.symmetry import Group

[docs] def setup_worker_logger(log_file): """ Set up the logger for each worker process. """ logging.getLogger().handlers.clear() logging.basicConfig(format="%(asctime)s| %(message)s", filename=log_file, level=logging.INFO)
# Update run_optimizer_with_timeout to accept a logger
[docs] def run_optimizer_with_timeout(args, logger): """ Run the optimizer with a timeout. This function will be executed by each process. """ def handler(signum, frame): raise TimeoutError("Optimization timed out") # Set the timeout signal cwd = os.getcwd() timeout = int(args[-1]) #logger.info(f"Rank-{args[-2]} entering optimizer_with_timeout") signal.signal(signal.SIGALRM, handler) signal.alarm(timeout) #logger.info(f"Rank-{args[-2]} after signal") try: if args[-2] > 0: logger.info(f"Rank-{args[-2]} running optimizer_par for PID {os.getpid()}") result = optimizer_par(*args[:-2]) if args[-2] > 0: logger.info(f"Rank-{args[-2]} finished optimizer_par for PID {os.getpid()}") signal.alarm(0) # Disable the alarm return result except TimeoutError: logger.info(f"Rank-{args[-2]} Process {os.getpid()} timed out after {timeout} seconds.") os.chdir(cwd) return None # or some other placeholder for timeout results
# Update process_task to accept a logger
[docs] def process_task(args): logger = logging.getLogger() #logger.info(f"Rank {args[-2]} start process_task.") result = run_optimizer_with_timeout(args, logger) return result
[docs] class GlobalOptimize: """ Base-class for all global optimization methods Args: smiles (str): smiles string workdir (str): path of working directory sg (int or list): space group number or list of spg numbers tag (string): job prefix ff_opt (bool): activate on the fly FF mode ff_style (str): automated force style (`gaff` or `openff`) ff_parameters (str or list): ff parameter xml file or list reference_file (str): path of reference xml data for FF training N_cpu (int): number of cpus for parallel calculation (default: `1`) cif (str): cif file name to store all structure information block: block mode num_block: list of blocks compositions: list of composition, (default is [1]*Num_mol) lattice (bool): whether or not supply the lattice torsions: list of torsion angle molecules (list): list of pyxtal_molecule objects sites (list): list of wp sites, e.g., [['4a']] use_hall (bool): whether or not use hall number (default: False) skip_ani (bool): whether or not use ani or not (default: True) eng_cutoff (float): the cutoff energy for FF training E_max (float): maximum energy defined as an invalid structure matcher : structurematcher from pymatgen early_quit: whether quit the program early when the target is found pre_opt: whether pre_optimize the structure or not """ def __init__( self, smiles: str, workdir: str, sg: int | list[int], tag: str, info: dict[any, any] | None = None, ff_opt: bool = False, ff_style: str = "openff", ff_parameters: str = "parameters.xml", reference_file: str = "references.xml", ref_criteria: dict[any, any] | None = None, N_cpu: int = 1, cif: str | None = None, block: list[any] | None = None, num_block: list[any] | None = None, composition: list[any] | None = None, lattice: Lattice | None = None, torsions: list[any] | None = None, molecules: list[pyxtal_molecule] | None = None, sites: list[any] | None = None, use_hall: bool = False, skip_ani: bool = True, factor: float = 1.1, eng_cutoff: float = 5.0, E_max: float = 1e10, random_state=None, max_time: float | None = None, matcher: StructureMatcher | None = None, early_quit: bool = True, check_stable: bool = False, use_mpi: bool = False, pre_opt: bool = False, ): self.ncpu = N_cpu self.use_mpi = use_mpi if self.use_mpi: from mpi4py import MPI self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() else: self.rank = 0 self.size = self.ncpu # General information if isinstance(random_state, Generator): self.random_state = random_state.spawn(1)[0] else: self.random_state = np.random.default_rng(random_state) # Molecular information self.smile = smiles self.smiles = self.smile.split(".") # list self.torsions = torsions self.molecules = molecules self.block = block self.num_block = num_block self.composition = [ 1] * len(self.smiles) if composition is None else composition self.N_torsion = 0 for smi, comp in zip(self.smiles, self.composition): self.N_torsion += len(find_rotor_from_smile(smi) ) * int(max([comp, 1])) # Crystal information self.pre_opt = pre_opt self.sg = [sg] if isinstance(sg, (int, np.int64)) else sg self.use_hall = use_hall self.factor = factor self.sites = sites if lattice is None: self.lattice = lattice elif isinstance(lattice, Lattice): self.lattice = lattice else: ltype = Group(self.sg[0]).lattice_type if len(lattice) == 6: self.lattice = Lattice.from_para(lattice, ltype=ltype) else: raise ValueError("input lattice is invalid", lattice) self.opt_lat = self.lattice is None self.ref_criteria = ref_criteria self.eng_cutoff = eng_cutoff # Generation and Optimization self.workdir = workdir os.makedirs(self.workdir, exist_ok=True) self.log_file = self.workdir + "/loginfo" if self.rank > 0: self.log_file += f"-{self.rank}" self.skip_ani = skip_ani self.check_stable = check_stable if not self.opt_lat: self.check_stable = False # setup timeout for each optimization call self.max_time = max_time if max_time is None: if not self.skip_ani: max_time = 300.0 elif self.check_stable: max_time = 300.0 else: max_time = 60.0 self.timeout = max_time * self.N_pop / self.ncpu self.ff_opt = ff_opt self.ff_style = ff_style # Setup logger logging.getLogger().handlers.clear() logging.basicConfig(format="%(asctime)s| %(message)s", filename=self.log_file, level=logging.INFO) self.logging = logging if info is not None: self.atom_info = info self.parameters = None self.ff_opt = False else: self.ff_parameters = self.workdir + "/" + ff_parameters self.reference_file = self.workdir + "/" + reference_file # Only call ForceFieldParameters once # No need to broadcast self.parameters? # Just broadcast atom_info should be fine # parameters = None atom_info = None if self.rank == 0: from pyocse.parameters import ForceFieldParameters self.parameters = ForceFieldParameters(self.smiles, style=ff_style, ncpu=self.ncpu) if self.ff_opt: self.parameters.set_ref_evaluator('mace') # Preload two set for FF parameters 1 for opt and 2 for refinement if isinstance(self.ff_parameters, list): assert len(self.ff_parameters) == 2 for para_file in self.ff_parameters: if not os.path.exists(para_file): raise RuntimeError("File not found", para_file) params0, dic = self.parameters.load_parameters( self.ff_parameters[0]) if "ff_style" in dic: assert dic["ff_style"] == self.ff_style # print(params0) params1, dic = self.parameters.load_parameters( self.ff_parameters[1]) if "ff_style" in dic: assert dic["ff_style"] == self.ff_style # print(params1) atom_info = self._prepare_chm_info(params0, params1) else: if os.path.exists(self.ff_parameters): self.print("Preload the existing FF parameters from", self.ff_parameters) params0, _ = self.parameters.load_parameters( self.ff_parameters) else: self.print( "No FF parameter file exists, using the default setting", ff_style, ) params0 = self.parameters.params_init.copy() self.parameters.export_parameters(self.ff_parameters, params0) atom_info = self._prepare_chm_info( params0, suffix='pyxtal') if self.use_mpi: # self.parameters = self.comm.bcast(parameters, root=0) self.atom_info = self.comm.bcast(atom_info, root=0) else: self.atom_info = atom_info # Structure matcher if matcher is None: self.matcher = StructureMatcher(ltol=0.3, stol=0.3, angle_tol=5) else: self.matcher = matcher # I/O stuff self.early_quit = early_quit self.N_min_matches = 10 # The min_num_matches for early termination self.E_max = E_max self.tag = tag.lower() self.suffix = f"{self.workdir:s}/{self.name:s}-{self.ff_style:s}" if self.rank == 0: if cif is None: self.cif = self.suffix + '.cif' else: self.cif = self.suffix + cif with open(self.cif, "w") as f: f.writelines(str(self)) self.matched_cif = self.suffix + "-matched.cif" # print(self) # Some neccessary trackers self.matches = [] self.best_reps = [] self.reps = [] self.engs = []
[docs] def print(self, *args, **kwargs): """Utility method to print only from rank 0.""" if self.rank == 0: print(*args, **kwargs)
def __str__(self): s = "\n-------Global Crystal Structure Prediction------" s += f"\nsmile : {self.smile:s}" s += f"\nZprime : {self.composition!s:s}" s += f"\nN_torsion : {self.N_torsion:d}" s += f"\nsg : {self.sg!s:s}" s += f"\nncpu : {self.size:d}" s += f"\ndiretory : {self.workdir:s}" s += f"\nopt_lat : {self.opt_lat!s:s}" s += f"\nusp_mpi : {self.use_mpi!s:s}\n" if self.early_quit: s += f"Mode : Production\n" else: s += f"Mode : Sampling\n" s += f"cif : {self.cif:s}\n" if self.ff_opt: s += "forcefield: Sample-training\n" else: s += "forcefield: Predefined\n" if self.parameters is not None: s += f"ff_style : {self.ff_style:s}\n" if isinstance(self.ff_parameters, list): for para in self.ff_parameters: s += f"ff_params : {para:s}\n" else: s += f"ff_params : {self.ff_parameters:s}\n" s += f"references: {self.reference_file:s}\n" s += str(self.parameters) return s def __repr__(self): return str(self)
[docs] def new_struc(self, xtal, xtals): return new_struc(xtal, xtals)
[docs] def run(self, ref_pmg=None, ref_pxrd=None): """ The main code to run Sampling Args: ref_pmg: reference pmg structure ref_pxrd: reference pxrd profile in 2D array Returns: success_rate or None """ t0 = time() if ref_pmg is not None: ref_pmg.remove_species("H") self.ref_pmg = ref_pmg self.ref_pxrd = ref_pxrd if self.ncpu > 1: pool = Pool(processes=self.ncpu, initializer=setup_worker_logger, initargs=(self.log_file,)) else: pool = None results = self._run(pool) if self.rank == 0: t = (time() - t0)/60 strs = f"{self.name:s} {self.workdir} COMPLETED " strs += f"in {t:.1f} mins {self.N_struc:d} strucs." print(strs) if self.use_mpi: self.comm.Barrier() return results
[docs] def select_xtals(self, ref_xtals, ids, N_max): """ Select only unique structures """ xtals = [] for id in ids: (xtal, _) = ref_xtals[id] if xtal.energy <= self.E_max and self.new_struc(xtal, xtals): xtals.append(xtal) # .to_ase(resort=False)) if len(xtals) == N_max: break # xtals = [xtal.to_ase(resort=False) for xtal in xtals] return xtals
[docs] def count_pxrd_match(self, xtals, matches): """ Wrap up the matched PXRD results Args: xtals: list of (xtal, tag) tuples matches (list): list of XRD matches """ gen = self.generation for i, match in enumerate(matches): if match > 0.85: (xtal, tag) = xtals[i] with open(self.matched_cif, "a+") as f: e = xtal.energy / sum(xtal.numMols) try: label = self.tag + "-g" + str(gen) + "-p" + str(i) label += f"-e{e:.3f}-{tag:s}-{match:4.2f}" except: print("Error in e, tag, match", e, tag, match) f.writelines(xtal.to_file(header=label)) self.matches.append((gen, i, xtal, e, match, tag))
[docs] def success_count(self, xtals, matches): """ Wrap up the matched results and count success rate. Args: xtals: list of (xtal, tag) tuples matches (list): list of matches [True, False, ..] Return: success_rate """ gen = self.generation for i, match in enumerate(matches): if match: (xtal, tag) = xtals[i] with open(self.matched_cif, "a+") as f: res = self._print_match(xtal, self.ref_pmg) e, d1, d2 = xtal.energy/sum(xtal.numMols), res[0], res[1] try: label = self.tag + "-g" + str(gen) + "-p" + str(i) label += f"-e{e:.3f}-{tag:s}-{d1:4.2f}-{d2:4.2f}" except: print("Error in e, tag, d1, d2", e, tag, d1, d2) f.writelines(xtal.to_file(header=label)) self.matches.append((gen, i, xtal, e, d1, d2, tag)) success_rate = len(self.matches) / self.N_struc * 100 gen_out = f"Success rate @ Gen {gen:3d}: {success_rate:7.4f}%" self.logging.info(gen_out) print(gen_out) return success_rate
[docs] def early_termination(self, success_rate): """ Check if the calculation can be terminated early. """ if success_rate > 0: if self.early_quit: msg = f"Early termination since a match is found" print(msg) self.logging.info(msg) return True elif success_rate > 2.5 or len(self.matches) >= self.N_min_matches: msg = f"Early termination with a high success rate" print(msg) self.logging.info(msg) return True return False
[docs] def export_references(self, xtals, engs, N_min=50, dE=2.5, FMSE=2.5): """ Add trainning data for FF optimization Args: xtals: a list of pyxtals engs: a list of energies N_min (int): minimum number of configs to add dE (float): the cutoff energy value FMSE (float): the cutoff Force MSE value """ cwd = os.getcwd() params, _ = self.parameters.load_parameters(self.ff_parameters) N_max = min([int(self.N_pop * 0.6), 50]) ids = np.argsort(engs) _xtals = self.select_xtals(xtals, ids, N_max) print("Select structures for FF optimization", len(_xtals)) # Initialize references if os.path.exists(self.reference_file): ref_dics = self.parameters.load_references(self.reference_file) ref_ground_states = self.parameters.get_gs_from_ref_dics(ref_dics) else: ref_dics = [] ref_ground_states = [] # Add references os.chdir(self.workdir) if len(ref_dics) > 0 and self.check: ref_dics = self.parameters.cut_references_by_error(ref_dics, params, dE=dE, FMSE=FMSE) if self.ref_criteria is not None: ref_dics = self.parameters.clean_ref_dics( ref_dics, self.ref_criteria) t0 = time() N_selected = min([N_min, self.ncpu, 20]) _ref_dics = self.parameters.add_references(_xtals, ref_ground_states, N_selected) # print(f"Current number of reference structures: {len(ref_dics)}") # print(f"Pick {len(_ref_dics)} reference data for agumentation") #print(_ref_dics); import sys; sys.exit() ref_dics.extend(_ref_dics) aug_dics = self.parameters.augment_references(_ref_dics) ref_dics.extend(aug_dics) t1 = (time() - t0) / 60 print(f"Ref. update usage: {len(_ref_dics)}/{len(aug_dics)} strucs in {t1:.2f} min") ff_dics, ref_dics = self.parameters.evaluate_ff_references(ref_dics, params) if abs(params[-1]) < 1e-3: params = self.parameters.optimize_offset(ref_dics, ff_dics) self.parameters.update_ff_parameters(params) self.parameters.export_parameters(self.ff_parameters.split('/')[-1]) for ff_dic in ff_dics: ff_dic['energy'] += params[-1] # Export FF performances gen_prefix = self.get_label(self.generation, 'gen_') performance_fig = f"FF_performance_{gen_prefix}.png" self.parameters.plot_ff_results(performance_fig, ref_dics, [params], labels=gen_prefix, ff_dics=ff_dics) t2 = (time() - t0) / 60 - t1 print(f"FF performance evaluation usage in {t2:.2f} min") os.chdir(cwd) self.parameters.export_references(ref_dics, self.reference_file)
def _prepare_chm_info(self, params0, params1=None, folder="calc", suffix="pyxtal0"): """ Prepar_chm_info with from the given params. Args: params0 (array or list): FF parameters array params1 (array or list): FF parameters array folder (str): folder path suffix (str): suffix of the temporary file Returns: atom_info """ pwd = os.getcwd() os.chdir(self.workdir) if not os.path.exists(folder): os.mkdir(folder) suffix = folder + '/' + suffix # To remove the old pyxtal1 files if os.path.exists(suffix + ".rtf"): os.remove(suffix + ".rtf") if os.path.exists(suffix + ".prm"): os.remove(suffix + ".prm") ase_with_ff = self.parameters.get_ase_charmm(params0) ase_with_ff.write_charmmfiles(base=suffix) if params1 is not None: ase_with_ff = self.parameters.get_ase_charmm(params1) ase_with_ff.write_charmmfiles(base=suffix) os.chdir(pwd) # Return the atom_info return ase_with_ff.get_atom_info()
[docs] def get_label(self, i, label='cpu'): if i < 10: folder = f"{label}00{i}" elif i < 100: folder = f"{label}0{i}" else: folder = f"{label}0{i}" return folder
[docs] def print_matches(self, header=None): """ Formatted output for the matched structures with xtal rep and eng rank """ if self.rank == 0: all_engs = np.sort(np.array(self.engs)) ranks = [] xtals = [] if self.ref_pxrd is not None: matches = sorted( self.matches, key=lambda x: -x[4]) # similarity else: print(self.matches) matches = sorted(self.matches, key=lambda x: x[3]) # eng for match_data in matches: d1, match = None, None if self.ref_pxrd is not None: (_, id, xtal, e, match, tag) = match_data add = self.new_struc(xtal, xtals) if add: xtals.append(xtal) else: (_, id, xtal, e, d1, d2, tag) = match_data add = True if add: rep0 = xtal.get_1D_representation() if header is not None: strs = header else: strs = "" strs += rep0.to_string(eng=xtal.energy / sum(xtal.numMols)) if d1 is not None: strs += f"{d1:6.3f}{d2:6.3f} Match " if match is not None: strs += f" {match:4.2f} " if e is not None: rank = len(all_engs[all_engs < (e - 1e-3)]) + 1 strs += f" {rank:d}/{self.N_struc:d} {tag:s}" ranks.append(rank) print(strs) if len(ranks) == 0: ranks = [0] return min(ranks)
def _print_match(self, xtal, ref_pmg): """ Print the matched structure Args: rep: 1d rep eng: energy ref_pmg: reference pmg structure """ rep0 = xtal.get_1D_representation() pmg_s1 = xtal.to_pymatgen() pmg_s1.remove_species("H") strs = rep0.to_string(eng=xtal.energy / sum(xtal.numMols)) rmsd = self.matcher.get_rms_dist(ref_pmg, pmg_s1) if rmsd is not None: strs += f"{rmsd[0]:6.3f}{rmsd[1]:6.3f} Match Ref" print(strs) return rmsd[0], rmsd[1] else: return None, None def _apply_gaussian(self, reps, engs, h1=0.1, h2=0.1, w1=0.2, w2=3): """ Apply Gaussian to discourage the sampling of already visited configs. Consider both lattice abc and torsion """ from copy import deepcopy # check torsion N_id = 8 engs_gau = deepcopy(engs) for i, rep in enumerate(reps): gau = 0 if rep is not None and engs[i] < 9999: sg1, abc1 = rep[0][0], np.array(rep[0][1:]) tor1 = np.zeros(self.N_torsion) count = 0 for j in range(1, len(rep)): if len(rep[j]) > N_id: # for Cl- tor1[count: count + len(rep[j]) - N_id - 1] = rep[j][N_id:-1] count += len(rep[j]) - N_id for ref in self.best_reps: sg2, abc2 = ref[0][0], np.array(ref[0][1:]) # Cell g1 = 0 if sg1 == sg2: diff1 = np.sum((abc1 - abc2) ** 2) / w1**2 g1 = h1 * np.exp(-0.5 * diff1) # cell # Torsion g2 = 0 if len(tor1) > 0: tor2 = np.zeros(self.N_torsion) count = 0 for j in range(1, len(rep)): if len(rep[j]) > N_id: # for Cl- tor2[count: count + len(ref[j]) - N_id - 1] = ref[j][N_id:-1] count += len(ref[j]) - N_id diff2 = np.sum((tor1 - tor2) ** 2) / w2**2 g2 = h2 * np.exp(-0.5 * diff2) # torsion gau += g1 + g2 # if gau > 1e-2: print(sg1, diag1, abc1, tor1) # print("Gaussian", i, "eng", engs[i], "gau", gau) # import sys; sys.exit() engs_gau[i] += gau return np.array(engs_gau)
[docs] def check_ref(self, reps=None, reference=None, filename="pyxtal.cif"): """ Check if ground state structure is found. Args: reps: list of representations refernce: [pmg, eng] filename: filename """ if os.path.exists(filename): os.remove(filename) if reference is not None: [pmg0, eng] = reference pmg0.remove_species("H") print("check if ground state structure is found") if reps is None: reps = np.array(self.reps) if eng is None: eng = np.min(reps[:, -1]) + 0.25 reps = reps[reps[:, -1] < (eng + 0.1)] ids = np.argsort(reps[:, -1]) reps = reps[ids] new_reps = [] for rep in reps: eng1 = rep[-1] rep0 = representation(rep[:-1], self.smiles) xtal = rep0.to_pyxtal() pmg_s1 = xtal.to_pymatgen() pmg_s1.remove_species("H") new = True for ref in new_reps: eng2 = ref[-1] pmg_s2 = representation( rep[:-1], self.smiles).to_pyxtal().to_pymatgen() pmg_s2.remove_species("H") if abs(eng1 - eng2) < 1e-2 and self.matcher().fit(pmg_s1, pmg_s2): new = False break if new: new_reps.append(rep) header = f"{len(new_reps):d}: {eng1:12.4f}" xtal.to_file(filename, header=header, permission="a+") strs = rep0.to_string(eng=eng1) rmsd = self.matcher.get_rms_dist(pmg0, pmg_s1) if rmsd is not None: strs += f"{rmsd[0]:6.3f}{rmsd[1]:6.3f} True" print(strs) return True else: print(strs) return False
def _get_local_optimization_args(self): """ Get the arguments for the local optimization """ args = [ randomizer, optimizer, self.smiles, self.block, self.num_block, self.atom_info, self.workdir + "/" + "calc", self.sg, self.composition, self.lattice, self.torsions, self.molecules, self.sites, self.ref_pmg, self.matcher, self.ref_pxrd, self.use_hall, self.skip_ani, self.check_stable, self.pre_opt, ] return args
[docs] def local_optimization(self, xtals, qrs=False, pool=None): """ Perform MPI optimization for each structure in each generation. Args: xtals : list of (xtal, tag) tuples qrs (bool): Force mutation or not (related to QRS) """ if self.use_mpi: return self.local_optimization_mpi(xtals, qrs=qrs, pool=pool) elif self.ncpu == 1: return self.local_optimization_serial(xtals, qrs) else: print(f"Local optimization by multi-threads {self.ncpu}") return self.local_optimization_mproc(xtals, self.ncpu, qrs=qrs, pool=pool)
[docs] def local_optimization_serial(self, xtals, qrs=False): """ Perform optimization for each structure in each generation. Args: xtals : list of (xtal, tag) tuples qrs (bool): Force mutation or not (related to QRS) """ args = self._get_local_optimization_args() gen = self.generation gen_results = [(None, None, None)] * len(xtals) for pop in range(len(xtals)): xtal = xtals[pop][0] job_tag = self.tag + "-g" + str(gen) + "-p" + str(pop) mutated = False if qrs else xtal is not None my_args = [xtal, pop, mutated, job_tag, *args] xtal, match = optimizer_single(*tuple(my_args)) gen_results[pop] = (pop, xtal, match) return gen_results
[docs] def local_optimization_mpi(self, xtals, qrs, pool): """ Perform MPI optimization for each structure in each generation. Args: xtals : list of (xtal, tag) tuples qrs (bool): Force mutation or not (related to QRS) """ #t0 = time() gen = self.generation self.print("Local optimization enabled by MPI", self.size, self.ncpu) # Distribute args_lists across available ranks (processes) local_xtals = xtals[self.rank::self.size] local_ids = list(range(self.N_pop))[self.rank::self.size] # Call local_optimization_mproc self.logging.info(f"Rank {self.rank} gets {len(local_xtals)} strucs") results = self.local_optimization_mproc(local_xtals, self.ncpu, local_ids, qrs, pool) # Synchronize before gathering self.logging.info(f"Rank {self.rank} finish local_optimization_mproc") self.comm.Barrier() # Gather all results at the root process self.logging.info(f"Rank {self.rank} in MPI_Gather at gen {gen}") all_results = self.comm.gather(results, root=0) self.logging.info(f"Rank {self.rank} done MPI_Gather at gen {gen}") # If root process, process the results gen_results = None if self.rank == 0: gen_results = [(None, None, None)] * len(xtals) for result_set in all_results: for res in result_set: (id, xtal, match) = res gen_results[id] = (id, xtal, match) # Broadcast self.logging.info(f"Rank {self.rank} MPI_bcast at gen {gen}") gen_results = self.comm.bcast(gen_results, root=0) return gen_results
[docs] def local_optimization_mproc(self, xtals, ncpu, ids=None, qrs=False, pool=None): """ Perform optimization for each structure in multiprocess mode. Args: xtals : list of (xtal, tag) tuples ncpu (int): number of parallel python processes ids (list): list of ids of the associated xtals qrs (bool): Force mutation or not (related to QRS) """ gen = self.generation args = self._get_local_optimization_args() if ids is None: ids = range(len(xtals)) N_cycle = int(np.ceil(len(xtals) / ncpu)) # Generator to create arg_lists for multiprocessing tasks def generate_args_lists(): for i in range(ncpu): id1 = i * N_cycle id2 = min([id1 + N_cycle, len(xtals)]) _ids = ids[id1: id2] job_tags = [self.tag + "-g" + str(gen) + "-p" + str(id) for id in _ids] _xtals = [xtals[id][0] for id in range(id1, id2)] mutates = [False if qrs else xtal is not None for xtal in _xtals] my_args = [_xtals, _ids, mutates, job_tags, *args, self.rank, self.timeout] yield tuple(my_args) # Yield args instead of appending to a list gen_results = [] for result in pool.imap_unordered(process_task, generate_args_lists()): if result is not None: for _res in result: gen_results.append(_res) return gen_results
[docs] def gen_summary(self, t0, gen_results, xtals): """ Write the generic summary for each generation. Args: t0 (float): time stamp gen_results: list of results (id, xtal, match) xtals: list of (xtal, tag) tuples """ matches = [False] if self.ref_pxrd is None else [0.0] matches *= self.N_pop eng0s = [self.E_max] * self.N_pop reps = [None] * self.N_pop new_xtals = [(None, None)] * len(xtals) gen = self.generation for res in gen_results: (id, xtal, match) = res if xtal is not None: new_xtals[id] = (xtal, xtals[id][1]) eng0s[id] = xtal.energy / sum(xtal.numMols) reps[id] = xtal.get_1D_representation() matches[id] = match # Don't write bad structure if self.cif is not None and xtal.energy < 9999: if self.verbose: print("Add qualified structure", id, xtal.energy) with open(self.cif, "a+") as f: label = self.tag + "-g" + str(gen) + "-p" + str(id) f.writelines(xtal.to_file(header=label)) self.engs.append(xtal.energy / sum(xtal.numMols)) self.stats[gen][id][0] = xtal.energy / sum(xtal.numMols) self.stats[gen][id][1] = match self.min_energy = np.min(np.array(self.engs)) self.N_struc = len(self.engs) strs = f"Generation {gen:d} finishes: {len(self.engs):d} strucs" print(strs) self.logging.info(strs) t1 = time() # Apply Gaussian reps_x = [rep.x if rep is not None else None for rep in reps] if self.ref_pxrd is None: engs = self._apply_gaussian(reps_x, eng0s) else: engs = self._apply_gaussian(reps_x, -1 * np.array(matches)) # Store the best structures count = 0 ref_xtals = [] ids = np.argsort(engs) for id in ids: (xtal, tag) = new_xtals[id] rep, eng = reps[id], eng0s[id] if self.new_struc(xtal, ref_xtals): ref_xtals.append(xtal) self.best_reps.append(rep.x) # d_rep = representation(rep, self.smiles) tag = f"{tag:8s}" try: strs = rep.to_string(None, eng, tag) out = f"{gen:3d} {strs:s} Top" if self.ref_pxrd is not None: out += f" {matches[id]:6.3f}" print(out) except: print('Error', xtal) count += 1 if count == 3: break t2 = time() gen_out = f"Gen{gen:3d} time usage: " gen_out += f"{t1 - t0:5.1f}[Calc] {t2 - t1:5.1f}[Proc]" print(gen_out) return new_xtals, matches, engs
[docs] def plot_results(self, save=True, figsize=(8.0, 5.0), figname=None, ylim=None): """ Plot the results Args: save (bool): whether or not save the data figsize: e.g. (8.5, 5.0) figname (str): ylim: e.g. (0, 1.0) """ if self.rank == 0: import matplotlib.pyplot as plt import seaborn as sns sns.set_theme() sns.set_context("talk", font_scale=0.9) if figname is None: figname = self.suffix + "-results.pdf" data1 = [] data2 = [] # Extract (pop_id, eng, gen_id) when pxrd is False # Extract (pop_id, eng, sim) when pxrd is True for i in range(self.N_gen): for j in range(self.N_pop): if self.ref_pxrd is not None: data1.append( [i, j, self.stats[i, j, 0], self.stats[i, j, 1]]) else: data1.append([i, j, self.stats[i, j, 0]]) for match in self.matches: if self.ref_pxrd is not None: data2.append([match[0], match[1], match[3], match[4]]) else: data2.append([match[0], match[1], match[3]]) fig = plt.figure(figsize=figsize) plt.ylabel("Lattice Energy (kcal)") # , weight='bold') data1 = np.array(data1) if self.ref_pxrd is not None: # (similarity, eng, gen_id) x1, y1, z1 = data1[:, 3], data1[:, 2], data1[:, 0] plt.xlabel("XRD Similarity") # , weight='bold') else: x1, y1, z1 = data1[:, 1], data1[:, 2], data1[:, 0] plt.xlabel("Population ID") # , weight='bold') # Plot of all samples (PopID, Engs/Similarity) scatter = plt.scatter(x1, y1, s=10, c=z1, cmap='winter', alpha=0.5, label='Samples') cbar = plt.colorbar(scatter) cbar.set_label('Generation ID') # , weight='bold') if ylim is None: y1 = np.array(y1) ymin = y1.min() - 0.25 ymax = min([ymin + 10.0, y1.max()]) ylim = (ymin, ymax) if len(data2) > 0: data2 = np.array(data2) if len(data2.shape) == 1: data2 = data2.reshape(-1, 1) if self.ref_pxrd is not None: x2, y2, z2 = data2[:, 3], data2[:, 2], data2[:, 0] else: x2, y2, z2 = data2[:, 1], data2[:, 2], data2[:, 0] plt.scatter(x2, y2, s=10, c='red', marker='x', label='Matches') plt.legend(loc=1) # , prop={'weight': 'bold'}) plt.ylim(ylim) plt.title(f"{self.name:s}-{self.ff_style:s}") # , weight='bold') plt.tight_layout() plt.savefig(figname) if save: if self.ref_pxrd is not None: header = "#Generation, Population, Energy, Similarity" else: header = "#Generation, Population, Energy" data_txt = self.suffix + "-data.txt" np.savetxt(data_txt, data1, header=header) if len(data2) > 0: match_txt = self.suffix + "-match.txt" np.savetxt(match_txt, data2, header=header)
[docs] def save(self, filename): """ Save the base class """ if self.rank == 0: import xml.etree.ElementTree as ET from pyxtal.util import prettify root = ET.Element("GO") ET.SubElement(root, "smile").text = self.smile ET.SubElement(root, "tag").text = self.tag ET.SubElement(root, "cif").text = self.cif ET.SubElement(root, "workdir").text = self.workdir ET.SubElement(root, "reference_file").text = self.reference_file ET.SubElement(root, "ff_style").text = self.ff_style ET.SubElement(root, "ff_parameters").text = self.ff_parameters ET.SubElement(root, "sg").text = str(self.sg) ET.SubElement(root, "N_pop").text = str(self.N_pop) ET.SubElement(root, "N_gen").text = str(self.N_gen) ET.SubElement(root, "E_max").text = str(self.E_max) ET.SubElement(root, "early_quit").text = str(self.early_quit) ET.SubElement(root, "ff_opt").text = str(self.ff_opt) ET.SubElement(root, "use_mpi").text = str(self.use_mpi) ET.SubElement(root, "verbose").text = str(self.verbose) ET.SubElement(root, "skip_ani").text = str(self.skip_ani) ET.SubElement(root, "check_stable").text = str(self.check_stable) ET.SubElement(root, "pre_opt").text = str(self.pre_opt) ET.SubElement(root, "use_hall").text = str(self.use_hall) ET.SubElement(root, "N_cpu").text = str(self.ncpu) ET.SubElement(root, "factor").text = str(self.factor) ET.SubElement(root, "eng_cutoff").text = str(self.eng_cutoff) ET.SubElement(root, "max_time").text = str(self.max_time) ET.SubElement(root, "fracs").text = str(self.fracs) ET.SubElement(root, "composition").text = str(self.composition) ET.SubElement(root, "lattice").text = arr_to_text(self.lattice) #ET.SubElement(basic, "sites").text = str(self.sites) #ET.SubElement(basic, "torsions").text = self.torsions #ET.SubElement(basic, "ref_criteria").text = str(None) #self.ref_criteria # Use prettify to get a pretty-printed XML string pretty_xml = prettify(root) with open(filename, "w") as f: f.writelines(pretty_xml)
[docs] def load_xml(filename, tag='GO'): """ Load the base class """ import xml.etree.ElementTree as ET if os.path.exists(filename): tree = ET.parse(filename) basic = tree.getroot() # Strings smile = basic.find("smile").text tag = basic.find("tag").text early_quit = basic.find("early_quit").text ff_style = basic.find("ff_style").text workdir = basic.find("workdir").text reference_file = basic.find("reference_file").text.split('/')[-1] ff_parameters = basic.find("ff_parameters").text.split('/')[-1] cif = basic.find("cif").text.split('/')[-1] # Boolean ff_opt = text_to_bool(basic.find("ff_opt").text) use_mpi = text_to_bool(basic.find("use_mpi").text) verbose = text_to_bool(basic.find("verbose").text) skip_ani = text_to_bool(basic.find("skip_ani").text) check_stable = text_to_bool(basic.find("check_stable").text) pre_opt = text_to_bool(basic.find("pre_opt").text) use_hall = text_to_bool(basic.find("use_hall").text) # Numbers N_cpu = int(basic.find("N_cpu").text) N_pop = int(basic.find("N_pop").text) N_gen = int(basic.find("N_gen").text) E_max = float(basic.find("E_max").text) eng_cutoff = float(basic.find("eng_cutoff").text) factor = float(basic.find("factor").text) sg = text_to_list(basic.find("sg").text, int) fracs = text_to_arr(basic.find("fracs").text, float) composition = text_to_arr(basic.find("composition").text, float) lattice = text_to_2darr(basic.find("lattice").text) max_time = text_to_float(basic.find("max_time").text) # sites = basic.find("sites").text # print(smile, tag, cif, workdir, sg, ff_opt, ff_style, ff_parameters, reference_file, N_gen, N_pop, N_cpu, fracs, cif, composition, lattice, use_hall, skip_ani, factor, eng_cutoff, E_max, verbose, max_time, early_quit, check_stable, use_mpi, pre_opt) return (smile, workdir, sg, tag, None, ff_opt, ff_style, ff_parameters, reference_file, None, N_gen, N_pop, N_cpu, fracs, cif, None, None, composition, lattice, None, None, None, use_hall, skip_ani, factor, eng_cutoff, E_max, verbose, None, max_time, None, early_quit, check_stable, use_mpi, pre_opt) else: raise ValueError("No such file", filename, os.getcwd())
[docs] def text_to_list(text, dtype=float): """ Convert a text to a list In the form of "[1 2 3]" """ text = text.replace("[", "").replace("]", "") if text == 'None': return None else: return [dtype(i) for i in text.split()]
[docs] def text_to_arr(text, dtype=float): """ Convert a text to an array """ if text == 'None': return None else: text = text.replace("[", "").replace("]", "") return np.array([dtype(i) for i in text.split()])
[docs] def text_to_2darr(text, dtype=float): """ Convert a text to an 2D array """ if text == 'None': return None else: text = text.replace("[", "").replace("]", "") return np.array([[dtype(j) for j in i.split()] for i in text.splitlines()])
[docs] def text_to_float(text): """ Convert a text to a float """ if text == 'None': return None else: return float(text)
[docs] def text_to_bool(text): """Convert text string to boolean value""" if text.lower() == 'true': return True elif text.lower() == 'false': return False else: raise ValueError(f"Invalid boolean string: {text}")
[docs] def arr_to_text(arr): """ Convert a 2D array to a text """ if arr is None: return 'None' elif len(arr.shape) == 1: return ' '.join([str(i) for i in arr]) else: return '\n'.join([' '.join([str(j) for j in i]) for i in arr])
if __name__ == "__main__": print("test")