import os
import signal
import subprocess
import time
from ast import literal_eval
from contextlib import contextmanager
from pathlib import Path
import numpy as np
try:
import rapidjson as json
except ModuleNotFoundError:
import json
from ase.data import chemical_symbols
from ase.db import connect
from ase.geometry import cell_to_cellpar
from ase.units import Bohr
[docs]
def is_equal(a, b):
try:
np.testing.assert_equal(a, b)
except AssertionError:
return False
return True
[docs]
def compare_atoms(atoms1, atoms2, atol=1e-8, pbc=False):
"""Check for system changes since last calculation."""
system_changes = []
# If the atoms objects don't represent the same system
# the comparison doesn't make sense and we return early.
if len(atoms1) != len(atoms2):
system_changes.append("natoms")
return system_changes
if not np.all(np.equal(atoms1.numbers, atoms2.numbers)):
system_changes.append("numbers")
if not np.allclose(atoms1.positions, atoms2.positions, atol=atol):
system_changes.append("positions")
if not np.allclose(atoms1.cell, atoms2.cell, atol=atol):
system_changes.append("cell")
if pbc:
if not equal(atoms1.pbc, atoms2.pbc):
system_changes.append("pbc")
if not np.allclose(
atoms1.get_initial_magnetic_moments(),
atoms2.get_initial_magnetic_moments(),
atol=atol,
):
system_changes.append("initial_magmoms")
if not np.allclose(
atoms1.get_initial_charges(), atoms2.get_initial_charges(), atol=atol
):
system_changes.append("initial_charges")
return system_changes
[docs]
def uniquify(seq):
"""Make a sequence unique whilst preserving order. """
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
[docs]
@contextmanager
def timeit():
"""Context manager for timing a piece of code. """
print("----------------")
start = time.time()
yield
timing = time.time() - start
print("Timing: {:.4f} s".format(timing))
print("----------------")
[docs]
def read_siteconf():
# handle cluster configurations
siteconf = None
if "STORQ_CONFIG_DIR" in os.environ:
config_file = Path(os.environ['STORQ_CONFIG_DIR']) / 'site.json'
else:
config_file = Path("~/.config/storq/site.json").expanduser().resolve()
siteconf = read_configuration(config_file)
return siteconf
[docs]
def read_batchconf():
conf = {}
if "STORQ_CONFIG_DIR" in os.environ:
config_files = [Path(os.environ['STORQ_CONFIG_DIR']) / 'vasp.json']
else:
config_files = [Path("~/.config/storq/vasp.json").expanduser().resolve()]
config_files.append('vasp.json')
for cf in config_files:
if os.path.exists(cf):
conf.update(read_configuration(cf))
return conf
[docs]
def handler(signum, frame):
"""Signal handler for timed find_file. """
raise RuntimeError
[docs]
def find_file(name, first=True, follow_symlinks=False, name_type="f"):
"""Locate name using call to bash find. """
if follow_symlinks:
symflag = "-L "
else:
symflag = ""
if first:
cmd = "find {}~ -type {} -name {} -print -quit".format(symflag, name_type, name)
else:
cmd = "find {}~ -type {} -name {}".format(symflag, name_type, name)
_, out, err = getstatusoutput(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out = out.decode().strip()
if follow_symlinks and out != "":
out = os.path.realpath(out)
return out
[docs]
def timed_find_file(name, first=True, follow_symlinks=False, name_type="f"):
"""Timed version of find. """
out = ""
try:
signal.signal(signal.SIGALRM, handler)
signal.alarm(3)
out = find_file(
name, first=first, follow_symlinks=follow_symlinks, name_type=name_type
)
except RuntimeError:
pass
return out
[docs]
def find_runs(parent_dir):
parent_dir = Path(parent_dir)
candidate_dirs = [d.resolve() for d in next(os.walk(parent_dir))[1]]
run_dirs = [d for d in candidate_dirs if d.joinpath("POCAR").is_file()]
return run_dirs
[docs]
def read_db(directory):
"""Read the persistent data from a storq.db
Parameters
----------
directory : str
Path (relative or absolute) to the calculation directory.
Returns
-------
data : dict
Dictionary containing all the persistent data from storq.db
"""
directory = Path(directory).resolve()
database = directory.joinpath("storq.json")
with connect(database) as db:
data = db.get(1).data
atoms = db.get(1).toatoms()
return atoms, data
[docs]
def write_configuration(fname, config):
with open(fname, "w") as fp:
json.dump(config, fp, indent=2)
[docs]
def read_configuration(fname):
with open(fname, "r") as fp:
config = json.load(fp)
return config
[docs]
def getstatusoutput(*args, **kwargs):
"""Helper function to replace the old commands.getstatusoutput.
Returns the returncode, stdout and sterr associated with the command.
getstatusoutput([command], stdin=subprocess.PIPE)
"""
p = subprocess.Popen(*args, **kwargs)
stdout, stderr = p.communicate()
return (p.returncode, stdout, stderr)
[docs]
def tail(fname):
""" Quick way to read the last line of a file. """
with open(fname, "rb") as f:
first = f.readline() # Read the first line.
f.seek(-2, os.SEEK_END) # Jump to the second last byte.
while f.read(1) != b"\n": # Until EOL is found...
f.seek(-2, os.SEEK_CUR) # ...jump back the read byte plus one more.
last = f.readline() # Read last line.
return last
[docs]
def ascii_atoms(atoms):
"""Blatantly stolen from GPAW.
Ascii-art plot of the atoms."""
# y
# |
# .-- x
# /
# z
cell_cv = atoms.get_cell()
if (cell_cv - np.diag(cell_cv.diagonal())).any():
atoms = atoms.copy()
atoms.cell = [1, 1, 1]
atoms.center(vacuum=2.0)
cell_cv = atoms.get_cell()
plot_box = False
else:
plot_box = True
cell = np.diagonal(cell_cv) / Bohr
positions = atoms.get_positions() / Bohr
numbers = atoms.get_atomic_numbers()
s = 1.3
nx, ny, nz = n = (s * cell * (1.0, 0.25, 0.5) + 0.5).astype(int)
sx, sy, sz = n / cell
grid = Grid(nx + ny + 4, nz + ny + 1)
positions = (positions % cell + cell) % cell
ij = np.dot(positions, [(sx, 0), (sy, sy), (0, sz)])
ij = np.around(ij).astype(int)
for a, Z in enumerate(numbers):
symbol = chemical_symbols[Z]
i, j = ij[a]
depth = positions[a, 1]
for n, c in enumerate(symbol):
grid.put(c, i + n + 1, j, depth)
if plot_box:
k = 0
for i, j in [(1, 0), (1 + nx, 0)]:
grid.put("*", i, j)
grid.put(".", i + ny, j + ny)
if k == 0:
grid.put("*", i, j + nz)
grid.put(".", i + ny, j + nz + ny)
for y in range(1, ny):
grid.put("/", i + y, j + y, y / sy)
if k == 0:
grid.put("/", i + y, j + y + nz, y / sy)
for z in range(1, nz):
if k == 0:
grid.put("|", i, j + z)
grid.put("|", i + ny, j + z + ny)
k = 1
for i, j in [(1, 0), (1, nz)]:
for x in range(1, nx):
if k == 1:
grid.put("-", i + x, j)
grid.put("-", i + x + ny, j + ny)
k = 0
tmp = ["".join([chr(x) for x in line]) for line in np.transpose(grid.grid)[::-1]]
return [" {}".format(elem) for elem in tmp]
[docs]
class Grid:
""" Supporting class for the plot_atoms function. """
def __init__(self, i, j):
self.grid = np.zeros((i, j), np.int8)
self.grid[:] = ord(" ")
self.depth = np.zeros((i, j))
self.depth[:] = 1e10
[docs]
def put(self, c, i, j, depth=1e9):
if depth < self.depth[i, j]:
self.grid[i, j] = ord(c)
self.depth[i, j] = depth