from functools import partial, wraps
from io import IOBase, StringIO
from typing import Any, Sequence, Callable, Iterator
import copy
from contextlib import redirect_stdout, redirect_stderr
import itertools
import multiprocessing as mp
import os
import pickle
import platform
import re
import shutil
import string
import subprocess
import sys
import time
import uuid
import warnings
import joblib
import numpy as np
import pandas as pd
import yaml
pj = os.path.join
# defaults, globals
PANDAS_DEFAULT_ORIENT = "records"
PANDAS_TIME_UNIT = "s"
PSET_HASH_ALG = "sha1"
GIT_ADD_ALL = "git add -A -v"
# Make DeprecationWarning visible to users by default.
warnings.simplefilter("default")
# -----------------------------------------------------------------------------
# helpers
# -----------------------------------------------------------------------------
[docs]
def system(cmd: str, **kwds) -> subprocess.CompletedProcess:
"""
Call shell command.
Parameters
----------
cmd
shell command
kwds
keywords passed to `subprocess.run`
"""
try:
return subprocess.run(
cmd,
shell=True,
check=True,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
**kwds,
)
except subprocess.CalledProcessError as ex:
print(ex.stdout.decode())
raise ex
# https://github.com/elcorto/pwtools
[docs]
def makedirs(path: str) -> None:
"""Create `path` recursively, no questions asked."""
if not path.strip() == "":
os.makedirs(path, exist_ok=True)
# https://github.com/elcorto/pwtools
[docs]
def fullpath(path: str) -> str:
return os.path.abspath(os.path.expanduser(path))
[docs]
def itr(func: Callable) -> Callable:
"""Wrap `func` to allow pasing args not as sequence.
Assuming ``func()`` requires a sequence as input: ``func([a,b,c])``, allow
passing ``func(a,b,c)``.
"""
@wraps(func)
def wrapper(*args):
# (arg1,)
if len(args) == 1:
arg = args[0]
return func(arg if is_seq(arg) else [arg])
# (arg1,...,argN)
else:
return func(args)
return wrapper
# https://github.com/elcorto/pwtools
[docs]
def is_seq(seq) -> bool:
if (
isinstance(seq, str)
or isinstance(seq, IOBase)
or isinstance(seq, dict)
):
return False
else:
try:
iter(seq)
return True
except TypeError:
return False
[docs]
def flatten(seq):
for item in seq:
if not is_seq(item):
yield item
else:
for subitem in flatten(item):
yield subitem
[docs]
def file_write(fn: str, txt: str, mode="w"):
makedirs(os.path.dirname(fn))
with open(fn, mode=mode) as fd:
try:
fd.write(txt)
except UnicodeEncodeError:
fd.write(txt.encode("ascii", errors="xmlcharrefreplace").decode())
[docs]
def file_read(fn: str):
with open(fn, "r") as fd:
return fd.read()
[docs]
def pickle_write(fn: str, obj):
makedirs(os.path.dirname(fn))
with open(fn, "wb") as fd:
pickle.dump(obj, fd)
[docs]
def pickle_read(fn: str):
with open(fn, "rb") as fd:
return pickle.load(fd)
[docs]
class PsweepHashError(TypeError):
pass
[docs]
def pset_hash(
dct: dict,
method=PSET_HASH_ALG,
raise_error=True,
skip_special_cols=None,
skip_prefix_cols=True,
skip_postfix_cols=True,
):
"""Reproducible hash of a dict for usage in database (hash of a `pset`)."""
# We target "reproducible" hashes, i.e. not what Python's ``hash`` function
# does, for instance for two interpreter sessions::
#
# $ python
# >>> hash("12")
# 8013944793133897043
#
# $ python
# >>> hash("12")
# 4021864388667373027
#
# We can't hash, say, the pickled byte string of some object (e.g.
# ``hash(pickle.dumps(obj))``), b/c that may contain a ref to its memory
# location which is not what we're interested in. Similarly, also using
# ``repr`` is not reproducible::
#
# >>> class Foo:
# ... pass
#
# >>> repr(Foo())
# '<__main__.Foo object at 0x7fcc68aa9d60>'
# >>> repr(Foo())
# '<__main__.Foo object at 0x7fcc732034c0>'
#
# even though for our purpose, we'd consider the two instances of ``Foo``
# to be the same.
#
# The same observations have been also made elsewhere [1,2]. Esp. [2]
# points to [3] which in turn mentions joblib.hashing.hash(). It's code
# shows how complex the problem is, but so far this is our best bet.
#
# [1] https://death.andgravity.com/stable-hashing
# [2] https://ourpython.com/python/deterministic-recursive-hashing-in-python
# [3] https://stackoverflow.com/a/52175075
assert isinstance(dct, dict), f"{dct=} is not a dict but {type(dct)=}"
if skip_special_cols is not None:
warnings.warn(
"skip_special_cols is deprecated, use skip_prefix_cols",
DeprecationWarning,
)
skip_prefix_cols = skip_special_cols
skip_cols_test = None
if skip_prefix_cols and skip_postfix_cols:
skip_cols_test = lambda key: key.startswith("_") or key.endswith("_")
elif skip_prefix_cols:
skip_cols_test = lambda key: key.startswith("_")
elif skip_postfix_cols:
skip_cols_test = lambda key: key.endswith("_")
if skip_cols_test is not None:
_dct = {
key: val for key, val in dct.items() if not skip_cols_test(key)
}
else:
_dct = dct
# joblib can hash "anything" so we didn't come up with an input that
# actually fails to hash. As such, TypeError is just a guess here. But
# still we don't catch ValueError raised when an invalid hash_name is
# passed (anything other than md5 or sha1).
try:
return joblib.hash(_dct, hash_name=method)
except TypeError as ex:
if raise_error:
raise PsweepHashError(
f"Error in hash calculation of: {dct}"
) from ex
else:
return np.nan
[docs]
def check_calc_dir(calc_dir: str, df: pd.DataFrame):
"""Check calc dir for consistency with database.
Assuming dirs are named::
<calc_dir>/<pset_id1>
<calc_dir>/<pset_id2>
...
check if we have matching dirs to ``_pset_id`` values in the
database.
"""
# fmt: off
pset_ids_disk = set([
m.group()
for m in [
re.match(r"(([0-9a-z]+)-){4}([0-9a-z]+)", x)
for x in os.listdir(calc_dir)
] if m is not None])
# fmt: on
pset_ids_db = set(df._pset_id.values)
return dict(
db_not_disk=pset_ids_db - pset_ids_disk,
disk_not_db=pset_ids_disk - pset_ids_db,
)
[docs]
def logspace(
start, stop, num=50, offset=0, log_func: Callable = np.log10, **kwds
):
"""
Like ``numpy.logspace`` but
* `start` and `stop` are not exponents but the actual bounds
* tuneable log scale strength
Control the strength of the log scale by `offset`, where we use by default
``log_func=np.log10`` and ``base=10`` and return
``np.logspace(np.log10(start + offset), np.log10(stop + offset)) -
offset``. `offset=0` is equal to ``np.logspace(np.log10(start),
np.log10(stop))``. Higher `offset` values result in more evenly spaced
points.
Parameters
----------
start, stop, num, **kwds :
same as in ``np.logspace``
offset :
Control strength of log scale.
log_func :
Must match `base` (pass that as part of `**kwds`). Default is
``base=10`` as in ``np.logspace`` and so ``log_func=np.log10``. If you
want a different `base`, also provide a matching `log_func`, e.g.
``base=e, log_func=np.log``.
Examples
--------
Effect of different `offset` values:
>>> from matplotlib import pyplot as plt
>>> from psweep import logspace
>>> import numpy as np
>>> for ii, offset in enumerate([1e-16,1e-3, 1,2,3]):
... x=logspace(0, 2, 20, offset=offset)
... plt.plot(x, np.ones_like(x)*ii, "o-", label=f"{offset=}")
>>> plt.legend()
"""
base = kwds.pop("base", 10.0)
# fmt: off
assert np.allclose(log_func(base), 1.0), f"log_func and {base=} don't match"
# fmt: on
return (
np.logspace(
log_func(start + offset),
log_func(stop + offset),
num=num,
base=base,
**kwds,
)
- offset
)
[docs]
def intspace(*args, dtype=np.int64, **kwds):
"""Like ``np.linspace`` but round to integers.
The length of the returned array may be lower than specified by `num` if
rounding to ints results in duplicates.
Parameters
----------
*args, **kwds
Same as ``np.linspace``
"""
assert "dtype" not in kwds, "Got 'dtype' multiple times."
return np.unique(np.round(np.linspace(*args, **kwds)).astype(dtype))
[docs]
def get_uuid(retry=10, existing: Sequence = []) -> str:
ret = str(uuid.uuid4())
while ret in existing:
ret = str(uuid.uuid4())
retry -= 1
if retry == 0:
raise Exception(
f"Failed to generate UUID after {retry} attempts. "
f"Existing UUIDs: {existing}"
)
return ret
[docs]
def get_many_uuids(
num: int, retry=10, existing: Sequence = []
) -> Sequence[str]:
generate = lambda: set(str(uuid.uuid4()) for _ in range(num))
ret = generate()
set_existing = set(existing)
while (len(ret) < num) or (len(ret & set_existing) > 0):
ret = generate()
retry -= 1
if retry == 0:
raise Exception(
f"Failed to generate {num} UUIDs after {retry} attempts. "
f"Existing UUIDs: {existing}"
)
return list(ret)
# -----------------------------------------------------------------------------
# git
# -----------------------------------------------------------------------------
[docs]
def git_clean():
return system("git status --porcelain").stdout.decode() == ""
[docs]
def in_git_repo():
# fmt: off
return subprocess.run(
"git status",
check=False,
shell=True,
capture_output=True,
).returncode == 0
# fmt: on
[docs]
def git_enter(use_git: bool, always_commit=False):
if use_git:
path = os.path.basename(fullpath(os.curdir))
if not in_git_repo():
if always_commit:
system(
f"git init; {GIT_ADD_ALL}; git commit -m 'psweep: {path}: init'"
)
else:
raise Exception("no git repo here, create one first")
if not git_clean():
if always_commit:
print("dirty repo, adding all changes")
system(
f"{GIT_ADD_ALL}; git commit -m 'psweep: {path}: local changes'"
)
else:
raise Exception("dirty repo, commit first")
[docs]
def git_exit(use_git: bool, df: pd.DataFrame):
if use_git and (not git_clean()):
path = os.path.basename(fullpath(os.curdir))
system(
f"{GIT_ADD_ALL}; git commit -m 'psweep: {path}: run_id={df._run_id.values[-1]}'"
)
# -----------------------------------------------------------------------------
# pandas
# -----------------------------------------------------------------------------
[docs]
def df_to_json(df: pd.DataFrame, **kwds) -> str:
"""Like `df.to_json` but with defaults for orient, date_unit, date_format,
double_precision.
Parameters
----------
df
DataFrame to convert
kwds
passed to :meth:`df.to_json`
"""
defaults = dict(
orient=PANDAS_DEFAULT_ORIENT,
date_unit=PANDAS_TIME_UNIT,
date_format="iso",
double_precision=15,
)
for key, val in defaults.items():
if key not in kwds.keys():
kwds[key] = val
return df.to_json(**kwds)
[docs]
def df_write(fn: str, df: pd.DataFrame, fmt="pickle", **kwds) -> None:
"""Write DataFrame to disk.
Parameters
----------
fn
filename
df
DataFrame to write
fmt
``{'pickle', 'json'}``
kwds
passed to ``pickle.dump()`` or :func:`df_to_json`
"""
makedirs(os.path.dirname(fn))
if fmt == "pickle":
with open(fn, "wb") as fd:
pickle.dump(df, fd, **kwds)
elif fmt == "json":
df_to_json(df, path_or_buf=fn, **kwds)
else:
raise ValueError("unknown fmt: {}".format(fmt))
[docs]
def df_read(fn: str, fmt="pickle", **kwds):
"""Read DataFrame from file `fn`. See :func:`df_write`."""
if fmt == "pickle":
with open(fn, "rb") as fd:
return pickle.load(fd, **kwds)
elif fmt == "json":
orient = kwds.pop("orient", PANDAS_DEFAULT_ORIENT)
return pd.io.json.read_json(
fn, precise_float=True, orient=orient, **kwds
)
else:
raise ValueError("unknown fmt: {}".format(fmt))
[docs]
def df_print(
df: pd.DataFrame,
index: bool = False,
special_cols=None,
prefix_cols: bool = False,
cols: Sequence[str] = [],
skip_cols: Sequence[str] = [],
):
"""Print DataFrame, by default without the index and prefix columns such
as `_pset_id`.
Similar logic as in `bin/psweep-db2table`, w/o tabulate support but more
features (`skip_cols` for instance).
Column names are always sorted, so the order of names in e.g. `cols`
doesn't matter.
Parameters
----------
df
index
include DataFrame index
prefix_cols
include all prefix columns (`_pset_id` etc.), we don't support skipping
user-added postfix columns (e.g. `result_`)
cols
explicit sequence of columns, overrides `prefix_cols` when prefix columns
are specified
skip_cols
skip those columns instead of selecting them (like `cols` would), use
either this or `cols`; overrides `prefix_cols` when prefix columns
are specified
Examples
--------
>>> import pandas as pd
>>> df=pd.DataFrame(dict(a=rand(3), b=rand(3), _c=rand(3)))
>>> df
a b _c
0 0.373534 0.304302 0.161799
1 0.698738 0.589642 0.557172
2 0.343316 0.186595 0.822023
>>> ps.df_print(df)
a b
0.373534 0.304302
0.698738 0.589642
0.343316 0.186595
>>> ps.df_print(df, prefix_cols=True)
a b _c
0.373534 0.304302 0.161799
0.698738 0.589642 0.557172
0.343316 0.186595 0.822023
>>> ps.df_print(df, index=True)
a b
0 0.373534 0.304302
1 0.698738 0.589642
2 0.343316 0.186595
>>> ps.df_print(df, cols=["a"])
a
0.373534
0.698738
0.343316
>>> ps.df_print(df, cols=["a"], prefix_cols=True)
a _c
0.373534 0.161799
0.698738 0.557172
0.343316 0.822023
>>> ps.df_print(df, cols=["a", "_c"])
a _c
0.373534 0.161799
0.698738 0.557172
0.343316 0.822023
>>> ps.df_print(df, skip_cols=["a"])
b
0.304302
0.589642
0.186595
"""
if special_cols is not None:
warnings.warn(
"special_cols is deprecated, use prefix_cols",
DeprecationWarning,
)
prefix_cols = special_cols
_prefix_cols = set(x for x in df.columns if x.startswith("_"))
if len(cols) > 0:
if len(skip_cols) > 0:
raise ValueError("Use either skip_cols or cols")
disp_cols = set(cols) | (_prefix_cols if prefix_cols else set())
else:
disp_cols = set(df.columns) - (set() if prefix_cols else _prefix_cols)
if len(skip_cols) > 0:
disp_cols = disp_cols - set(skip_cols)
disp_cols = list(disp_cols)
disp_cols.sort()
print(df[disp_cols].to_string(index=index))
[docs]
def df_filter_conds(
df: pd.DataFrame,
conds: Sequence[Sequence[bool]],
op: str = "and",
) -> pd.DataFrame:
"""Filter DataFrame using bool arrays/Series/DataFrames in `conds`.
Fuse all bool sequences in `conds` using `op`. For instance, if
``op="and"``, then we logical-and them, which is equal to
>>> df[conds[0] & conds[1] & conds[2] & ...]
but `conds` can be programmatically generated while the expression above
would need to be changed by hand if `conds` changes.
Parameters
----------
df
DataFrame
conds
Sequence of bool masks, each of length `len(df)`.
op
Bool operator, used as ``numpy.logical_{op}``, e.g. "and", "or",
"xor".
Returns
-------
DataFrame
Examples
--------
>>> df=pd.DataFrame({'a': arange(10), 'b': arange(10)+4})
>>> c1=df.a > 3
>>> c2=df.b < 9
>>> c3=df.a % 2 == 0
>>> df[c1 & c2 & c3]
a b
4 4 8
>>> ps.df_filter_conds(df, [c1, c2, c3])
a b
4 4 8
"""
cc = conds if hasattr(conds, "__len__") else list(conds)
if len(cc) == 0:
return df
for ic, c in enumerate(cc):
# fmt: off
assert len(c) == len(df), \
f"Condition at index {ic} has {len(c)=}, expect {len(df)=}"
# fmt: on
if len(cc) == 1:
msk = cc[0]
else:
assert op in (
op_allowed := ["and", "or", "xor"]
), f"{op=} not one of {op_allowed}"
msk = getattr(np, f"logical_{op}").reduce(cc)
return df[msk]
# -----------------------------------------------------------------------------
# building params
# -----------------------------------------------------------------------------
[docs]
def plist(name: str, seq: Sequence[Any]):
"""Create a list of single-item dicts holding the parameter name and a
value.
>>> plist('a', [1,2,3])
[{'a': 1}, {'a': 2}, {'a': 3}]
"""
return [{name: entry} for entry in seq]
[docs]
@itr
def merge_dicts(args: Sequence[dict]):
"""Start with an empty dict and update with each arg dict
left-to-right."""
dct = {}
assert is_seq(args), f"input {args=} is no sequence"
for entry in args:
assert isinstance(entry, dict), f"{entry=} is no dict"
dct.update(entry)
return dct
[docs]
def itr2params(loops: Iterator[Any]):
"""Transform the (possibly nested) result of a loop over plists (or
whatever has been used to create psets) to a proper list of psets
by flattening and merging dicts.
Examples
--------
>>> a = ps.plist('a', [1,2])
>>> b = ps.plist('b', [77,88])
>>> c = ps.plist('c', ['const'])
>>> # result of loops
>>> list(itertools.product(a,b,c))
[({'a': 1}, {'b': 77}, {'c': 'const'}),
({'a': 1}, {'b': 88}, {'c': 'const'}),
({'a': 2}, {'b': 77}, {'c': 'const'}),
({'a': 2}, {'b': 88}, {'c': 'const'})]
>>> # flatten into list of psets
>>> ps.itr2params(itertools.product(a,b,c))
[{'a': 1, 'b': 77, 'c': 'const'},
{'a': 1, 'b': 88, 'c': 'const'},
{'a': 2, 'b': 77, 'c': 'const'},
{'a': 2, 'b': 88, 'c': 'const'}]
>>> # also more nested stuff is no problem
>>> list(itertools.product(zip(a,b),c))
[(({'a': 1}, {'b': 77}), {'c': 'const'}),
(({'a': 2}, {'b': 88}), {'c': 'const'})]
>>> ps.itr2params(itertools.product(zip(a,b),c))
[{'a': 1, 'b': 77, 'c': 'const'},
{'a': 2, 'b': 88, 'c': 'const'}]
"""
ret = [merge_dicts(flatten(entry)) for entry in loops]
lens = list(map(len, ret))
assert (
len(np.unique(lens)) == 1
), f"not all psets have same length {lens=}\n {ret=}"
return ret
[docs]
@itr
def pgrid(plists: Sequence[Sequence[dict]]) -> Sequence[dict]:
"""Convenience function for the most common loop: nested loops with
``itertools.product``: ``ps.itr2params(itertools.product(a,b,c,...))``.
Parameters
----------
plists
List of :func:`plist()` results. If more than one, you can also provide
plists as args, so ``pgrid(a,b,c)`` instead of ``pgrid([a,b,c])``.
Notes
-----
For a single plist arg, you have to use ``pgrid([a])``. ``pgrid(a)`` won't
work. However, this edge case (passing one plist to pgrid) is not super
useful, since
>>> a=ps.plist("a", [1,2,3])
>>> a
[{'a': 1}, {'a': 2}, {'a': 3}]
>>> ps.pgrid([a])
[{'a': 1}, {'a': 2}, {'a': 3}]
Examples
--------
>>> a = ps.plist('a', [1,2])
>>> b = ps.plist('b', [77,88])
>>> c = ps.plist('c', ['const'])
>>> # same as pgrid([a,b,c])
>>> ps.pgrid(a,b,c)
[{'a': 1, 'b': 77, 'c': 'const'},
{'a': 1, 'b': 88, 'c': 'const'},
{'a': 2, 'b': 77, 'c': 'const'},
{'a': 2, 'b': 88, 'c': 'const'}]
>>> ps.pgrid(zip(a,b),c)
[{'a': 1, 'b': 77, 'c': 'const'},
{'a': 2, 'b': 88, 'c': 'const'}]
"""
assert is_seq(plists), f"input {plists=} is no sequence"
return itr2params(itertools.product(*plists))
[docs]
def filter_params_unique(params: Sequence[dict], **kwds) -> Sequence[dict]:
"""Reduce params to unique psets.
Use pset["_pset_hash"] if present, else calculate hash on the fly.
Parameters
----------
params
kwds
passed to :func:`pset_hash`
"""
get_hash = lambda pset: pset.get("_pset_hash", pset_hash(pset, **kwds))
msk = np.unique([get_hash(pset) for pset in params], return_index=True)[1]
return [params[ii] for ii in np.sort(msk)]
[docs]
def filter_params_dup_hash(
params: Sequence[dict], hashes: Sequence[str], **kwds
) -> Sequence[dict]:
"""Return params with psets whose hash is not in `hashes`.
Use pset["_pset_hash"] if present, else calculate hash on the fly.
Parameters
----------
params
hashes
kwds
passed to :func:`pset_hash`
"""
get_hash = lambda pset: pset.get("_pset_hash", pset_hash(pset, **kwds))
return [pset for pset in params if get_hash(pset) not in hashes]
[docs]
def stargrid(
const: dict,
vary: Sequence[Sequence[dict]],
vary_labels: Sequence[str] = None,
vary_label_col: str = "_vary",
skip_dups=True,
) -> Sequence[dict]:
"""
Helper to create a specific param sampling pattern.
Vary params in a "star" pattern (and not a full pgrid) around constant
values (middle of the "star").
When doing that, duplicate psets can occur. By default try to filter them
out (using :func:`filter_params_unique`) but ignore hash calculation errors
and return non-reduced params in that case.
Examples
--------
>>> import psweep as ps
>>> const=dict(a=1, b=77, c=11)
>>> a=ps.plist("a", [1,2,3,4])
>>> b=ps.plist("b", [77,88,99])
>>> c=ps.plist("c", [11,22,33,44])
>>> ps.stargrid(const, vary=[a, b])
[{'a': 1, 'b': 77, 'c': 11},
{'a': 2, 'b': 77, 'c': 11},
{'a': 3, 'b': 77, 'c': 11},
{'a': 4, 'b': 77, 'c': 11},
{'a': 1, 'b': 88, 'c': 11},
{'a': 1, 'b': 99, 'c': 11}]
>>> ps.stargrid(const, vary=[a, b], skip_dups=False)
[{'a': 1, 'b': 77, 'c': 11},
{'a': 2, 'b': 77, 'c': 11},
{'a': 3, 'b': 77, 'c': 11},
{'a': 4, 'b': 77, 'c': 11},
{'a': 1, 'b': 77, 'c': 11},
{'a': 1, 'b': 88, 'c': 11},
{'a': 1, 'b': 99, 'c': 11}]
>>> ps.stargrid(const, vary=[a, b], vary_labels=["a", "b"])
[{'a': 1, 'b': 77, 'c': 11, '_vary': 'a'},
{'a': 2, 'b': 77, 'c': 11, '_vary': 'a'},
{'a': 3, 'b': 77, 'c': 11, '_vary': 'a'},
{'a': 4, 'b': 77, 'c': 11, '_vary': 'a'},
{'a': 1, 'b': 88, 'c': 11, '_vary': 'b'},
{'a': 1, 'b': 99, 'c': 11, '_vary': 'b'}]
>>> ps.stargrid(const, vary=[ps.itr2params(zip(a,c)),b], vary_labels=["a+c", "b"])
[{'a': 1, 'b': 77, 'c': 11, '_vary': 'a+c'},
{'a': 2, 'b': 77, 'c': 22, '_vary': 'a+c'},
{'a': 3, 'b': 77, 'c': 33, '_vary': 'a+c'},
{'a': 4, 'b': 77, 'c': 44, '_vary': 'a+c'},
{'a': 1, 'b': 88, 'c': 11, '_vary': 'b'},
{'a': 1, 'b': 99, 'c': 11, '_vary': 'b'}]
>>> ps.stargrid(const, vary=[ps.pgrid([zip(a,c)]),b], vary_labels=["a+c", "b"])
[{'a': 1, 'b': 77, 'c': 11, '_vary': 'a+c'},
{'a': 2, 'b': 77, 'c': 22, '_vary': 'a+c'},
{'a': 3, 'b': 77, 'c': 33, '_vary': 'a+c'},
{'a': 4, 'b': 77, 'c': 44, '_vary': 'a+c'},
{'a': 1, 'b': 88, 'c': 11, '_vary': 'b'},
{'a': 1, 'b': 99, 'c': 11, '_vary': 'b'}]
"""
params = []
if vary_labels is not None:
assert len(vary_labels) == len(
vary
), f"{vary_labels=} and {vary=} must have same length"
for ii, plist in enumerate(vary):
for dct in plist:
if vary_labels is not None:
label = {vary_label_col: vary_labels[ii]}
_dct = merge_dicts(dct, label)
else:
_dct = dct
params.append(merge_dicts(const, _dct))
if skip_dups:
try:
return filter_params_unique(
params,
raise_error=True,
skip_prefix_cols=True,
skip_postfix_cols=True,
)
except PsweepHashError:
return params
else:
return params
# -----------------------------------------------------------------------------
# run study
# -----------------------------------------------------------------------------
# tmpsave: That's cool, but when running in parallel, we loose the ability to
# store the whole state of the study calculated thus far. For that we would
# need an extra thread that periodically checks for or -- even better -- gets
# informed by workers about finished work and collects the so-far written temp
# results into a global df -- maybe useful for monitoring progress.
[docs]
def worker_wrapper(
pset: dict,
worker: Callable,
*,
tmpsave: bool = False,
verbose: bool | Sequence[str] = False,
simulate: bool = False,
) -> dict:
"""
Add those prefix fields (e.g. `_time_utc`) to `pset` which can be
determined at call time.
Call worker on exactly one pset. Return updated pset built from
``pset.update(worker(pset))``. Do verbose printing.
"""
assert "_pset_id" in pset
assert "_run_id" in pset
assert "_calc_dir" in pset
time_start = pd.Timestamp(time.time(), unit=PANDAS_TIME_UNIT)
pset.update(_time_utc=time_start, _exec_host=platform.node())
if verbose:
df_row_print = pd.DataFrame([pset], index=[time_start])
if isinstance(verbose, bool) and verbose:
df_print(df_row_print, index=True)
elif is_seq(verbose):
df_print(df_row_print, index=True, cols=verbose)
else:
raise ValueError(f"Type of {verbose=} not understood.")
t0 = time.time()
if not simulate:
pset.update(worker(pset))
pset["_pset_runtime"] = time.time() - t0
if tmpsave:
fn = pj(
pset["_calc_dir"],
"tmpsave",
pset["_run_id"],
pset["_pset_id"] + ".pk",
)
pickle_write(fn, pset)
return pset
[docs]
def capture_logs_wrapper(
pset: dict,
worker: Callable,
capture_logs: str,
db_field: str = "_logs",
) -> dict:
"""Capture and redirect stdout and stderr produced in worker().
Note the limitations mentioned in [1]:
Note that the global side effect on sys.stdout means that this context
manager is not suitable for use in library code and most threaded
applications. It also has no effect on the output of subprocesses.
However, it is still a useful approach for many utility scripts.
So if users rely on playing with sys.stdout/stderr in worker(), then they
sould not use this feature and take care of logging themselves.
[1] https://docs.python.org/3/library/contextlib.html#contextlib.redirect_stdout
"""
fn = f"{pset['_calc_dir']}/{pset['_pset_id']}/logs.txt"
if capture_logs == "file":
makedirs(os.path.dirname(fn))
with open(fn, "w") as fd, redirect_stdout(fd), redirect_stderr(fd):
return worker(pset)
elif capture_logs in ["db", "db+file"]:
with StringIO() as fd:
with redirect_stdout(fd), redirect_stderr(fd):
ret = worker(pset)
txt = fd.getvalue()
ret[db_field] = txt
if capture_logs == "db+file":
file_write(fn, txt)
return ret
else:
raise ValueError(f"Illegal value {capture_logs=}")
[docs]
def run(
worker: Callable,
params: Sequence[dict],
df: pd.DataFrame = None,
poolsize: int = None,
dask_client=None,
save: bool = True,
tmpsave: bool = False,
verbose: bool | Sequence[str] = False,
calc_dir: str = "calc",
simulate: bool = False,
database_dir: str = None,
database_basename: str = "database.pk",
backup: bool = False,
git: bool = False,
skip_dups: bool = False,
capture_logs: str = None,
) -> pd.DataFrame:
"""
Call `worker` for each `pset` in `params`. Populate a DataFrame with rows
from each call ``worker(pset)``.
Parameters
----------
worker
must accept one parameter: `pset` (a dict ``{'a': 1, 'b': 'foo',
...}``), return either an update to `pset` or a new dict, result will
be processes as ``pset.update(worker(pset))``
params
each dict is a pset ``{'a': 1, 'b': 'foo', ...}``
df
append rows to this DataFrame, if None then either create new one or
read existing database file from disk if found
poolsize
* None : use serial execution
* int : use multiprocessing.Pool (even for ``poolsize=1``)
dask_client
A dask client. Use this or ``poolsize``.
save
save final ``DataFrame`` to ``<database_dir>/<database_basename>``
(pickle format only), default: "calc/database.pk", see also
`database_dir`, `calc_dir` and `database_basename`
tmpsave
save the result dict from each ``pset.update(worker(pset))`` from each
`pset` to ``<calc_dir>/tmpsave/<run_id>/<pset_id>.pk`` (pickle format
only), the data is a dict, not a DataFrame row
verbose
* bool : print the current DataFrame row
* sequence : list of DataFrame column names, print the row but only
those columns
calc_dir
Dir where calculation artifacts can be saved if needed, such as dirs
per pset ``<calc_dir>/<pset_id>``. Will be added to the database in
``_calc_dir`` field.
simulate
run everything in ``<calc_dir>.simulate``, don't call `worker`, i.e. save
what the run would create, but without the results from `worker`,
useful to check if `params` are correct before starting a production run
database_dir
Path for the database. Default is ``<calc_dir>``.
database_basename
``<database_dir>/<database_basename>``, default: "database.pk"
backup
Make backup of ``<calc_dir>`` to ``<calc_dir>.bak_<timestamp>_run_id_<run_id>``
where ``<run_id>`` is the latest ``_run_id`` present in ``df``
git
Use ``git`` to commit all files written and changed by the current run
(``_run_id``). Make sure to create a ``.gitignore`` manually before if
needed.
skip_dups
Skip psets whose hash is already present in `df`. Useful when repeating
(parts of) a study.
capture_logs
{'db', 'file', 'db+file', None}
Redirect stdout and stderr generated in ``worker()`` to database ('db')
column ``_logs``, file ``<calc_dir>/<pset_id>/logs.txt``, or both. If
``None`` then do nothing (default). Useful for capturing per-pset log
text, e.g. ``print()`` calls in `worker` will be captured.
Returns
-------
df
The database build from `params`.
"""
# Don't in-place alter dicts in params we get as input.
params = copy.deepcopy(params)
database_dir = calc_dir if database_dir is None else database_dir
git_enter(git)
if simulate:
calc_dir_sim = calc_dir + ".simulate"
if os.path.exists(calc_dir_sim):
shutil.rmtree(calc_dir_sim)
makedirs(calc_dir_sim)
old_db = pj(database_dir, database_basename)
if os.path.exists(old_db):
shutil.copy(old_db, pj(calc_dir_sim, database_basename))
else:
warnings.warn(
f"simulate: {old_db} not found, will create new db in "
f"{calc_dir_sim}"
)
database_fn = pj(calc_dir_sim, database_basename)
calc_dir = calc_dir_sim
else:
database_fn = pj(database_dir, database_basename)
if df is None:
if os.path.exists(database_fn):
df = df_read(database_fn)
else:
df = pd.DataFrame()
if len(df) == 0:
pset_seq_old = -1
run_seq_old = -1
else:
# int(): convert np.int64 back to Python int
pset_seq_old = int(df._pset_seq.values.max())
run_seq_old = int(df._run_seq.values.max())
if backup and len(df.index) > 0:
stamp = df._time_utc.max().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
dst = f"{calc_dir}.bak_{stamp}_run_id_{df._run_id.values[-1]}"
assert not os.path.exists(dst), (
"backup destination {dst} exists, seems like there has been no new "
"data in {calc_dir} since the last backup".format(
dst=dst, calc_dir=calc_dir
)
)
shutil.copytree(calc_dir, dst)
for pset in params:
pset["_pset_hash"] = pset_hash(pset)
if skip_dups and len(df) > 0:
params = filter_params_dup_hash(params, df._pset_hash.values)
run_id = get_uuid(existing=df._run_id.values if len(df) > 0 else [])
pset_ids = get_many_uuids(
len(params), existing=df._pset_id.values if len(df) > 0 else []
)
for ii, (pset, pset_id) in enumerate(zip(params, pset_ids)):
pset["_pset_id"] = pset_id
pset["_run_seq"] = run_seq_old + 1
pset["_pset_seq"] = pset_seq_old + ii + 1
pset["_run_id"] = run_id
pset["_calc_dir"] = calc_dir
if capture_logs is not None:
worker = partial(
capture_logs_wrapper, worker=worker, capture_logs=capture_logs
)
worker = partial(
worker_wrapper,
worker=worker,
tmpsave=tmpsave,
verbose=verbose,
simulate=simulate,
)
if (poolsize is None) and (dask_client is None):
results = list(map(worker, params))
else:
assert [poolsize, dask_client].count(
None
) == 1, "Use either poolsize or dask_client."
if dask_client is None:
with mp.Pool(poolsize) as pool:
results = pool.map(worker, params)
else:
futures = dask_client.map(worker, params)
results = dask_client.gather(futures)
df = pd.concat((df, pd.DataFrame(results)), sort=False, ignore_index=True)
if save:
df_write(database_fn, df)
git_exit(git, df)
return df
# -----------------------------------------------------------------------------
# (HPC cluster) batch runs using file templates
# -----------------------------------------------------------------------------
[docs]
class Machine:
[docs]
def __init__(self, machine_dir: str, jobscript_name: str = "jobscript"):
"""
Expected templates layout::
templates/machines/<name>/info.yaml
^^^^^^^^^^^^^^^^^^^^^^^^^------------- machine_dir
templates/machines/<name>/jobscript
^^^^^^^^^--- template.basename
"""
self.name = os.path.basename(os.path.normpath(machine_dir))
self.template = FileTemplate(
pj(machine_dir, jobscript_name), target_suffix="_" + self.name
)
with open(pj(machine_dir, "info.yaml")) as fd:
info = yaml.safe_load(fd)
for key, val in info.items():
assert key not in self.__dict__, f"cannot overwrite '{key}'"
setattr(self, key, val)
def __repr__(self):
return f"{self.name}:{self.template}"
[docs]
class FileTemplate:
[docs]
def __init__(self, filename, target_suffix=""):
self.filename = filename
self.basename = os.path.basename(filename)
self.dirname = os.path.dirname(filename)
self.targetname = f"{self.basename}{target_suffix}"
def __repr__(self):
return self.filename
def fill(self, pset):
try:
return string.Template(file_read(self.filename)).substitute(pset)
except:
print(f"Failed to fill template: {self.filename}", file=sys.stderr)
raise
[docs]
def gather_calc_templates(calc_templ_dir):
return [
FileTemplate(pj(calc_templ_dir, basename))
for basename in os.listdir(calc_templ_dir)
]
[docs]
def gather_machines(machine_templ_dir):
return [
Machine(pj(machine_templ_dir, basename))
for basename in os.listdir(machine_templ_dir)
]
# If we ever add a "simulate" kwd here: don't pass that thru to run() b/c
# there this prevents worker() from being executed, but that's what we always
# want here since it writes only input files. Instead, just set calc_dir =
# calc_dir_sim and copy the database as in run() and go. Don't copy the
# run_*.sh scripts b/c they are generated afresh anyway.
#
[docs]
def prep_batch(
params: Sequence[dict],
*,
calc_templ_dir: str = "templates/calc",
machine_templ_dir: str = "templates/machines",
git: bool = False,
write_pset: bool = False,
**kwds,
) -> pd.DataFrame:
"""
Write files based on templates.
Parameters
----------
params
See :func:`run`
calc_templ_dir, machine_templ_dir
Dir with templates.
git
Use git to commit local changes.
write_pset
Write the input `pset` to ``<calc_dir>/<pset_id>/pset.pk``.
**kwds
Passed to :func:`run`.
Returns
-------
df
The database build from `params`.
"""
git_enter(git)
calc_dir = kwds.get("calc_dir", "calc")
calc_templates = gather_calc_templates(calc_templ_dir)
machines = gather_machines(machine_templ_dir)
templates = calc_templates + [m.template for m in machines]
def worker(pset):
for template in templates:
file_write(
pj(calc_dir, pset["_pset_id"], template.targetname),
template.fill(pset),
)
if write_pset:
pickle_write(pj(calc_dir, pset["_pset_id"], "pset.pk"), pset)
return {}
df = run(
worker,
params,
git=False,
**kwds,
)
msk_latest = df._run_seq == df._run_seq.values.max()
msk_old = df._run_seq < df._run_seq.values.max()
for machine in machines:
txt = ""
for pfx, msk in [("# ", msk_old), ("", msk_latest)]:
if msk.any():
txt += "\n"
txt += "\n".join(
f"{pfx}cd $here/{pset_id}; {machine.subcmd} {machine.template.targetname} # run_seq={run_seq} pset_seq={pset_seq}"
for pset_id, pset_seq, run_seq in zip(
df[msk]._pset_id.values,
df[msk]._pset_seq.values,
df[msk]._run_seq.values,
)
)
file_write(
f"{calc_dir}/run_{machine.name}.sh",
f"#!/bin/sh\n\nhere=$(readlink -f $(dirname $0))\n{txt}\n",
)
git_exit(git, df)
return df