Source code for pyxtal.optimize.QRS

"""
Global Optimizer base on Quasi-Random Sampling
"""
from __future__ import annotations
from time import time
from typing import TYPE_CHECKING

import numpy as np
from scipy.stats import qmc
from pymatgen.analysis.structure_matcher import StructureMatcher

from pyxtal.optimize.base import GlobalOptimize
from pyxtal.representation import representation
from pyxtal.lattice import Lattice

if TYPE_CHECKING:
    from pyxtal.molecule import pyxtal_molecule

[docs] def generate_qrs_cell(sampler, cell_bounds, ref_volume, ltype): """ A routine to generate quasi random samples for lattice and wp """ # Sample cell parameters min_vol, max_vol = 0.75*ref_volume, 2.5*ref_volume lb = [b[0] for b in cell_bounds] ub = [b[1] for b in cell_bounds] count = 0 while True: count += 1 sample = qmc.scale(sampler.random(), lb, ub)[0].tolist() lat = Lattice.from_1d_representation(sample, ltype) if min_vol < lat.volume < max_vol: #print(sample, ltype) return sample if count == 1000: raise ValueError("Cannot generate valid cell with 1000 attempts")
[docs] def generate_qrs_xtals(cell, wp_bounds, N_pop, smiles, comp, sampler_wp=None, d_tol=0.85): """ Get the qrs xtal samples Args: cell (list): [hall, a, b, c] wp_bounds (list): [[wp0], [wp1], ...] N_pop (int): number of valid candidates smiles (list): [] comp (list): [1] sampler_wp: sampler d_tol (float): short distance tolerance value """ #cell = [81, 11.38, 6.48, 11.24, 96.9] xtals = [] lb, ub = [], [] seqs = [] for wp_bound in wp_bounds: lb += [b[0] for b in wp_bound] ub += [b[1] for b in wp_bound] seqs.append(len(wp_bound)) if sampler_wp is None: sampler_wp = qmc.Sobol(d=len(lb), scramble=False) m = max([int(np.log2(N_pop))+3, 9]) for i in range(2**m): sample_wp = sampler_wp.random()#; print(sample_wp) sample_wp = qmc.scale(sample_wp, lb, ub)[0].tolist() x = [cell] prev = 0 for seq in seqs: wp = [0] + sample_wp[prev:prev+seq] + [0]#; print('DDDD', prev, prev+seq, sample_wp[prev:prev+seq], wp) x.append(wp) #, [0] + wp0.tolist() + [False]] #print(x) prev = seq rep = representation(x, smiles) xtal = rep.to_pyxtal(composition=comp) if not xtal.has_special_site() and len(xtal.check_short_distances(r=d_tol)) == 0: #print("debug", rep) xtals.append((xtal, "QRandom")) if len(xtals) == N_pop: return xtals #else: # print(rep, len(xtal.check_short_distances(r=0.6))) return xtals
[docs] class QRS(GlobalOptimize): """ Quasi-Random Sampling Args: smiles (str): smiles string workdir (str): path of working directory sg (int or list): space group number or list of spg numbers tag (string): job prefix ff_opt (bool): activate on the fly FF mode ff_style (str): automated force style (`gaff` or `openff`) ff_parameters (str or list): ff parameter xml file or list reference_file (str): path of reference xml data for FF training N_gen (int): number of generation (default: `10`) N_pop (int): number of populations (default: `10`) N_cpu (int): number of cpus for parallel calculation (default: `1`) cif (str): cif file name to store all structure information block: block mode num_block: list of blocks compositions: list of composition, (default is [1]*Num_mol) lattice (bool): whether or not supply the lattice torsions: list of torsion angle molecules (list): list of pyxtal_molecule objects sites (list): list of wp sites, e.g., [['4a']] use_hall (bool): whether or not use hall number (default: False) skip_ani (bool): whether or not use ani or not (default: True) eng_cutoff (float): the cutoff energy for FF training E_max (float): maximum energy defined as an invalid structure verbose (bool): show more details use_mpi: bool = False, """ def __init__( self, smiles: str, workdir: str, sg: int | list, tag: str = 'test', info: dict[any, any] | None = None, ff_opt: bool = False, ff_style: str = "openff", ff_parameters: str = "parameters.xml", reference_file: str = "references.xml", ref_criteria: dict[any, any] | None = None, N_gen: int = 10, N_pop: int = 10, N_cpu: int = 1, cif: str | None = None, block: list[any] | None = None, num_block: list[any] | None = None, composition: list[any] | None = None, lattice: Lattice | None = None, torsions: list[any] | None = None, molecules: list[pyxtal_molecule] | None = None, sites: list[any] | None = None, use_hall: bool = False, skip_ani: bool = True, factor: float = 1.1, eng_cutoff: float = 5.0, E_max: float = 1e10, N_survival: int = 20, verbose: bool = False, random_state: int | None = None, max_time: float | None = None, matcher: StructureMatcher | None = None, early_quit: bool = False, check_stable: bool = False, use_mpi: bool = False, ): # POPULATION parameters: self.N_gen = N_gen # Number of lattice points self.N_pop = N_pop # Number of wp varieties self.verbose = verbose self.name = 'QRS' # initialize other base parameters GlobalOptimize.__init__( self, smiles, workdir, sg, tag, info, ff_opt, ff_style, ff_parameters, reference_file, ref_criteria, N_cpu, cif, block, num_block, composition, lattice, torsions, molecules, sites, use_hall, skip_ani, factor, eng_cutoff, E_max, None, #random_state, max_time, matcher, early_quit, check_stable, use_mpi, ) # Setup the stats [N_gen, Npop, (E, matches)] self.stats = np.zeros([self.N_gen, self.N_pop, 2]) self.stats[:, :, 0] = self.E_max if self.rank == 0: strs = self.full_str() self.logging.info(strs) print(strs)
[docs] def full_str(self): s = str(self) s += "\nMethod : Deterministic Quasi-Random Sampling" s += f"\nGeneration: {self.N_gen:4d}" s += f"\nPopulation: {self.N_pop:4d}" # The rest base information from now on return s
def _run(self, pool=None): """ The main code to run QRS prediction Returns: success_rate or None """ self.ref_volumes = [] success_rate = 0 print(f"Rank {self.rank} starts QRS in {self.tag}") for gen in range(self.N_gen): self.generation = gen cur_xtals = None self.logging.info(f"Gen {gen} starts in Rank {self.rank}") if self.rank == 0: print(f"\nGeneration {gen:d} starts") self.logging.info(f"Generation {gen:d} starts") t0 = time() # Initialize cur_xtals = [(None, "Random")] * self.N_pop # QRS update if gen > 0: if self.lattice is not None: cell = [self.hall_number] + self.lattice.encode() sampler = self.sampler else: cell = generate_qrs_cell(self.sampler, self.cell_bounds, self.ref_volumes[-1], self.ltype) cell = [self.hall_number] + cell sampler = None cur_xtals = generate_qrs_xtals(cell, self.wp_bounds, self.N_pop, self.smiles, self.composition, sampler) strs = f"Cell parameters in Gen-{gen:d}: " print(strs, cell, self.ref_volumes[-1], len(cur_xtals)) # Broadcast if self.use_mpi: cur_xtals = self.comm.bcast(cur_xtals, root=0) # Local optimization gen_results = self.local_optimization(cur_xtals, qrs=True, pool=pool) self.logging.info(f"Rank {self.rank} finishes local_opt") # Summary and Ranking quit = False if self.rank == 0: cur_xtals, matches, engs = self.gen_summary(t0, gen_results, cur_xtals) # update hist_best vols = [] for id, (xtal, _) in enumerate(cur_xtals): if xtal is not None: vols.append(xtal.lattice.volume) # update best volume self.ref_volumes.append(np.array(vols).mean()) if gen == 0: best_xtal = cur_xtals[0][0] self.cell_bounds = best_xtal.lattice.get_bounds(2.5, 25) self.ltype = best_xtal.lattice.ltype self.wp_bounds = [site.get_bounds() for site in best_xtal.mol_sites] self.hall_number = best_xtal.group.hall_number if self.lattice is not None: len_reps = sum(len(bound) for bound in self.wp_bounds) else: len_reps = len(self.cell_bounds) self.sampler = qmc.Sobol(d=len_reps, scramble=False) if self.ref_pmg is not None: success_rate = self.success_count(cur_xtals, matches) if self.early_termination(success_rate): quit = True elif self.ref_pxrd is not None: self.count_pxrd_match(cur_xtals, matches) if self.use_mpi: quit = self.comm.bcast(quit, root=0) self.comm.Barrier() self.logging.info(f"Gen {gen} Finish in Rank {self.rank}") if quit: self.logging.info(f"Early Termination in Rank {self.rank}") return success_rate return success_rate
if __name__ == "__main__": print("test")