Source code for proton_decay_study.generators.threaded_multi_file

from proton_decay_study.generators.base import BaseDataGenerator
from proton_decay_study.generators.single_file import SingleFileDataGenerator
import logging
import threading
import Queue
import traceback
import signal
import sys


[docs]class SingleFileThread(threading.Thread): """ Represents a single file that is asynchronously """ # Locks class for the duration of parent lifetime # Currently, hierarchal threading is not supported threadLock = threading.Lock() # Set this to 0 to kill all threads once prefetch is finished. __ThreadExitFlag__ = 1 # Holds the current thread queue queue = Queue.Queue(10000) # Locks the thread queue queueLock = threading.Lock() logger = logging.getLogger('pdk.gen.singlefilethread') activeThreads = [] def __init__(self, datasetname, labelsetname, batch_size): """ :param: datasetname The name used to pull the dataset from file :type: datasetname string :param: labelsetname The name used to pull the labelset from file :type: labelsetname string :param: batch_size The number of images and truths to pull :type: batch_size int """ super(SingleFileThread, self).__init__() self.datasetname = datasetname self.labelsetname = labelsetname self.batch_size = batch_size # prefetched means that there's data ready self._buffer = None self.single_thread_lock = threading.Lock() self._filegen = None @property def get(self): self.single_thread_lock.acquire() if self._buffer is None: pass
[docs] def run(self): """ Loops over queue to accept new configurations """ while SingleFileThread.__ThreadExitFlag__: if self.single_thread_lock.locked(): continue else: self.single_thread_lock.acquire() if self._buffer is not None: self.single_thread_lock.release() continue self.single_thread_lock.release() if self._filegen is None or self._filegen.reused: self.queueLock.acquire() if not self.queue.empty(): self._filegen = SingleFileDataGenerator(SingleFileThread.queue.get(), self.datasetname, self.labelsetname, self.batch_size) self.logger.info("Moving to file: {}".format(self._filegen)) self.queueLock.release() else: try: try: self.single_thread_lock.acquire() self._buffer = self._filegen.next() except StopIteration: self.logger.warning("Hit Stop Iteration") self._buffer = None self._filegen = None self.single_thread_lock.release() except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() self.logger.error(repr(traceback.format_exception(exc_type, exc_value,
exc_traceback)))
[docs] @staticmethod def killRunThreads(signum, frame): """ Sets the thread kill flag to each of the ongoing analysis threads """ SingleFileThread.logger.info("Killing Single File threads...") SingleFileThread.__ThreadExitFlag__ = 0
sys.exit(signum)
[docs] @staticmethod def startThreads(nThreads, datasetname, labelsetname, batch_size): msg = "Starting {} Single File threads".format(nThreads) SingleFileThread.logger.info(msg) for i in range(nThreads): thread = SingleFileThread(datasetname, labelsetname, batch_size) thread.start() SingleFileThread.activeThreads.append(thread)
SingleFileThread.logger.info("Threads successfully started")
[docs] @staticmethod def waitTillComplete(callback=None): queue = SingleFileThread if callback is None: while not queue.queue.empty() and queue.__ThreadExitFlag__: sys.stdout.flush() else: while not queue.queue.empty() and queue.__ThreadExitFlag__: callback() # Notify threads it's time to exit SingleFileThread.__ThreadExitFlag__ = 0 # Wait for all threads to complete for t in SingleFileThread.activeThreads: t.join() # dealloc
SingleFileThread.activeThreads = [] signal.signal(signal.SIGINT, SingleFileThread.killRunThreads)
[docs]class ThreadedMultiFileDataGenerator(BaseDataGenerator): """ Uses threads to pull asynchronously from files """ logger = logging.getLogger("pdk.gen.threaded_multi") def __init__(self, datapaths, datasetname, labelsetname, batch_size=1, nThreads=2): SingleFileThread.threadLock.acquire() self._threads = SingleFileThread.startThreads(nThreads, datasetname, labelsetname, batch_size) SingleFileThread.queueLock.acquire() for config in datapaths: SingleFileThread.queue.put(config) SingleFileThread.queueLock.release() self.datapaths = datapaths self.logger.info("Threaded multi file generator ready for generation") self.current_thread_index = 0 def __del__(self): SingleFileThread.__ThreadExitFlag__ = 0 for t in SingleFileThread.activeThreads: t.join() del self._threads SingleFileThread.threadLock.release() @property def output(self): x, y = self.next() return x[0].shape @property def input(self): x, y = self.next() return y[0].shape[0] def __len__(self): """ Iterates over files to create the total sum length of the datasets in each file. """ return 0
[docs] def next(self): # see if there's any pre-fetched data # If the filename queue is empty, fill it back up again. # This ensures that files are all used up before they # are iterated over again. self.logger.debug("Getting next") SingleFileThread.queueLock.acquire() if SingleFileThread.queue.empty(): for i in self.datapaths: SingleFileThread.queue.put(i) SingleFileThread.queueLock.release() tmp_buffer = None while tmp_buffer is None: if self.current_thread_index >= len(SingleFileThread.activeThreads): self.current_thread_index = 0 thread = SingleFileThread.activeThreads[self.current_thread_index] thread.single_thread_lock.acquire() if thread._buffer is not None: new_buff = thread._buffer thread._buffer = None thread.single_thread_lock.release() if new_buff is None: self.logger.warning("Found null dataset") return self.next() x, y = new_buff if not len(x) == len(y): self.logger.warning("Found incorrect dataset size") return self.next() return new_buff else: self.current_thread_index += 1 thread.single_thread_lock.release() self.logger.warning("Threaded buffer found queue is empty. Trying again")
return self.next()