# -*- Mode: Python; tab-width: 4 -*- VERSION_STRING = '$Id: //depot/main/findmail/src/coroutine/coro.py#50 $' # Copyright 1999, 2000 by eGroups, Inc. # # All Rights Reserved # # Permission to use, copy, modify, and distribute this software and # its documentation for any purpose and without fee is hereby # granted, provided that the above copyright notice appear in all # copies and that both that copyright notice and this permission # notice appear in supporting documentation, and that the name of # eGroups not be used in advertising or publicity pertaining to # distribution of the software without specific, written prior # permission. # # EGROUPS DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, # INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN # NO EVENT SHALL EGROUPS BE LIABLE FOR ANY SPECIAL, INDIRECT OR # CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS # OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN # CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import bisect import os import select import socket import string import sys import time import whrandom # This is a problem, because of /findmail/src/poll.py ## try: ## import poll ## USE_POLL = 1 ## except ImportError: ## USE_POLL = 0 USE_POLL = 0 LOG_LEVELS = 10 # Do not change this - jeske PRINT_LOG_LEVELS = 5 # Higher levels are more severe errors LOG_ERROR = 7 LOG_VERBOSE = 3 TIMEOUT_VALUE = 0xb00b BADF_ERROR = 0xc00c # sentinel value used by wait_for_read() and wait_for_write() USE_DEFAULT_TIMEOUT = -1 # a magic 'number' for the JoinTracker class. Really, this is just # useful for being a single fixed object - we do an 'is' comparison on it. MAGIC_JOIN_TIMEOUT_ARG = ('spam', 'eggs', 'Guido') from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, ENOTCONN import coroutine import exceptions class CoroutineSocketError (exceptions.Exception): pass class CoroutineCondError (exceptions.Exception): pass class CoroutineThreadError (exceptions.Exception): pass class TimeoutError (exceptions.Exception): pass # =========================================================================== # Coroutine Socket # =========================================================================== class coroutine_socket: "socket that automatically suspends/resumes instead of blocking." def __init__ (self, sock=None, timeout=None, connect_timeout=None): """Timeout semantics for __init__(): If you don't pass in a keyword arg for 'timeout' (or pass in a value of None), then the socket is blocking.""" self.socket = sock if sock: self.socket.setblocking (0) self.connected = 1 self.set_fileno() self._closed = 0 self._timeout = timeout self._connect_timeout = connect_timeout # getters and setters for those that need access (for example, # container objects that manage sockets returned by accept()) def timeout_(self, o): self._timeout = o; return self def timeout(self): return self._timeout def connect_timeout_(self, o): self._connect_timeout = o; return self def connect_timeout(self): return self._connect_timeout def set_fileno (self): self._fileno = self.socket.fileno() def fileno (self): return self._fileno def create_socket (self, family, type): self.socket = socket.socket (family, type) self.socket.setblocking (0) self.set_fileno() def wait_for_read (self, timeout=USE_DEFAULT_TIMEOUT): """Timeout semantics: No timeout keyword arg given means that the default given to __init__() is used. A timeout of None means that we will act as a blocking socket.""" if timeout == USE_DEFAULT_TIMEOUT: timeout = self._timeout me = current_thread() if timeout is not None: triple = the_event_list.insert_event (me, time.time()+timeout, TIMEOUT_VALUE) if me is None: raise CoroutineSocketError, "coroutine sockets cannot run in 'main'" else: read_set[self._fileno] = me try: result = me.yield() except: raise CoroutineSocketError, "coroutine socket could not yield" if result == TIMEOUT_VALUE: # i.e, we timed out if read_set.has_key(self._fileno): del read_set[self._fileno] raise TimeoutError, "request timed out in recv (%s secs)" % timeout elif timeout is not None: try: # remove event the_event_list.remove_event (triple) except: print "Error removing: result=%s" % result if result == BADF_ERROR: # ie, bad socket for select raise socket.error, "Invalid Socket" else: return 1 def wait_for_write (self, timeout=USE_DEFAULT_TIMEOUT): """Timeout semantics: No timeout keyword arg given means that the default given to __init__() is used. A timeout of None means that we will act as a blocking socket.""" if timeout == USE_DEFAULT_TIMEOUT: timeout = self._timeout me = current_thread() if timeout is not None: triple = the_event_list.insert_event (me, time.time()+timeout, TIMEOUT_VALUE) if me is None: raise CoroutineSocketError, "coroutine sockets cannot run in 'main'" else: write_set[self._fileno] = me try: result = me.yield() except: raise CoroutineSocketError, "coroutine socket could not yield" if result == TIMEOUT_VALUE: # i.e, we timed out if write_set.has_key(self._fileno): del write_set[self._fileno] raise TimeoutError, "request timed out in send (%s secs)" % timeout elif timeout is not None: # remove event the_event_list.remove_event (triple) if result == BADF_ERROR: # ie, bad socket for select raise socket.error, "Invalid Socket" else: return 1 def connect (self, address): try: return self.socket.connect (address) except socket.error, why: if why[0] in (EINPROGRESS, EWOULDBLOCK): self.wait_for_write (timeout=self._connect_timeout) ret = self.socket.getsockopt (socket.SOL_SOCKET, socket.SO_ERROR) if ret != 0: raise socket.error, (ret, os.strerror(ret)) return elif why[0] == EALREADY: return else: raise socket.error, why def recv (self, buffer_size): self.wait_for_read() return self.socket.recv (buffer_size) def recvfrom (self, buffer_size): self.wait_for_read() return self.socket.recvfrom (buffer_size) # Things we try to avoid: # 1) continually slicing a huge string into slightly smaller pieces # [e.g., 1MB, 1MB-8KB, 1MB-16KB, ...] # 2) forcing the kernel to copy huge strings for the syscall # # So, we try to send reasonably-sized slices of a large string. # If we were really smart we might try to adapt the amount we try to send # based on how much got through the last time. _max_send_block = 64 * 1024 def send (self, data): old = ld = len(data) while ld: self.wait_for_write() block = min (ld, self._max_send_block) start = old-ld end = start + block n = self.socket.send (data[start:end]) ld = ld - n return old def sendfile (self, fd, offset, size, headers=None, trailers=None): # damn, this is cool! import sendfile if headers: hl = len(headers) else: hl = 0 self.wait_for_write() sent = sendfile.sendfile (fd, self._fileno, offset, size, headers, trailers) if sent < hl: raise SystemError, "header didn't fit" else: offset = offset + (sent - hl) size = size - (sent - hl) while size > 0: self.wait_for_write() sent = sendfile.sendfile (fd, self._fileno, offset, size, headers, trailers) offset = offset + sent size = size - sent def sendto (self, data, where): if self.wait_for_write(): return self.socket.sendto (data, where) else: return 0 def bind (self, address): return self.socket.bind (address) def listen (self, queue_length): return self.socket.listen (queue_length) def accept (self): # accept() should not have a timeout self.wait_for_read(timeout=None) conn, addr = self.socket.accept() return self.__class__ (conn), addr def close (self): if not self._closed: self._closed = 1 if self.socket: return self.socket.close() else: return None def __del__ (self): self.close() def set_reuse_addr (self): # try to re-use a server port if possible try: self.socket.setsockopt ( socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 | self.socket.getsockopt (socket.SOL_SOCKET, socket.SO_REUSEADDR) ) except: pass # =========================================================================== # Condition Variable # =========================================================================== class coroutine_cond: def __init__ (self): self._waiting = {} def __len__(self): return len(self._waiting) def wait (self, timeout = None): thrd = current_thread() tid = thrd.thread_id() self._waiting[tid] = thrd result = thrd.yield(timeout) # If we have passed to here, we are not waiting any more, # so remove reference to thread: if self._waiting.has_key(tid): del self._waiting[tid] return result def wake (self, id): thrd = self._waiting.get(id, None) if thrd is None: raise CoroutineCondError, 'unknown thread <%d>' % (id) else: del self._waiting[id] schedule (thrd, ()) def wake_one (self, *args): if len(self._waiting): id = whrandom.choice(self._waiting.keys()) thrd = self._waiting[id] del self._waiting[id] schedule (thrd, args) def wake_all (self, *args): for thrd in self._waiting.values(): schedule (thrd, args) self._waiting = {} # =========================================================================== # Thread Abstraction # =========================================================================== _current_threads = {} class Thread: _thread_count = 0 def __init__ (self, group=None, target=None, name=None, args=(), kwargs={}): if Thread._thread_count == 0x7FFFFFFF: Thread._thread_count = 0 Thread._thread_count = Thread._thread_count + 1 self._thread_id = self._thread_count if name is None: self._name = 'thread_%d' % self._thread_id else: self._name = name self._target = target self._args = args self._kwargs = kwargs self._resume_count = 0 self._total_time = 0 self._alive = 0 self._started = 0 self._profile = 0 self._daemonic = 0 self._co = coroutine.new (self._run, 65536) self._status = 'initialized' self._log_level = PRINT_LOG_LEVELS def resume (self, args): if not _JoinTracker.okay_to_resume(self, args): self.yield() return self._resume(args) def _resume(self, args): if self._profile: self._resume_count = self._resume_count + 1 start_time = time.time() else: start_time = 0 if self._alive: # This first one will create a self-referenced tuple, which causes # a core dump. The second does not. Extremely weird. #result = coroutine.resume (self._co, (args,)) newargs = (args,) result = coroutine.resume (self._co, newargs) else: result = coroutine.resume (self._co, ()) if self._profile and start_time: end_time = time.time() self._total_time = self._total_time + (end_time - start_time) return result def start (self): self._started = 1 schedule (self) def _run (self): global _current_threads try: self._alive = 1 self._status = 'alive' _current_threads[self._co] = self if self._target is None: result = apply (self.run, self._args, self._kwargs) else: result = apply (self._target, self._args, self._kwargs) except coroutine.unwind: # kill() will cause this pass except: self._error = compact_traceback() self.log (LOG_ERROR, self._error) del _current_threads[self._co] self._alive = 0 self._status = 'dead' _JoinTracker.joinee_done(self) def __del__ (self): if self._alive: self.kill() def kill (self): _JoinTracker.joinee_done(self) if self._alive: if _current_threads.has_key(self._co): del _current_threads[self._co] coroutine.kill (self._co) def run (self): self.log (0, 'unregistered run method') def profile (self, status): self._profile = status # Higher level == more severe error def log_level_(self, level): if level not in range(LOG_LEVELS): raise CoroutineThreadError, 'error log level out of bounds' else: self._log_level = level def log (self, level, message): if level not in range(LOG_LEVELS): raise CoroutineThreadError, 'error log level out of bounds' if level > self._log_level: time_str = "[%02d/%02d %02d:%02d:%02d]" % ( time.localtime(time.time())[1:6] ) sys.stderr.write ( '%s thread %d: %s\n' % ( time_str, self._thread_id, str(message) ) ) def yield (self, timeout = None): if timeout is not None: triple = the_event_list.insert_event (self, time.time()+timeout, None) return coroutine.main(()) def thread_id (self): return self._thread_id def getName (self): return self._name def setName (self, name): self._name = name def isAlive (self): return self._alive def isDaemon (self): return self._daemonic def setDaemon (self, daemonic): self._daemonic = daemonic def status (self): print 'Thread status:' print ' id: ', self._thread_id print ' alive: ', self._alive if self._profile: print ' resume count:', self._resume_count print ' execute time:', self._total_time def join (self, timeout=None): """From the Threads documentation: join([timeout]): Wait until the thread terminates. This blocks the calling thread until the thread whose join() method is called terminates - either normally or through an unhandled exception - or until the optional timeout occurs. When the timeout argument is present and not None, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). A thread can be join()ed many times. A thread cannot join itself because this would cause a deadlock. It is an error to attempt to join() a thread before it has been started. """ caller = current_thread() if caller == self: raise CoroutineThreadError, "Cannot join() myself" if not self._started: msg = "Caller cannot join() an unstarted thread " raise CoroutineThreadError, msg % ( caller.thread_id(), self.thread_id() ) return _JoinTracker.block_until_done(caller, self, timeout) def __repr__ (self): if self._profile: p = ' resume_count: %d execute_time: %s' % ( self._resume_count, self._total_time ) else: p = '' if self._alive: a = 'running' else: a = 'suspended' return '<%s.%s id:%d %s %s%s at %x>' % ( __name__, self.__class__.__name__, self._thread_id, self._status, a, p, id(self) ) # # end class Thread # def compact_traceback (): t,v,tb = sys.exc_info() tbinfo = [] if tb is None: # this should never happen, but then again, lots of things # should never happen but do return (('','',''), str(t), str(v), 'traceback is None!!!') while 1: tbinfo.append ( tb.tb_frame.f_code.co_filename, tb.tb_frame.f_code.co_name, str(tb.tb_lineno) ) tb = tb.tb_next if not tb: break # just to be safe del tb file, function, line = tbinfo[-1] info = '[' + string.join ( map ( lambda x: string.join (x, '|'), tbinfo ), '] [' ) + ']' return (file, function, line), str(t), str(v), info # # =========================================================================== # global state and threadish API # =========================================================================== # # file descriptors waiting for a read event read_set = {} # file descriptors waiting for a write event write_set = {} # coroutines that are ready to run pending = {} def _socket (family, type, **kwargs): s = apply(coroutine_socket, (), kwargs) s.create_socket (family, type) return s make_socket = _socket def new (function, *args, **kwargs): return Thread (target=function, args=args, kwargs=kwargs) #def spawn (function, *args): # schedule (coroutine.new (function), args) def spawn (function, *args, **kwargs): Thread (target=function, args=args, kwargs=kwargs).start() def schedule (coroutine, args=None): "schedule a coroutine to run" pending[coroutine] = args #def yield (): # return coroutine.main() def yield(): return current_thread().yield() def thread_list(): return _current_threads.values() def current_thread(): co = coroutine.current() return _current_threads.get (co, None) current = current_thread def insert_thread(thrd): thrd.start() def run_pending(): "run all pending coroutines" while len(pending): try: # some of these will kick off others, thus the loop runnable = pending.items() pending.clear() for c,v in runnable: c.resume (v) except: # XXX can we throw the exception to the coroutine? import traceback traceback.print_exc() # uses poll(2) def poll_with_poll (timeout=30.0): if read_set or write_set: u = {} for fd in read_set.keys(): u[fd] = poll.POLLIN for fd in write_set.keys(): if u.has_key(fd): u[fd] = u[fd] | poll.POLLOUT else: u[fd] = poll.POLLOUT u = u.items() #print 'before',u u = poll.poll (u, timeout) #print 'after', u for fd, flags in u: if flags & poll.POLLIN: schedule (read_set[fd]) del read_set[fd] if flags & poll.POLLOUT: schedule (write_set[fd]) del write_set[fd] # uses select(2) def poll_with_select (timeout=30.0): if read_set or write_set: r = read_set.keys() w = write_set.keys() #print 'before: read: %d write: %d' % (len(r),len(w)) r,w,e = select.select (r,w, [], timeout) #print 'after: read: %d write: %d' % (len(r),len(w)) #sys.stdout.write ('- %d %d|' % (len(r),len(w))); sys.stdout.flush() for fd in r: schedule (read_set[fd]) del read_set[fd] for fd in w: schedule (write_set[fd]) del write_set[fd] ###################################################################### # the JoinTracker class # # a 'joiner' is a thread that calls join() on another thread. # a 'joinee' is a thread whose join() method was called. # # Rather than messing with the event list or the dictionary 'pending' # directly, we accomplish the join() functionality only by interfering # where a coroutine can be resumed - the resume() method of coro.Thread. # # It calls our okay_to_resume() call, which checks for events marked # by a sentinel value - MAGIC_JOIN_TIMEOUT_ARG - which is compared via # 'is'. These, which are the timeouts we ourselves have planted, are # allowed through for threads that are otherwise forbidden from being # resumed. # # We have two cases to take care of: if we timeout, we must update # our internal data. If the joinee actually dies, we need to remove # the timeout event from the event list, in addition to scheduling # all joining threads. # # PS If you want to extend join() to pass args back, that is easily # done - just modify triple[2] used here. You could even extend it # so you can tell directly whether your joinee died or you timed out. # If you use join() with timeout really excessively, you might # improve performance by changing the values of _joined_by to # dictionaries rather than lists. ###################################################################### class JoinTracker: """Keep track of which threads join() others.""" def __init__(self): # this dict maps joiner thread ids to joinee threads: self._wants_to_join = {} # this dict contains a reverse mapping of sorts: # joinee thread id to a list of triples of # # (timeoutTime, joiner, args), # # 'timeoutTime' the time at which the join() is # scheduled to time out # 'joiner' a joiner thread waiting for the joinee, # 'args' a magic value. # # If join() is called without a timeout, both timeoutTime # and args are None. self._joined_by = {} def okay_to_resume(self, thread, args): """Here is the hackery used to get the rest of the module to behave and not start threads it's not supposed to - and to start them again when it should.""" #print 'args:', args, 'thread id:', thread.thread_id(), ' resume' if not self._wants_to_join.has_key(thread.thread_id()): return 1 elif args[0] is MAGIC_JOIN_TIMEOUT_ARG: # if we are here, it means resume() called us, from # a timeout-induced scheduling. We must remove the # appropriate joiner from our data structures. joinee = self._wants_to_join[thread.thread_id()] del self._wants_to_join[thread.thread_id()] joineeId = joinee.thread_id() listOfTriples = self._joined_by[joineeId] numTriples = len(listOfTriples) if numTriples == 1: del self._joined_by[joineeId] else: for i in xrange(numTriples): triple = listOfTriples[i] if triple[1] == thread: del listOfTriples[i] break return 1 else: #return not self._wants_to_join[thread.thread_id()].isAlive() return 0 def block_until_done(self, joiner, joinee, timeout): """Block joiner from running until joinee exits.""" joineeId = joinee.thread_id() joinerId = joiner.thread_id() if self._wants_to_join.has_key(joinerId): msg = "joiner thread already called join()!" % joinerId raise CoroutineThreadError, msg if timeout is None: triple = (None, joiner, None) else: # the following call returns its args, with the first two switched triple = the_event_list.insert_event(joiner, time.time()+timeout, [MAGIC_JOIN_TIMEOUT_ARG]) self._wants_to_join[joinerId] = joinee if not self._joined_by.has_key(joineeId): self._joined_by[joineeId] = [] self._joined_by[joineeId].append(triple) return joiner.yield() def joinee_done(self, joinee): """Schedule all joiners and remove them from the joining dict. This is called by a thread upon its demise. We also remove the timeout events from the event list since we are immediately scheduling the threads.""" joineeId = joinee.thread_id() if self._joined_by.has_key(joineeId): for triple in self._joined_by[joineeId]: joiner = triple[1] joinerId = joiner.thread_id() if self._wants_to_join.has_key(joinerId): del self._wants_to_join[joinerId] # triple[0] is the time at which our event was scheduled, # if we had an event to schedule if triple[0] is not None: try: the_event_list.remove_event(triple) except ValueError: pass schedule(joiner) del self._joined_by[joineeId] def __del__(self): self.close() def close(self): """Clean out the dictionaries used to track joining relationships. This avoids problems caused by a thread's __del__ method calling joinee_done() when we are trying to shut things down.""" self._joined_by = {} self._wants_to_join = {} # end class JoinTracker _JoinTracker = JoinTracker() class event_list: def __init__ (self): self.events = [] def __nonzero__ (self): return len(self.events) def __len__ (self): return len(self.events) def insert_event (self, co, when, args): triple = (when, co, args) bisect.insort (self.events, triple) return triple def remove_event (self, triple): self.events.remove (triple) def sleep_absolute (self, when, *args): me = current_thread() self.insert_event (me, when, args) me.yield() def sleep_relative (self, delta, *args): me = current_thread() self.insert_event (me, time.time()+delta, args) me.yield() def run_scheduled (self): now = time.time() i = j = 0 while i < len(self.events): when, thread, args = self.events[i] if now >= when: schedule (thread, args) j = i + 1 else: break i = i + 1 self.events = self.events[j:] return None def next_event (self, max_timeout=30.0): now = time.time() if len(self.events): when, thread, args = self.events[0] return min (max_timeout, max(when-now, 0)) else: return max_timeout return None the_event_list = event_list() sleep_absolute = the_event_list.sleep_absolute sleep_relative = the_event_list.sleep_relative exit = 0 def reset_event_loop(): # # dump all events/threads, I'm using this to shut down, so # if the event_loop exits, I want to startup another eventloop # thread set to perform shutdown functions. # global the_event_list global read_set global write_set global pending global _JoinTracker the_event_list = event_list() read_set = {} write_set = {} pending = {} _JoinTracker.close() _JoinTracker = JoinTracker() return None def _continue_event_loop(): if exit: return 0 else: return ( len(the_event_list) + len(read_set) + len(write_set) + len(pending) ) def event_loop (max_timeout=30.0): if USE_POLL: max_timeout = int (max_timeout * 1000) poll_fun = poll_with_poll else: poll_fun = poll_with_select while _continue_event_loop(): the_event_list.run_scheduled() run_pending() delta = the_event_list.next_event (max_timeout) try: poll_fun (timeout=delta) except select.error: b, thrd = find_broken_socket() if b is not None and thrd is not None: # print "*** broken socket: %s, %s" % (repr(b), repr(thrd)) schedule (thrd, BADF_ERROR) def find_broken_socket(): # ick. scan through all the sockets to find out which one is bad. # XXX: in every case where this has happened, socket.fileno() yields # -1, which is probably a much cheaper way to find Mr. Broken. for s in read_set.keys(): try: r,w,e = select.select ([s],[s],[s], 0.0) except select.error: thrd = read_set[s] del read_set[s] return s, thrd for s in write_set.keys(): try: r,w,e = select.select ([s],[s],[s], 0.0) except select.error: thrd = write_set[s] del write_set[s] return s, thrd return None, None