import atexit
import io
import logging
import os
import queue
import random
import socket
import tempfile
import threading
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Type, Dict, Optional, Tuple, Set, List
import psutil
from psij import JobExecutor, Job, JobState, JobStatus
logger = logging.getLogger(__name__)
_MAX_FILE_AGE_DAYS = 30
[docs]class SingletonThread(threading.Thread):
"""
A convenience class to return a thread that is guaranteed to be unique to this process.
This is intended to work with fork() to ensure that each os.getpid() value is associated with
at most one thread. This is not safe. The safe thing, as pointed out by the fork() man page,
is to not use fork() with threads. However, this is here in an attempt to make it slightly
safer for when users really really want to take the risk against all advice.
This class is meant as an abstract class and should be used by subclassing and implementing
the `run` method.
"""
_instances: Dict[int, Dict[type, 'SingletonThread']] = {}
_lock = threading.RLock()
def __init__(self, name: Optional[str] = None, daemon: bool = False) -> None:
"""
Instantiation of this class or one of its subclasses should be done through the
:meth:`get_instance` method rather than directly.
Parameters
----------
name
An optional name for this thread.
daemon
A daemon thread does not prevent the process from exiting.
"""
super().__init__(name=name, daemon=daemon)
[docs] @classmethod
def get_instance(cls: Type['SingletonThread']) -> 'SingletonThread':
"""Returns a started instance of this thread.
The instance is guaranteed to be unique for this process. This method also guarantees
that a forked process will get a separate instance of this thread from the parent.
"""
with cls._lock:
my_pid = os.getpid()
if my_pid in cls._instances:
classes = cls._instances[my_pid]
else:
classes = {}
cls._instances[my_pid] = classes
if cls in classes:
return classes[cls]
else:
instance = cls()
classes[cls] = instance
instance.start()
return instance
class _StatusUpdater(SingletonThread):
# we are expecting short messages in the form <jobid> <status>
RECV_BUFSZ = 2048
def __init__(self) -> None:
super().__init__()
self.name = 'Status Update Thread'
self.daemon = True
self.work_directory = Path.home() / '.psij'
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(True)
self.socket.settimeout(0.5)
self.socket.bind(('', 0))
self.update_port = self.socket.getsockname()[1]
self.ips = self._get_ips()
logger.debug('Local IPs: %s' % self.ips)
logger.debug('Status updater port: %s' % self.update_port)
self._create_update_file()
logger.debug('Update file: %s' % self.update_file.name)
self.partial_file_data = ''
self.partial_net_data = ''
self._jobs: Dict[str, Tuple[Job, JobExecutor]] = {}
self._jobs_lock = threading.RLock()
self._sync_ids: Set[str] = set()
self._last_received = ''
def _get_ips(self) -> List[str]:
addrs = psutil.net_if_addrs()
r = []
for name, l in addrs.items():
if name == 'lo':
continue
for a in l:
if a.family == socket.AddressFamily.AF_INET:
r.append(a.address)
return r
def _create_update_file(self) -> None:
f = tempfile.NamedTemporaryFile(dir=self.work_directory, prefix='supd_', delete=False)
name = f.name
self.update_file_name = name
atexit.register(os.remove, name)
f.close()
self.update_file = open(name, 'r+b')
self.update_file.seek(0, io.SEEK_END)
self.update_file_pos = self.update_file.tell()
def register_job(self, job: Job, ex: JobExecutor) -> None:
with self._jobs_lock:
self._jobs[job.id] = (job, ex)
def unregister_job(self, job: Job) -> None:
with self._jobs_lock:
try:
del self._jobs[job.id]
except KeyError:
# There are cases when it's difficult to ensure that this method is only called
# once for each job. Instead, ignore errors here, since the ultimate goal is to
# remove the job from the _jobs dictionary.
pass
def step(self) -> None:
self._poll_file()
try:
data = self.socket.recv(_StatusUpdater.RECV_BUFSZ)
self._process_update_data(data)
except TimeoutError:
pass
except socket.timeout:
# before 3.10, this was a separate exception from TimeoutError
pass
except BlockingIOError:
pass
def _poll_file(self) -> None:
self.update_file.seek(0, io.SEEK_END)
pos = self.update_file.tell()
if pos > self.update_file_pos:
self.update_file.seek(self.update_file_pos, io.SEEK_SET)
n = pos - self.update_file_pos
self._process_update_data(self.update_file.read(n))
self.update_file_pos = pos
def run(self) -> None:
while True:
try:
self.step()
except Exception:
logger.exception('Exception in status updater thread. Ignoring.')
def flush(self) -> None:
# Ensures that, upon return from this call, all updates available before this call have
# been processed. To do so, we send a UDP packet to the socket to wake it up and wait until
# it is received. This does not guarantee that file-based updates are necessarily
# processed, since that depends on many factors.
# On the minus side, this method, as implemented, can cause deadlocks if the socket
# reads fail for unexpected reasons. This should probably be accounted for.
token = '_SYNC ' + str(random.getrandbits(128))
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.sendto(bytes(token, 'utf-8'), ('127.0.0.1', self.update_port))
self._poll_file()
delay = 0.0001
while token not in self._sync_ids:
time.sleep(delay)
delay *= 2
self._sync_ids.remove(token)
def _process_update_data(self, data: bytes) -> None:
sdata = data.decode('utf-8')
if sdata == self._last_received:
# we send UDP packets to all IP addresses of the submit host, which may
# result in duplicates, so we drop consecutive messages that are identical
return
else:
self._last_received = sdata
lines = sdata.splitlines()
for line in lines:
if sdata.startswith('_SYNC '):
self._sync_ids.add(sdata)
continue
els = line.split()
if len(els) > 2 and els[1] == 'LOG':
logger.info('%s %s' % (els[0], ' '.join(els[2:])))
continue
if len(els) != 2:
logger.warning('Invalid status update message received: %s' % line)
continue
job_id = els[0]
state = JobState.from_name(els[1])
job = None
with self._jobs_lock:
try:
(job, executor) = self._jobs[job_id]
except KeyError:
pass
if job:
executor._set_job_status(job, JobStatus(state))
class _FileCleaner(SingletonThread):
def __init__(self) -> None:
super().__init__()
self.name = 'File Cleaner'
self.daemon = True
self.queue: queue.SimpleQueue[Path] = queue.SimpleQueue()
def clean(self, path: Path) -> None:
self.queue.put(path)
def run(self) -> None:
while True:
try:
path = self.queue.get(block=True, timeout=1)
try:
self._do_clean(path)
except Exception as ex:
print(f'Warning: cannot clean {path}: {ex}')
except queue.Empty:
pass
def _do_clean(self, path: Path) -> None:
now = datetime.now()
max_age = timedelta(days=_MAX_FILE_AGE_DAYS)
for child in path.iterdir():
m_time = datetime.fromtimestamp(child.lstat().st_mtime)
if now - m_time > max_age:
try:
child.unlink()
except FileNotFoundError:
# we try our best
pass