Commit 1c972920 authored by sparonuz's avatar sparonuz
Browse files

Merge branch 'master' into Issue#86

parents 8a1af14c 6a143db8
import paramiko
import os
import stat
import re
from UtilsRPE.Functions import load_vault
import pickle
def remote_isdir(attr):
return stat.S_ISDIR(attr.st_mode)
class SSH:
def __init__(self, user, host, remote_scratch=""):
# Init
self.user = user
self.host = host
self._host_config_id = None
self._ssh = None
self._ssh_config = None
self._user_config_file = None
self._host_config = None
self.sftp = None
self.remote_scratch = remote_scratch
# SSH
self.connect()
self.init_transport()
def init_transport(self):
self.sftp = self._ssh.open_sftp()
def get(self, remote_path, local_path):
if not self.sftp:
self.init_transport()
self.sftp.get(remote_path, local_path)
def put(self, local_path, remote_path):
if not self.sftp:
self.init_transport()
self.sftp.put(local_path, remote_path)
def connect(self):
"""
Creates ssh connection to host
:return: True if connection is created, False otherwise
:rtype: bool
"""
try:
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh_config = paramiko.SSHConfig()
self._user_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(self._user_config_file):
with open(self._user_config_file) as f:
# noinspection PyTypeChecker
self._ssh_config.parse(f)
self._host_config = self._ssh_config.lookup(self.host)
if 'identityfile' in self._host_config:
self._host_config_id = self._host_config['identityfile']
if 'proxycommand' in self._host_config:
self._proxy = paramiko.ProxyCommand(self._host_config['proxycommand'])
self._ssh.connect(self._host_config['hostname'], 22, username=self.user,
key_filename=self._host_config_id, sock=self._proxy)
else:
self._ssh.connect(self._host_config['hostname'], 22, username=self.user,
key_filename=self._host_config_id)
return True
except IOError as e:
print('Can not create ssh connection to {0}: {1}', self.host, e.strerror)
return False
def list_experiments_(self, expid):
if not self.sftp:
self.init_transport()
experiment_path = "/gpfs/scratch/pr1egi00/pr1egi23/%s" % expid
list_of_experiments = []
# Enter in experiment folders
experiment_pattern = "[0-9]+"
member_pattern = "[a-z][0-9]+"
variable_pattern = "ref_[0-9]+"
for attr in list(self.sftp.listdir_attr(experiment_path)):
filename = attr.filename
if re.match(experiment_pattern, filename):
if remote_isdir(attr):
member_path = "%s/%s" % (experiment_path, filename)
member_list = list(self.sftp.listdir_attr(member_path))
for member in member_list:
if re.match(member_pattern, member.filename):
if remote_isdir(member):
member_filepath = "%s/%s" % (member_path, member.filename)
list_of_experiments.append(member_filepath)
variable_list = list(self.sftp.listdir_attr(member_filepath))
for variable in variable_list:
if remote_isdir(variable):
if re.match(variable_pattern, variable.filename):
variable_filepath = "%s/%s" % (member_filepath, variable.filename)
list_of_experiments.append(variable_filepath)
return list_of_experiments
def find_results(self, expid):
namelist_pattern = "emulator_variable_precisions\(([0-9]+)\) *= *([0-9]+)"
experiments = self.list_experiments_(expid)
database = {"error": {}, "precisions": {}}
for experiment in experiments:
accuracy_filepath = "%s/accuracy.txt" % experiment
namelist_filepath = "%s/namelist_precisions" % experiment
with self.sftp.open(accuracy_filepath) as f:
error = float(f.read())
database["error"][experiment] = error
database["precisions"][experiment] = {}
with self.sftp.open(namelist_filepath) as f:
for l in f:
m = re.match(namelist_pattern, l)
if m:
var_id, precision = [ int(x) for x in m.groups(0)]
database["precisions"][experiment][var_id] = precision
return database
def execute(self, command):
return self._ssh.exec_command(command)
def check_completed_experiments(self):
if not self.sftp:
self.init_transport()
experiment_path = self.remote_scratch
list_of_experiments = []
# Enter in experiment folders
experiment_pattern = "[0-9]+"
member_pattern = "[a-z][0-9]+"
variable_pattern = "ref_[0-9]+"
for attr in list(self.sftp.listdir_attr(experiment_path)):
filename = attr.filename
if re.match(experiment_pattern, filename):
if remote_isdir(attr):
member_path = "%s/%s" % (experiment_path, filename)
member_list = list(self.sftp.listdir_attr(member_path))
for member in member_list:
if re.match(member_pattern, member.filename):
if remote_isdir(member):
member_filepath = "%s/%s" % (member_path, member.filename)
list_of_experiments.append(member_filepath)
variable_list = list(self.sftp.listdir_attr(member_filepath))
for variable in variable_list:
if remote_isdir(variable):
if re.match(variable_pattern, variable.filename):
variable_filepath = "%s/%s" % (member_filepath, variable.filename)
list_of_experiments.append(variable_filepath)
return list_of_experiments
if __name__ == "__main__":
# Load vault
vault = load_vault("/home/Earth/otintopr/RPE-ROMS/AutoRPE/vault.pkl")
# SSH to remote machine and download results
hostname = "mn1.bsc.es"
user = "pr1egi23"
expid = "a1bo"
s = SSH(user=user, host=hostname)
results = s.find_results(expid)
variables = [variable for variable in vault.variables if variable.id is not None]
var_dict = {}
for variable in variables:
var_dict[variable.id] = variable
for variable in variables:
variable.error = 1000
for key, error in results["error"].items():
for var_id, precision in results["precisions"][key].items():
if precision == 10:
if error < var_dict[var_id].error:
var_dict[var_id].error = error
ordered_vars = var_dict.values()
ordered_vars = sorted(ordered_vars, key=lambda x: -x.error)
for var in ordered_vars:
print("%4i\t\t%8.2e" % (var.id, var.error))
save_vault = True
if save_vault:
# This workaround is needed to allow pickle to write the file, don't know yet about its risks
import sys
sys.setrecursionlimit(2000)
new_vault_name = "vault_with_processed_a1bo.pkl"
pickle.dump(vault, open(new_vault_name, "wb"))
def completed_jobs_have_successful_defined(all_jobs):
# Check if completed jobs have a defined is_successful property
for job in all_jobs:
if job.is_completed:
if job.is_successful is None:
raise AssertionError("Completed but without a is_successful defined!")
def successful_job_descendants(all_jobs):
# A successful job should not have a non-banned descendant that is not successful
for job in all_jobs:
if job.is_completed:
if job.is_successful is True:
for descendant in job.descendants():
if descendant.is_banned is False:
if descendant.is_successful is not True:
raise AssertionError("Descendant of a successful job unsuccessful")
def non_successful_job_ancestors(all_jobs):
# An ancestor of a non successful job should not be successful
for job in all_jobs:
# For any given completed job that is not banned
if job.is_completed and job.is_banned is False and job.is_successful is False:
# Look at all the ancestors to assert that any of them is successful
for ancestor in job.ancestors():
if ancestor.is_banned is False and ancestor.is_successful is True:
raise AssertionError("Ancestor of a unsuccessful job successful")
def check_duplicates(queue):
# Check for duplicates
duplicated_completed_running = [job for job in queue.completed if job in queue.running]
duplicated_completed_queue = [job for job in queue.completed if job in queue]
duplicated_running_queue = [job for job in queue.running if job in queue]
all_duplicated = duplicated_completed_running + duplicated_completed_queue + duplicated_running_queue
duplicated_jobs = len(all_duplicated)
if duplicated_jobs > 0:
names = " ".join([j.identifier() for j in all_duplicated])
raise AssertionError("Duplicated jobs!: %s" % names)
def check_banned_variables(all_jobs, ban_list):
# Check that any non banned job has banned variables
for job in all_jobs:
if job.is_banned is False:
for variable in ban_list:
if variable in job.variable_set:
all_variables_banned = True
for v in job.variable_set:
if v not in ban_list:
all_variables_banned = False
if all_variables_banned is False:
raise AssertionError("Banned variable present in variable set!")
def job_not_present_in_any_list(all_jobs, queue):
for job in all_jobs:
anywhere = job in queue or job in queue.running or job in queue.completed
if not anywhere:
raise AssertionError("Job has been lost!")
def banned_job_has_descendants_running(all_jobs, queue):
for job in all_jobs:
if job.is_banned:
for descendant in job.descendants():
if descendant in queue.running:
raise AssertionError("Descendant of a banned set running.")
def check_job_in_running_is_successful(all_jobs, queue):
for job in all_jobs:
if job.is_successful:
if job in queue.running:
raise AssertionError("Job in running with is_successful defined")
def jobs_in_queue_have_same_hash(all_jobs, queue):
jobs = [job for job in all_jobs if job in queue or job in queue.running]
hashes = []
for job in jobs:
if job.hash in hashes:
raise AssertionError("Two jobs with same hash in queues")
hashes.append(job.hash)
def jobs_in_queue_with_a_non_banned_unsuccessful_descendant(all_jobs, queue):
jobs = [job for job in all_jobs if job in queue or job in queue.running]
for job in jobs:
if job in queue.running:
if job.unsuccessful_descendant:
raise AssertionError("Job running with an unsuccessful non-banned descendant!")
def unsuccessful_unbanned_jobs_with_a_single_variable(all_jobs, queue):
for job in all_jobs:
if not job.child:
if job.is_banned is False and job.is_successful is False:
raise AssertionError("unsuccessful_unbanned_jobs_with_a_single_variable")
def no_jobs_left_in_queue(queue):
if len(queue) == 0 and len(queue.running) == 0:
raise EmptyQueuesError("No jobs queuing and no running jobs!")
# Attempt to assert that nothing nasty happens
def check_consistency(seed_set, queue, ban_list):
all_jobs = seed_set.descendants()
all_jobs.append(seed_set)
job_not_present_in_any_list(all_jobs, queue)
completed_jobs_have_successful_defined(all_jobs)
successful_job_descendants(all_jobs)
non_successful_job_ancestors(all_jobs)
check_duplicates(queue)
check_banned_variables(all_jobs, ban_list)
banned_job_has_descendants_running(all_jobs, queue)
check_job_in_running_is_successful(all_jobs, queue)
jobs_in_queue_with_a_non_banned_unsuccessful_descendant(all_jobs, queue)
unsuccessful_unbanned_jobs_with_a_single_variable(all_jobs, queue)
# jobs_in_queue_have_same_hash(all_jobs, queue)
no_jobs_left_in_queue(queue)
def checkpoint(queue):
print("Saving checkpoint!", end="\r")
import pickle
all_jobs = queue.pending + queue.running + queue.completed
# Save the communicator
communicator = all_jobs[0].communicator
for job in all_jobs:
job.communicator = None
# Dump object
pickle.dump(queue, open("checkpoint.pkl", "wb"))
# Restore the communicator
for job in all_jobs:
job.communicator = communicator
print("Checkpoint saved!")
def recover_checkpoint(communicator, filename="checkpoint.pkl"):
import pickle as pkl
print("Loading checkpoint!", end="\r")
queue = pkl.load(open(filename, "rb"))
all_jobs = queue.pending + queue.running + queue.completed
for job in all_jobs:
job.communicator = communicator
job_with_no_parent = [job for job in all_jobs if job.parent is None]
grandfather = job_with_no_parent[0]
ban_list = grandfather.banned_variables
print("Checkpoint loaded!!")
return queue, grandfather, ban_list
def saving_state(_queue, filename="state.pkl"):
import pickle as pkl
log = {"COMPLETED": {}, "RUNNING": [], "QUEUING": []}
for _job in _queue.running:
log["RUNNING"].append(_job.identifier())
for _job in _queue:
log["QUEUING"].append(_job.identifier())
for _job in _queue.completed:
log["COMPLETED"][_job.identifier()] = (_job.is_successful, _job.is_banned)
with open(filename, "wb") as output_log:
pkl.dump(log, output_log, protocol=2)
class EmptyQueuesError(Exception):
def __init__(self, message):
# Call the base class constructor with the parameters it needs
super(EmptyQueuesError, self).__init__(message)
# Brand new vault
from UtilsRPE.Functions import load_vault
from collector import SSH
import re
import time
import sys
# Signal test
import signal
def hash_from_list(_list):
_list = list(sorted(_list))
val = (sys.maxsize + 1) * 2
return hash(tuple(_list)) + val
class SimulationError(Exception):
pass
def generate_jobscript(template, parameters):
with open(template) as f:
job = f.read()
for parameter, value in parameters.items():
job = job.replace("%%%s%%" % parameter, value)
return job
def submit_job(job, remote_jobscript, communicator):
with communicator.sftp.file(remote_jobscript, "w") as jobscript:
jobscript.write(job)
command = "sbatch %s" % remote_jobscript
stdin, stdout, stderr = communicator.execute(command)
output = str(stdout.read())
jobid_pattern = "Submitted batch job +([0-9]+)"
jobid = re.search(jobid_pattern, output).group(1)
return jobid
def job_status(job_id, communicator):
command = 'sacct --job %s --format="State,Elapsed"' % job_id
stdin, stdout, stderr = communicator.execute(command)
output = stdout.read().decode()
try:
status = output.split("\n")[2].split()[0]
time = output.split("\n")[2].split()[1]
except IndexError:
status = "PENDING"
time = "00:00"
return status, time
def obtain_result(remote_rundir, communicator):
with communicator.sftp.open("%s/accuracy.txt" % remote_rundir) as f:
error = float(f.read())
return error
def generate_namelist(variable_set, filename, _vault=None):
if _vault:
vault = _vault
variables = list(sorted([v for v in vault.variables if v.id is not None], key=lambda x: -x.error))
for index, v in enumerate(variables):
if v.id in variable_set:
v.precision = 10
else:
v.precision = 52
vault.generate_namelist_precisions(filename=filename, exclude_double_precision=True)
def signal_handler(sig, frame):
jobs_to_cancel = " ".join(job_ids)
print('\n\nCanceling simulation/s %s !!' % jobs_to_cancel)
command = "scancel %s" % jobs_to_cancel
stdin, stdout, stderr = remote.execute(command)
print(stdout.read().decode())
exit(1)
class Job:
def __init__(self, variable_set, remote_communicator, vault=None):
# Init variables
self.vault = vault
self.variable_set = variable_set
self.communicator = remote_communicator
self.remote_scratch = self.communicator.remote_scratch
self.remote_logdir = "%s/LOGS" % self.remote_scratch
self.remote_namelist_dir = "%s/Namelists" % self.remote_logdir
self.remote_rundir = "%s/RefinementTest%i" % (self.remote_scratch, self.hash)
self.remote_jobscript_path = "%s/refinement_%i.cmd" % (self.remote_logdir, self.hash)
self.remote_namelist = "%s/refinement_namelist_%i" % (self.remote_namelist_dir, self.hash)
self.local_namelist = "NewWorlds/namelist_precisions"
self.template = "Templates/singleJobscript.sh"
# parameters to replace:
self.parameters = {"JOBNAME": "refinement_test",
"RUNDIR": self.remote_rundir,
"LOGDIR": self.remote_logdir,
"NAMELIST_PATH": self.remote_namelist,
"NUMPROC": "48"
}
self.jobscript = generate_jobscript(self.template, self.parameters)
self.is_completed = False
self.is_successful = None
self.elapsed_time = None
self.job_id = None
self._result = None
self.submitted = False
@property
def status(self):
if self.submitted is False:
return "NOT SUBMITTED", "00:00:00"
elif self.is_completed is False:
_status, _time = job_status(self.job_id, self.communicator)
_status = _status.strip()
if _status == "COMPLETED":
# self.is_completed = True
self.elapsed_time = _time
return _status, _time
else:
return "COMPLETED", self.elapsed_time
# @property
def result(self, level=0):
if self._result is None:
if level > 10:
return None
try:
self._result = obtain_result(self.remote_rundir, self.communicator)
return self._result
except FileNotFoundError:
# time.sleep(1)
# return self.result(level=level+1)
return None
else:
return self._result
@property
def hash(self):
return hash_from_list(self.variable_set)
def check_if_result_already_exists(self):
self.update_parameters()
accuracy_filename = "%s/accuracy.txt" % self.remote_rundir
try:
self.communicator.sftp.open(accuracy_filename)
# self.is_completed = True
self.submitted = True
# Recovering job_id
job_id_filename = "%s/job_id.txt" % self.remote_rundir
self.job_id = self.communicator.sftp.open(job_id_filename).read().decode().strip()
print("Job %s already finished! Error: %.2e" % (self.job_id, self.result()))
return True
except IOError:
return False
def submit(self):
self.submitted = True
# Check if result already exists
completed = self.check_if_result_already_exists()
if not completed:
self.update_parameters()
# Generate namelist and put it in the remote system
generate_namelist(self.variable_set, self.local_namelist, self.vault)
self.communicator.put(self.local_namelist, self.remote_namelist)
# Submit job and obtain job id
self.job_id = submit_job(self.jobscript, self.remote_jobscript_path, self.communicator)
print("Job %i successfully submitted to the remote machine with job_id %s" % (self.hash, self.job_id))
else:
# self.is_completed = True
pass
def cancel_remote(self):
if self.job_id is not None:
command = "scancel -v %s" % self.job_id
print("Canceling job %s" % self.job_id)
stdin, stdout, stderr = self.communicator.execute(command)
print(stdout.read().decode())
print(stderr.read().decode())