import inspect
import uuid
import subprocess as sp
import dill
import tcmu
from tcmu import cache, connect
from typing import List, Any
import jsonpickle
import os
import hashlib
import json
[docs]
def hash(obj: Any) -> str:
'''
Hash any python object using SHA-256. If the object is a dictionary use JSON to
put it in a standard format. If there are non-picklable objects
use jsonpickle to pickle them anyway.
'''
if isinstance(obj, dict):
try:
obj = json.dumps(obj, sort_keys=True, indent=4)
except:
obj = jsonpickle.encode(obj)
return hashlib.sha256(obj.encode('utf-8')).hexdigest()
@cache
def _python_path(server: connect.Server = connect.Local()) -> str:
"""
Sometimes it is necessary to have the Python path as some environments don't have its path.
This function attempts to find the Python path and returns it.
"""
python = ""
try:
python = server.execute("which python").strip()
except sp.CalledProcessError:
python == ""
if python == "" or not server.path_exists(python):
try:
python = server.execute("which python3").strip()
except sp.CalledProcessError:
python == ""
# we default to the python executable
if python == "" or not server.path_exists(python):
python = "python"
return python
[docs]
class WorkFlow:
'''
The ``WorkFlow`` class is used to generate Python scripts of functions and running/submitting them.
``WorkFlow`` objects act as decorators for functions and supports writing the function as a python script that
can be submitted. It also supports checking for the status of previous runs of the workflow. If there are return
statements in the script, they will be written to an output file, loaded, and returned to the user.
Args:
server: Server object that provides sbatch defaults and standard pre- and postambles.
delete_files: Whether to delete files after the workflow has finished running.
This will not delete output files that are needed for the return statements in the function.
preambles: Preambles used when running the script remotely.
postambles: Postambles used when running the script remotely.
sbatch: Sbatch settings used when running the script remotely.
Example usage:
.. code-block::
python
@WorkFlow()
def optimize(molecule: str) -> "plams.Molecule":
import tcmu
from scm import plams
with tcmu.DFTBJob() as job:
job.molecule(molecule)
job.optimization()
return plams.Molecule(job.output_mol_path)
optimized_mol = optimize('example.xyz')
print(optimized_mol)
Gives as output:
.. code-block::
Atoms:
1 C -1.560651 0.018909 -0.343850
2 C -0.564164 0.536286 0.648202
3 H 0.579596 -0.095326 0.334690
4 Cl -0.849164 0.123036 2.289769
5 H -0.271691 1.576823 0.520219
6 H -1.197774 0.186010 -1.358023
7 H -2.514720 0.543213 -0.231771
8 H -1.740146 -1.045976 -0.197945
9 F 1.478087 -0.563635 0.032020
'''
def __init__(self, server: connect.Server = None, delete_files: bool = True, preambles: List[str] = None, postambles: List[str] = None, sbatch: dict = None):
self.server = server
if server is None:
self.server = tcmu.connect.get_current_server()()
self.preambles = preambles
if preambles is None:
self.preambles = self.server.preamble_defaults.get('AMS', [])
self.preambles.append(self.server.program_modules.get('AMS', {}).get('latest', ''))
self.postambles = postambles
if postambles is None:
self.postambles = self.server.postamble_defaults.get('AMS', [])
self.sbatch = sbatch
if sbatch is None:
self.sbatch = self.server.sbatch_defaults
self.delete_files = delete_files
def __call__(self, *args, **kwargs):
return self._call_method(*args, **kwargs)
def _call_method(*args, **kwargs):
self, func = args[:2]
self.func = func
self.name = func.__name__
self.parameters = inspect.signature(func).parameters
self._call_method = self.execute
return self
def __str__(self):
s = f'WorkFlow({self.name}):\n'
s += ' Parameters:\n'
param_names = []
for param in self.parameters.values():
p = param.name
if param.annotation is not param.empty:
p += f': {param.annotation}'
param_names.append(p)
param_name_len = max([len(param) for param in param_names])
for param in self.parameters.values():
p = param.name
if param.annotation is not param.empty:
p += f': {param.annotation}'
s += f' {p.ljust(param_name_len)}'
if param.kind in [param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY] and param.default is param.empty:
s += ' #REQUIRED'
s += '\n'
return s
def _write_files(self, args: dict):
with self.server.open_file(self.py_path, 'w+') as script:
script.write('#====== LOAD STATE ========#\n')
script.write('import dill\n')
script.write('import jsonpickle\n\n')
for arg_name, arg_val in args.items():
if arg_name in self.parameters and self.parameters[arg_name].annotation != inspect._empty:
annotation = self.parameters[arg_name].annotation
if isinstance(annotation, str):
annotation = '"' + annotation + '"'
else:
annotation = annotation.__name__
script.write(f'# type: {annotation}\n')
else:
script.write(f'# type: {type(arg_val).__name__}\n')
try:
script.write(f'{arg_name} = dill.loads({dill.dumps(arg_val)})\n\n')
except dill.PickleError:
script.write(f'{arg_name} = jsonpickle.decode(\'{jsonpickle.encode(arg_val, unpicklable=True)}\')\n\n')
script.write('#========= SCRIPT =========#\n')
script.write('import tcmu\nimport atexit\n\n\n')
script.write(f'''@atexit.register
def on_exception():
if tcmu.job.workflow_db.get_status("{self.hash}") == "RUNNING":
tcmu.job.workflow_db.set_failed("{self.hash}")
def __end_workflow__():
tcmu.job.workflow_db.set_finished("{self.hash}")
exit()\n\n\n''')
code_lines = extract_func_code(self.func)
code_lines = handle_return_statements(code_lines, self.return_path)
script.write(code_lines)
script.write(f'\n\n\n# indicate to the db that this wf has finished:\n__end_workflow__()\n')
with self.server.open_file(self.sh_path, 'w') as file:
file.write('#!/bin/bash\n\n')
for line in self.preambles:
file.write(line + '\n')
file.write(f'{_python_path(self.server)} {self.py_path}\n')
for line in self.postambles:
file.write(line + '\n')
if self.delete_files:
file.write(f'rm {self.py_path}\n')
file.write(f'rm {self.out_path}\n')
file.write(f'rm {self.sh_path}\n')
[docs]
def execute(self, *args, dependency=None, restart=False, **kwargs):
self.hash = hash({'wf': self.func, 'args': args, 'kwargs': kwargs})
file_name = '.' + self.name + '_' + self.hash
self.sh_path = f'{file_name}.sh'
self.py_path = f'{file_name}.py'
self.out_path = f'{file_name}.out'
self.return_path = f'{file_name}.json'
# if a restart was requested we simply delete known data
if restart:
tcmu.job.workflow_db.delete_data(self.hash)
# set up dependencies between jobs
if dependency is None:
dependency = []
if len(dependency) > 0:
if any(option not in self.sbatch for option in ['d', 'dependency']):
self.sbatch.setdefault('dependency', 'afterany')
for dep in dependency:
if not hasattr(dep, 'slurm_job_id'):
continue
self.sbatch['dependency'] = self.sbatch['dependency'] + f':{dep.slurm_job_id}'
if tcmu.job.workflow_db.can_skip(self.hash):
if tcmu.job.workflow_db.get_status(self.hash) == 'RUNNING':
tcmu.log.info('Workflow is currently running.')
elif tcmu.job.workflow_db.get_status(self.hash) == 'SUCCESS':
tcmu.log.info('Workflow run has already been completed!')
elif tcmu.job.workflow_db.get_status(self.hash) == 'FAILED':
tcmu.log.info('Workflow was run but failed')
# Add slurm_job_id to self if it is skippable
temp_data = tcmu.job.workflow_db.get_data(self.hash)
if 'slurm_job_id' in temp_data:
self.slurm_job_id = temp_data["slurm_job_id"]
box = f'WorkFlow({self.name}):\n args = (\n'
for arg in args:
box += f' {repr(arg)},\n'
box += ' )\n kwargs = {\n'
for k, v in kwargs.items():
box += f' {k}: {repr(v)},\n'
box += ' }\n'
box += f' hash = {self.hash}'
tcmu.log.boxed(box)
return self.__load_return()
_args = {}
for param_name, arg in zip(self.parameters, args):
_args[param_name] = arg
_args.update(kwargs)
for param_name, param in self.parameters.items():
if param.default != param.empty:
_args.setdefault(param_name, param.default)
for glob_name, glob in inspect.getclosurevars(self.func).globals.items():
_args[glob_name] = glob
self._write_files(_args)
if not ruff_check_script(self.py_path, ignored_codes=['E402', 'F811', 'F401']):
raise Exception('Python script will fail!')
if tcmu.slurm.has_slurm():
if any(option not in self.sbatch for option in ["o", "output"]):
self.sbatch.setdefault("o", self.out_path)
sbatch_result = tcmu.slurm.sbatch(self.sh_path, self.server, **self.sbatch)
self.slurm_job_id = sbatch_result.id
tcmu.job.workflow_db.set_running(self.hash, slurm_job_id=self.slurm_job_id)
else:
runfile_dir, runscript = os.path.split(self.sh_path)
if runfile_dir == '':
runfile_dir = '.'
command = ["./" + runscript] if os.name == "posix" else ["sh", runscript]
self.server.chmod(744, self.sh_path)
with open(self.out_path, "w+") as out:
sp.run(command, cwd=runfile_dir, stdout=out, shell=True)
return self.__load_return()
def __load_return(self):
return WorkFlowOutput(self.return_path)
[docs]
class WorkFlowOutput:
def __init__(self, path: str):
self.path = path
@property
def is_available(self):
return os.path.exists(self.path)
@property
def value(self):
if not self.is_available:
return None
with open(self.path) as ret:
return jsonpickle.decode(ret.read())
def __str__(self):
return f'WorkFlowOutput({self.path}) = {self.value}'
[docs]
def handle_return_statements(lines: List[str], return_file: str) -> str:
# go through the script and check if there are any return statements
new_lines = []
for line in lines:
if not line.strip().startswith('return'):
new_lines.append(line)
continue
return_variable = line.removeprefix('return').strip()
indent = ' ' * (len(line) - len(line.lstrip()))
new_lines.append(indent + f"with open('{return_file}', 'w+') as ret:")
new_lines.append(indent + f" ret.write(jsonpickle.encode({return_variable}))")
new_lines.append(indent + '__end_workflow__()')
return '\n'.join(new_lines)
[docs]
def ruff_check_script(path: str, ignored_codes=None) -> bool:
'''
Check if a python script in `path` will run according to Ruff.
Args:
path: a path to the python script to check.
ignored_codes: the Ruff warning and error codes to ignore.
Returns:
A boolean specifying if the return code is 0 or not.
'''
if ignored_codes is None:
ignored_codes = []
out = sp.run(f'ruff check {path} --ignore {",".join(ignored_codes)}', shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
# if the ruff check failed we get a non-zero exit
if out.returncode != 0:
# simply print the output if we failed
tcmu.log.log('Found issue when parsing code with Ruff:')
tcmu.log.boxed(out.stdout.decode())
# and the code itself
tcmu.log.log(f'Code ({path}):')
with open(path) as file:
tcmu.log.boxed(file.read())
return False
return True
if __name__ == '__main__':
from scm import plams
@WorkFlow()
def optimize(molecule: str) -> "plams.Molecule":
import tcmu
from scm import plams
import time
time.sleep(10)
with tcmu.DFTBJob() as job:
job.molecule(molecule)
job.optimization()
return plams.Molecule(job.output_mol_path)
optimized_mol = optimize('example.xyz', restart=False)
optimized_mol = optimize(optimized_mol, restart=False)
print(optimized_mol)