Source code for proton_decay_study.generators.threaded_gen3d

# -*- coding: utf-8 -*-
from .base import BaseDataGenerator
from .single_file_3d import Gen3D
import logging
import threading
import queue
import traceback
import signal
import sys
import random


[docs]class SingleFileThread(threading.Thread): """ Wrapper thread for buffering data from a single file """ # Locks class for the duration of parent lifetime # Currently, hierarchal threading is not supported threadLock = threading.Lock() # Set this to 0 to kill all threads once prefetch is finished. __ThreadExitFlag__ = 1 # Holds the current request queue queue = queue.Queue(10000) # Locks the thread queue queueLock = threading.Lock() logger = logging.getLogger('pdk.gen.singlefilethread') activeThreads = [] def __init__(self, datasetname, labelsetname, batch_size): """ :param: datasetname The name used to pull the dataset from file :type: datasetname string :param: labelsetname The name used to pull the labelset from file :type: labelsetname string :param: batch_size The number of images and truths to pull :type: batch_size int """ super(SingleFileThread, self).__init__() self.datasetname = datasetname self.labelsetname = labelsetname self.batch_size = batch_size self._buffer = None self.single_thread_lock = threading.Lock() # This holds the file generator pattern self._filegen = None
[docs] def run(self): """ Loops over queue to accept new configurations """ self.logger.info("Starting thread: {}".format(self)) while SingleFileThread.__ThreadExitFlag__: # The thread loop should exit once it sees # the stop thread flag if self.single_thread_lock.locked() or self._buffer is not None: # If this is currently being visited by parent # Just continue continue # At this point, we need to fill the buffer # If we don't have filegen, get one if self._filegen is None: self.queueLock.acquire() if not self.queue.empty(): name = SingleFileThread.queue.get() try: self._filegen = Gen3D(name, self.datasetname, self.labelsetname, self.batch_size) self.logger.info("Moving to file: {}".format(self._filegen._filename)) except Exception as e: self.logger.warning(e) self._filegen = None self.queueLock.release() return None else: self.queueLock.release() continue self.queueLock.release() # Now, fill the buffer and release self.single_thread_lock.acquire() try: self._buffer = self._filegen.next() except StopIteration: self._buffer = None self._filegen = None
self.single_thread_lock.release()
[docs] def visit(self, parent): # wait until we have it if self.single_thread_lock.locked(): return None # Now grab it self.single_thread_lock.acquire() # Copy the data handle out (not the data itself) ret = self._buffer self._buffer = None self.single_thread_lock.release() parent.check_and_refill()
return ret
[docs] @staticmethod def killRunThreads(signum, frame): """ Sets the thread kill flag to each of the ongoing analysis threads """ SingleFileThread.logger.info("Killing Single File threads...") SingleFileThread.__ThreadExitFlag__ = 0
sys.exit(signum)
[docs] @staticmethod def startThreads(nThreads, datasetname, labelsetname, batch_size): msg = "Starting {} Single File threads".format(nThreads) SingleFileThread.logger.info(msg) for i in range(nThreads): thread = SingleFileThread(datasetname, labelsetname, batch_size) thread.start() SingleFileThread.activeThreads.append(thread)
SingleFileThread.logger.info("Threads successfully started")
[docs] @staticmethod def waitTillComplete(callback=None): queue = SingleFileThread if callback is None: while not queue.queue.empty() and queue.__ThreadExitFlag__: sys.stdout.flush() else: while not queue.queue.empty() and queue.__ThreadExitFlag__: callback() # Notify threads it's time to exit SingleFileThread.__ThreadExitFlag__ = 0 # Wait for all threads to complete for t in SingleFileThread.activeThreads: t.join() # dealloc
SingleFileThread.activeThreads = []
[docs] @staticmethod def status(): SingleFileThread.logger.debug("ThreadLock: {}".format(SingleFileThread.threadLock.locked())) SingleFileThread.logger.debug("QueueLock: {}".format(SingleFileThread.queueLock.locked()))
SingleFileThread.logger.debug("Flag: {}".format(SingleFileThread.__ThreadExitFlag__))
[docs] def single_status(self):
self.logger.debug("Single Thread Lock: {}".format(self.single_thread_lock.locked())) signal.signal(signal.SIGINT, SingleFileThread.killRunThreads)
[docs]class ThreadedMultiFileDataGenerator(BaseDataGenerator): """ Uses threads to pull asynchronously from files """ logger = logging.getLogger("pdk.gen.threaded_multi") def __init__(self, datapaths, datasetname, labelsetname, batch_size=1, nThreads=8): SingleFileThread.threadLock.acquire() self.datapaths = [i for i in datapaths] for i in range(len(datapaths)): random.shuffle(self.datapaths) self.check_and_refill() SingleFileThread.startThreads(nThreads, datasetname, labelsetname, batch_size) self.current_thread_index = 0 self.logger.info("Threaded multi file generator ready for generation") def __del__(self): SingleFileThread.__ThreadExitFlag__ = 0 for t in SingleFileThread.activeThreads: t.join() SingleFileThread.activeThreads = [] SingleFileThread.threadLock.release() @property def output(self): x, y = self.next() return x[0].shape @property def input(self): x, y = self.next() return y[0].shape[0] def __len__(self): """ Iterates over files to create the total sum length of the datasets in each file. """ return 0
[docs] def status(self): self.logger.debug("Filenames: {}".format(self.datapaths)) self.logger.debug("Active threads: {}".format(SingleFileThread.activeThreads)) SingleFileThread.status() for i in SingleFileThread.activeThreads:
i.single_status()
[docs] def check_and_refill(self): SingleFileThread.queueLock.acquire() if SingleFileThread.queue.empty(): for i in self.datapaths: SingleFileThread.queue.put(i)
SingleFileThread.queueLock.release()
[docs] def next(self): # see if there's any pre-fetched data # If the filename queue is empty, fill it back up again. # This ensures that files are all used up before they # are iterated over again. if self.current_thread_index >= len(SingleFileThread.activeThreads): self.current_thread_index = 0 thread = SingleFileThread.activeThreads[self.current_thread_index] self.current_thread_index+=1 ret = thread.visit(self) while ret is None: if self.current_thread_index == len(SingleFileThread.activeThreads): self.current_thread_index = 0 thread = SingleFileThread.activeThreads[self.current_thread_index] self.current_thread_index+=1 ret = thread.visit(self)
return ret