Source code for tools

import os
import signal
import subprocess
import time
from ast import literal_eval
from contextlib import contextmanager
from pathlib import Path

import numpy as np
try:
    import rapidjson as json
except ModuleNotFoundError:
    import json

from ase.data import chemical_symbols
from ase.db import connect
from ase.geometry import cell_to_cellpar
from ase.units import Bohr


[docs] def is_equal(a, b): try: np.testing.assert_equal(a, b) except AssertionError: return False return True
[docs] def compare_atoms(atoms1, atoms2, atol=1e-8, pbc=False): """Check for system changes since last calculation.""" system_changes = [] # If the atoms objects don't represent the same system # the comparison doesn't make sense and we return early. if len(atoms1) != len(atoms2): system_changes.append("natoms") return system_changes if not np.all(np.equal(atoms1.numbers, atoms2.numbers)): system_changes.append("numbers") if not np.allclose(atoms1.positions, atoms2.positions, atol=atol): system_changes.append("positions") if not np.allclose(atoms1.cell, atoms2.cell, atol=atol): system_changes.append("cell") if pbc: if not equal(atoms1.pbc, atoms2.pbc): system_changes.append("pbc") if not np.allclose( atoms1.get_initial_magnetic_moments(), atoms2.get_initial_magnetic_moments(), atol=atol, ): system_changes.append("initial_magmoms") if not np.allclose( atoms1.get_initial_charges(), atoms2.get_initial_charges(), atol=atol ): system_changes.append("initial_charges") return system_changes
[docs] def uniquify(seq): """Make a sequence unique whilst preserving order. """ seen = set() seen_add = seen.add return [x for x in seq if not (x in seen or seen_add(x))]
[docs] @contextmanager def timeit(): """Context manager for timing a piece of code. """ print("----------------") start = time.time() yield timing = time.time() - start print("Timing: {:.4f} s".format(timing)) print("----------------")
[docs] def read_siteconf(): # handle cluster configurations siteconf = None if "STORQ_CONFIG_DIR" in os.environ: config_file = Path(os.environ['STORQ_CONFIG_DIR']) / 'site.json' else: config_file = Path("~/.config/storq/site.json").expanduser().resolve() siteconf = read_configuration(config_file) return siteconf
[docs] def read_batchconf(): conf = {} if "STORQ_CONFIG_DIR" in os.environ: config_files = [Path(os.environ['STORQ_CONFIG_DIR']) / 'vasp.json'] else: config_files = [Path("~/.config/storq/vasp.json").expanduser().resolve()] config_files.append('vasp.json') for cf in config_files: if os.path.exists(cf): conf.update(read_configuration(cf)) return conf
[docs] def handler(signum, frame): """Signal handler for timed find_file. """ raise RuntimeError
[docs] def find_file(name, first=True, follow_symlinks=False, name_type="f"): """Locate name using call to bash find. """ if follow_symlinks: symflag = "-L " else: symflag = "" if first: cmd = "find {}~ -type {} -name {} -print -quit".format(symflag, name_type, name) else: cmd = "find {}~ -type {} -name {}".format(symflag, name_type, name) _, out, err = getstatusoutput( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) out = out.decode().strip() if follow_symlinks and out != "": out = os.path.realpath(out) return out
[docs] def timed_find_file(name, first=True, follow_symlinks=False, name_type="f"): """Timed version of find. """ out = "" try: signal.signal(signal.SIGALRM, handler) signal.alarm(3) out = find_file( name, first=first, follow_symlinks=follow_symlinks, name_type=name_type ) except RuntimeError: pass return out
[docs] def find_runs(parent_dir): parent_dir = Path(parent_dir) candidate_dirs = [d.resolve() for d in next(os.walk(parent_dir))[1]] run_dirs = [d for d in candidate_dirs if d.joinpath("POCAR").is_file()] return run_dirs
[docs] def read_db(directory): """Read the persistent data from a storq.db Parameters ---------- directory : str Path (relative or absolute) to the calculation directory. Returns ------- data : dict Dictionary containing all the persistent data from storq.db """ directory = Path(directory).resolve() database = directory.joinpath("storq.json") with connect(database) as db: data = db.get(1).data atoms = db.get(1).toatoms() return atoms, data
[docs] def write_configuration(fname, config): with open(fname, "w") as fp: json.dump(config, fp, indent=2)
[docs] def read_configuration(fname): with open(fname, "r") as fp: config = json.load(fp) return config
[docs] def getstatusoutput(*args, **kwargs): """Helper function to replace the old commands.getstatusoutput. Returns the returncode, stdout and sterr associated with the command. getstatusoutput([command], stdin=subprocess.PIPE) """ p = subprocess.Popen(*args, **kwargs) stdout, stderr = p.communicate() return (p.returncode, stdout, stderr)
[docs] def tail(fname): """ Quick way to read the last line of a file. """ with open(fname, "rb") as f: first = f.readline() # Read the first line. f.seek(-2, os.SEEK_END) # Jump to the second last byte. while f.read(1) != b"\n": # Until EOL is found... f.seek(-2, os.SEEK_CUR) # ...jump back the read byte plus one more. last = f.readline() # Read last line. return last
[docs] def ascii_atoms(atoms): """Blatantly stolen from GPAW. Ascii-art plot of the atoms.""" # y # | # .-- x # / # z cell_cv = atoms.get_cell() if (cell_cv - np.diag(cell_cv.diagonal())).any(): atoms = atoms.copy() atoms.cell = [1, 1, 1] atoms.center(vacuum=2.0) cell_cv = atoms.get_cell() plot_box = False else: plot_box = True cell = np.diagonal(cell_cv) / Bohr positions = atoms.get_positions() / Bohr numbers = atoms.get_atomic_numbers() s = 1.3 nx, ny, nz = n = (s * cell * (1.0, 0.25, 0.5) + 0.5).astype(int) sx, sy, sz = n / cell grid = Grid(nx + ny + 4, nz + ny + 1) positions = (positions % cell + cell) % cell ij = np.dot(positions, [(sx, 0), (sy, sy), (0, sz)]) ij = np.around(ij).astype(int) for a, Z in enumerate(numbers): symbol = chemical_symbols[Z] i, j = ij[a] depth = positions[a, 1] for n, c in enumerate(symbol): grid.put(c, i + n + 1, j, depth) if plot_box: k = 0 for i, j in [(1, 0), (1 + nx, 0)]: grid.put("*", i, j) grid.put(".", i + ny, j + ny) if k == 0: grid.put("*", i, j + nz) grid.put(".", i + ny, j + nz + ny) for y in range(1, ny): grid.put("/", i + y, j + y, y / sy) if k == 0: grid.put("/", i + y, j + y + nz, y / sy) for z in range(1, nz): if k == 0: grid.put("|", i, j + z) grid.put("|", i + ny, j + z + ny) k = 1 for i, j in [(1, 0), (1, nz)]: for x in range(1, nx): if k == 1: grid.put("-", i + x, j) grid.put("-", i + x + ny, j + ny) k = 0 tmp = ["".join([chr(x) for x in line]) for line in np.transpose(grid.grid)[::-1]] return [" {}".format(elem) for elem in tmp]
[docs] class Grid: """ Supporting class for the plot_atoms function. """ def __init__(self, i, j): self.grid = np.zeros((i, j), np.int8) self.grid[:] = ord(" ") self.depth = np.zeros((i, j)) self.depth[:] = 1e10
[docs] def put(self, c, i, j, depth=1e9): if depth < self.depth[i, j]: self.grid[i, j] = ord(c) self.depth[i, j] = depth