Newer
Older
import operator
import threading
from concurrent.futures import ThreadPoolExecutor
from earthdiagnostics.datafile import StorageStatus, LocalStatus
from earthdiagnostics.diagnostic import (
DiagnosticStatus,
Diagnostic,
DiagnosticOptionError,
)
from earthdiagnostics.utils import Utils, TempFile
"""
Class to produce and control the workflow of EarthDiagnostics
Parameters
----------
config: Config
data_manager: DataManager
"""
self.jobs = {}
self.jobs[DiagnosticStatus.WAITING] = []
self.jobs[DiagnosticStatus.READY] = []
self.jobs[DiagnosticStatus.RUNNING] = []
self.jobs[DiagnosticStatus.COMPLETED] = []
self.jobs[DiagnosticStatus.FAILED] = []
self.time = {}
self.had_errors = False
self._register_diagnostics()
for fulldiag in self.config.get_commands():
Log.info("Adding {0} to diagnostic list", fulldiag)
diag_class = Diagnostic.get_diagnostic(diag_options[0])
if diag_class:
try:
for job in diag_class.generate_jobs(self, diag_options):
self.add_job(subjob)
Log.error(
"Can not configure diagnostic {0}: {1}",
diag_options[0],
ex,
)
Log.error(
"{0} is not an available diagnostic", diag_options[0]
)
def add_job(self, job, old_status=None):
if old_status is not None:
self.jobs[old_status].remove(job)
if job not in self.jobs[job.status]:
self.jobs[job.status].append(job)
"""
Run all the diagnostics
Returns
-------
bool
Only True if all diagnostic were correctly executed
"""
if (
not self.jobs[DiagnosticStatus.WAITING]
and not self.jobs[DiagnosticStatus.READY]
):
Log.result("No diagnostics to run")
return not self.had_errors
start_time = datetime.datetime.now()
Log.info("Starting to compute at {0}", start_time)
self.threads = Utils.available_cpu_count()
if 0 < self.config.max_cores < self.threads:
self.threads = self.config.max_cores
self.downloader = Downloader()
self.uploader = ThreadPoolExecutor(self.config.parallel_uploads)
self.executor = ThreadPoolExecutor(self.threads)
self.lock.acquire()
for job in self.jobs[DiagnosticStatus.WAITING].copy():
job.subscribe(self, self._job_status_changed)
if job.can_skip_run():
Log.info("Diagnostic {0} already done. Skipping !", job)
job.status = DiagnosticStatus.COMPLETED
for file_object in self.data_manager.requested_files.values():
file_object.subscribe(self, self._file_object_status_changed)
if file_object.download_required():
self.downloader.submit(file_object)
self.downloader.start()
printer_lock = threading.Lock()
printer_lock.acquire()
printer = Thread(
target=self._print_status, name="Printer", args=(printer_lock,)
)
self.lock.acquire()
printer_lock.release()
self.downloader.shutdown()
self.executor.shutdown()
TempFile.clean()
finish_time = datetime.datetime.now()
Log.result("Diagnostics finished at {0}", finish_time)
Log.result("Elapsed time: {0}\n", finish_time - start_time)
self._print_errors()
self._print_stats()
def _print_status(self, lock):
interval = 60
while True:
if six.PY3:
if lock.acquire(blocking=True, timeout=interval):
return
else:
step = 0.5
if lock.acquire(False):
time.sleep(step)
Log.info("Current status:")
Log.info("===============")
Log.info(
"Waiting: {0:4}", len(self.jobs[DiagnosticStatus.WAITING])
)
Log.info(
"Ready to run: {0:4}", len(self.jobs[DiagnosticStatus.READY])
)
Log.info(
"Running: {0:4}", len(self.jobs[DiagnosticStatus.RUNNING])
)
Log.info(
"Completed: {0:4}",
len(self.jobs[DiagnosticStatus.COMPLETED]),
)
Log.info(
"Failed: {0:4}", len(self.jobs[DiagnosticStatus.FAILED])
)
Log.info("===============")
def _job_status_changed(self, job, old_status):
self.add_job(job, old_status)
if job.status == DiagnosticStatus.READY:
self.executor.submit(self._run_job, job)
for request in job._requests:
if request.only_suscriber(self):
del self.data_manager.requested_files[request.remote_file]
request.unsubscribe(self)
def _file_object_status_changed(self, file_object):
Log.debug(
"Checking file {0}. Local status {0.local_status} "
"Storage status{0.storage_status}",
file_object,
)
if file_object.download_required():
self.downloader.submit(file_object)
return
if file_object.upload_required():
file_object.storage_status = StorageStatus.UPLOADING
self.uploader.submit(file_object.upload)
return
if file_object.storage_status == StorageStatus.FAILED:
self.uploads_failed.append(file_object)
if (
file_object.local_status != LocalStatus.COMPUTING
and file_object.storage_status != StorageStatus.UPLOADING
and file_object.only_suscriber(self)
):
del self.data_manager.requested_files[file_object.remote_file]
file_object.unsubscribe(self)
if self._jobs_running_or_ready():
return False
if self._data_downloading_or_uploading():
return False
try:
self.lock.release()
except threading.ThreadError:
pass
return True
def _jobs_running_or_ready(self):
if (
self.jobs[DiagnosticStatus.READY]
or self.jobs[DiagnosticStatus.RUNNING]
):
for job in self.jobs[DiagnosticStatus.WAITING]:
if job.all_requests_in_storage():
return True
def _data_downloading_or_uploading(self):
for request in self.data_manager.requested_files.values():
if request.storage_status == StorageStatus.UPLOADING:
if request.local_status == LocalStatus.DOWNLOADING:
if request.upload_required():
if request.download_required():
def _pause_downloader_if_required(self):
self.downloader.on_hold = len(self.jobs[DiagnosticStatus.READY]) > 20
Log.info("Time consumed by each diagnostic class")
Log.info("--------------------------------------")
times = {}
for job in (
self.jobs[DiagnosticStatus.COMPLETED]
+ self.jobs[DiagnosticStatus.FAILED]
):
job_type = job.alias
if job_type in times.keys():
times[job_type] += job.consumed_time
else:
times[job_type] = job.consumed_time
for diag in sorted(times, key=operator.itemgetter(1)):
if self.jobs[DiagnosticStatus.FAILED]:
self.had_errors = True
Log.error("Failed jobs")
Log.error("-----------")
for job in self.jobs[DiagnosticStatus.FAILED]:
Log.error("{0}: {0.message}", job)
Log.error(
"Total wasted time: {0}",
sum(
[
job.consumed_time
for job in self.jobs[DiagnosticStatus.FAILED]
],
datetime.timedelta(),
),
)
Log.info("")
self.had_errors = True
Log.error("Failed uploads")
Log.error("--------------")
for file_object in self.uploads_failed:
Log.error("{0}", file_object.remote_file)
Log.info("")
job.consumed_time = datetime.datetime.now() - start_time
job.message = "{0}\n{1}".format(
ex, "".join(traceback.format_tb(exc_traceback))
)
Log.error("Job {0} failed ({2}): {1}", job, job.message, exc_type)
job.status = DiagnosticStatus.FAILED
return False
job.consumed_time = datetime.datetime.now() - start_time
job.status = DiagnosticStatus.COMPLETED
return True
@staticmethod
def _register_diagnostics():
WorkManager._register_ocean_diagnostics()
WorkManager._register_general_diagnostics()
WorkManager._register_stats_diagnostics()
@staticmethod
def _register_stats_diagnostics():
from earthdiagnostics.statistics.monthlypercentile import (
MonthlyPercentile,
)
from earthdiagnostics.statistics.climatologicalpercentile import (
ClimatologicalPercentile,
)
from earthdiagnostics.statistics.daysoverpercentile import (
DaysOverPercentile,
)
from earthdiagnostics.statistics.discretize import Discretize
Diagnostic.register(MonthlyPercentile)
Diagnostic.register(ClimatologicalPercentile)
Diagnostic.register(DaysOverPercentile)
@staticmethod
def _register_general_diagnostics():
from earthdiagnostics.general.attribute import Attribute
from earthdiagnostics.general.timemean import (
DailyMean,
MonthlyMean,
YearlyMean,
)
from earthdiagnostics.general.module import Module
from earthdiagnostics.general.rewrite import Rewrite
from earthdiagnostics.general.scale import Scale
from earthdiagnostics.general.verticalmeanmetersiris import (
VerticalMeanMetersIris,
)
from earthdiagnostics.general.simplify_dimensions import (
SimplifyDimensions,
)
Diagnostic.register(DailyMean)
Diagnostic.register(MonthlyMean)
Diagnostic.register(YearlyMean)
Diagnostic.register(Rewrite)
Diagnostic.register(Scale)
Diagnostic.register(Attribute)
Diagnostic.register(Module)
Diagnostic.register(VerticalMeanMetersIris)
Diagnostic.register(SimplifyDimensions)
@staticmethod
def _register_ocean_diagnostics():
from .ocean.mixedlayerheatcontent import MixedLayerHeatContent
from .ocean.mixedlayersaltcontent import MixedLayerSaltContent
from .ocean.siasiesiv import Siasiesiv
from .ocean.verticalmean import VerticalMean
from .ocean.verticalmeanmeters import VerticalMeanMeters
from .ocean.verticalgradient import VerticalGradient
from .ocean.interpolate import Interpolate
from .ocean.interpolatecdo import InterpolateCDO
from .ocean.moc import Moc
from .ocean.areamoc import AreaMoc
from .ocean.maxmoc import MaxMoc
from .ocean.psi import Psi
from .ocean.gyres import Gyres
from .ocean.convectionsites import ConvectionSites
from .ocean.cutsection import CutSection
from .ocean.averagesection import AverageSection
from .ocean.heatcontentlayer import HeatContentLayer
from .ocean.heatcontent import HeatContent
from .ocean.regionmean import RegionMean
from .ocean.regionsum import RegionSum
from .ocean.rotation import Rotation
from .ocean.sivolume import Sivolume
from .ocean.sivol2d import Sivol2d
Diagnostic.register(MixedLayerSaltContent)
Diagnostic.register(Siasiesiv)
Diagnostic.register(VerticalMean)
Diagnostic.register(VerticalMeanMeters)
Diagnostic.register(Interpolate)
Diagnostic.register(InterpolateCDO)
Diagnostic.register(Moc)
Diagnostic.register(AreaMoc)
Diagnostic.register(MaxMoc)
Diagnostic.register(Psi)
Diagnostic.register(Gyres)
Diagnostic.register(ConvectionSites)
Diagnostic.register(CutSection)
Diagnostic.register(AverageSection)
Diagnostic.register(MixedLayerHeatContent)
Diagnostic.register(HeatContentLayer)
Diagnostic.register(HeatContent)
Diagnostic.register(RegionMean)
Diagnostic.register(VerticalGradient)
class Downloader(object):
"""
Download manager for EarthDiagnostics
We are not using a ThreadPoolExecutor because we want to be able to
control priorities in the download
def __init__(self):
self._downloads = []
self._lock = threading.Lock()
self.stop = False
self.on_hold = False
def start(self):
"""Create the downloader thread and initialize it"""
self._thread = threading.Thread(target=self._downloader)
self._thread.start()
def submit(self, datafile):
"""Add a datafile to the download queue"""
self._lock.acquire()
self._downloads.append(datafile)
self._lock.release()
try:
while True:
with self._lock:
if not self._downloads or self.on_hold:
if self.stop:
return
self._downloads.sort(
key=cmp_to_key(Downloader._prioritize)
)
self._downloads.remove(datafile)
datafile.download()
except Exception as ex:
Log.critical(
"Unhandled error at downloader: {0}\n{1}",
ex,
traceback.print_exc(),
)
@staticmethod
def _suscribers_waiting(datafile):
waiting = 0
for diag in datafile.suscribers:
if not isinstance(diag, Diagnostic):
continue
if diag.pending_requests() == 1:
waiting += 1
return waiting
@staticmethod
def _prioritize(datafile1, datafile2):
waiting = Downloader._suscribers_waiting(
datafile1
) - Downloader._suscribers_waiting(datafile2)
else:
suscribers1 = len(datafile1.suscribers)
if datafile2.suscribers is None:
if suscribers:
return -suscribers
return 0
def shutdown(self):
"""Stop the downloader after all downloads have finished"""
self.stop = True
self._thread.join()