# -*- Mode: Python -*- # Copyright 1999 by eGroups, Inc. # # All Rights Reserved # # Permission to use, copy, modify, and distribute this software and # its documentation for any purpose and without fee is hereby # granted, provided that the above copyright notice appear in all # copies and that both that copyright notice and this permission # notice appear in supporting documentation, and that the name of # eGroups not be used in advertising or publicity pertaining to # distribution of the software without specific, written prior # permission. # # EGROUPS DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, # INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN # NO EVENT SHALL EGROUPS BE LIABLE FOR ANY SPECIAL, INDIRECT OR # CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS # OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN # CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. VERSION_STRING = '$Id: //depot/main/findmail/src/coroutine/coromysql.py#39 $' # Note: this should be split into two modules, one which is unaware of # the distinction between blocking and coroutine sockets. # Strategies: # 1) Simple; schedule at the socket level. # Just like the mysql client library, when a coroutine accesses # the mysql client object, it will automatically detach when # the socket gets EWOULDBLOCK. # 2) Smart; schedule at the request level. # Use a separate coroutine to manage the mysql connection. # A client coroutine will resume when a response is available. # 3) Sophisticated; schedule at the row level. # Allow a client coroutine to peel off rows one at a time. # # # Currently I am trying to emulate MySQLmodule.c as closely as possible # so it can be used as a drop-in replacement, from Mysqldb.py and up., If # all the commands are not there, they are being added on an as needed # basis. # # The one place where the stradegy takes a different route than the module # is auto reconnects. The module does not perform auto reconnect, so it is # left to higher layers like Mysqldb.py, which uses sleep. Unfortunatly # regular sleep will block an entire process. (bad) So we are adding # auto-reconnect to this module. -Libor 10/10/99 # import exceptions import math import socket import string import sys import coro def log(msg): sys.stderr.write(msg + '\n') MAX_RECONNECT_RETRY = 10.0 RECONNECT_RETRY_GRAIN = 0.1 DEFAULT_RECV_SIZE = 0x8000 MYSQL_HEADER_SIZE = 0x4 class InternalError (exceptions.Exception): pass class error (exceptions.Exception): pass # =========================================================================== # Authentication # =========================================================================== # Note: I have ignored the stuff to support an older version of the protocol. # # The code is based on the file mysql-3.21.33/client/password.c # # The auth scheme is challenge/response. Upon connection the server # sends an 8-byte challenge message. This is hashed with the password # to produce an 8-byte response. The server side performs an identical # hash to verify the password is correct. class random_state: def __init__ (self, seed, seed2): self.max_value = 0x3FFFFFFF self.seed = seed % self.max_value self.seed2 = seed2 % self.max_value return None def rnd (self): self.seed = (self.seed * 3 + self.seed2) % self.max_value self.seed2 = (self.seed + self.seed2 + 33) % self.max_value return float(self.seed)/ float(self.max_value) def hash_password (password): nr=1345345333L add=7 nr2=0x12345671L for ch in password: if (ch == ' ') or (ch == '\t'): continue tmp = ord(ch) nr = nr ^ (((nr & 63) + add) * tmp) + (nr << 8) nr2 = nr2 + ((nr2 << 8) ^ nr) add = add + tmp return (nr & ((1L<<31)-1L), nr2 & ((1L<<31)-1L)) def scramble (message, password): hash_pass = hash_password (password) hash_mess = hash_password (message) r = random_state (hash_pass[0] ^ hash_mess[0], hash_pass[1] ^ hash_mess[1]) to = [] for ch in message: to.append (int (math.floor ((r.rnd() * 31) + 64))) extra = int (math.floor (r.rnd()*31)) for i in range(len(to)): to[i] = to[i] ^ extra return to # =========================================================================== # Packet Protocol # =========================================================================== def unpacket (p): # 3-byte length, one-byte packet number, followed by packet data a,b,c,s = map (ord, p[:4]) l = a | (b << 8) | (c << 16) # s is a sequence number return l, s def packet (data, s = 0): l = len(data) a, b, c = l & 0xff, (l>>8) & 0xff, (l>>16) & 0xff h = map (chr, [a,b,c,s]) return string.join (h,'') + data def n_byte_num (data, n, pos=0): result = 0 for i in range(n): result = result | (ord(data[pos+i])<<(8*i)) return result def decode_length (data, pos=0): n = ord(data[pos]) if n < 251: return n, 1 elif n == 251: return 0, 1 elif n == 252: return n_byte_num (data, 2, pos+1), 3 elif n == 253: return n_byte_num (data, 3, pos+1), 4 else: # libmysql adds 6, why? return n_byte_num (data, 4, pos+1), 5 # used to generate the dumps below def dump_hex (s): r1 = [] r2 = [] for ch in s: r1.append (' %02x' % ord(ch)) if (ch in string.letters) or (ch in string.digits): r2.append (' %c' % ch) else: r2.append (' ') return string.join (r1, ''), string.join (r2, '') # =========================================================================== # generic utils # =========================================================================== def is_disconnect(reason): if string.find(string.lower(repr(reason.args)), "lost connection") != -1: return 1 if string.find(string.lower(repr(reason.args)), "no connection") != -1: return 1 if string.find(string.lower(repr(reason.args)),"server has gone away") != -1: return 1 return 0 def null_func(): pass class mysql_client: def __init__ (self, username, password, address = ('127.0.0.1', 3306), debug = 0, timeout=None, connect_timeout=None): # remember this for reconnect self.username = username self.password = password self.address = address self._database = None self._connected = 0 self._recv_buffer = '' self._recv_length = 0 self._lock = 0 self._debug = debug if not self._debug: self._timer_cond = coro.coroutine_cond() self._lock_cond = coro.coroutine_cond() self._timeout = timeout self._connect_timeout = connect_timeout def sleep(self, *args, **kwargs): if self._debug: return apply(time.sleep, args, kwargs) else: return apply(self._timer_cond.wait, args, kwargs) def make_socket(self, *args, **kwargs): if self._debug: return apply(socket.socket, args, kwargs) else: # keyword args for timeout passed in should override the values # set in __init__() if not kwargs.has_key('timeout'): kwargs['timeout'] = self._timeout if not kwargs.has_key('connect_timeout'): kwargs['connect_timeout'] = self._connect_timeout return apply(coro.make_socket, args, kwargs) def lock_wait(self, *args, **kwargs): if not self._debug: return apply(self._lock_cond.wait, args, kwargs) def lock_wake(self, *args, **kwargs): if not self._debug: return apply(self._lock_cond.wake_one, args, kwargs) def _lock_connection(self): while self._lock: self.lock_wait() self._lock = 1 return None def _unlock_connection(self): self._lock = 0 self.lock_wake() # # yield to give someone else a chance # self.sleep(0.0) return None def _connect (self): try: self.socket = self.make_socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.connect (self.address) except socket.error, msg: raise InternalError, "No connection to MySQL server." self._recv_buffer = '' self._recv_length = 0 return None def recv(self): try: data = self.socket.recv (DEFAULT_RECV_SIZE) except socket.error, msg: raise InternalError, "No connection to MySQL server during query" if not data: raise InternalError, "Lost connection to MySQL server during query" else: self._recv_buffer = self._recv_buffer + data self._recv_length = self._recv_length + len(data) return None def write (self, data): ln = len(data) while data: try: n = self.socket.send (data) except socket.error, msg: raise InternalError, "No connection to MySQL server during query" if not n: raise InternalError, "Lost connection to MySQL server during query" else: data = data[n:] return ln debug = 0 def send_packet (self, data, sequence=0): if self.debug: print '--> %03d' % sequence a, b = dump_hex (data) print a print b self.write (packet (data, sequence)) return None def get_header(self): if self._recv_length < MYSQL_HEADER_SIZE: return None, None else: # 3-byte length, one-byte packet number, followed by packet data a,b,c, seq = map (ord, self._recv_buffer[:MYSQL_HEADER_SIZE]) length = a | (b << 8) | (c << 16) return length, seq def read_packet (self): packet_len, seq = self.get_header() while MYSQL_HEADER_SIZE > self._recv_length or \ packet_len + MYSQL_HEADER_SIZE > self._recv_length: self.recv() if packet_len is None: packet_len, seq = self.get_header() # # now we have at least one packet # data = self._recv_buffer[MYSQL_HEADER_SIZE:MYSQL_HEADER_SIZE+packet_len] self._recv_buffer = self._recv_buffer[MYSQL_HEADER_SIZE+packet_len:] self._recv_length = self._recv_length - (MYSQL_HEADER_SIZE+packet_len) return seq, data def _login (self): seq, data = self.read_packet() # unpack the greeting protocol_version = ord(data[0]) eos = string.find (data, '\000') mysql_version = data[1:eos] #thread_id = n_byte_num (data[eos+1:eos+5], 4, eos) thread_id = n_byte_num (data, 4, eos+1) challenge = data[eos+5:eos+13] auth = (protocol_version, mysql_version, thread_id, challenge) lp = self.build_login_packet (challenge) # seems to require a sequence number of one self.send_packet (lp, 1) # # read the response, which will check for errors # response_tuple = self.read_reply_header() if response_tuple != (0, 0, 0): raise InternalError, 'unknow header response: <%s>' % \ (repr(response_tuple)) # # mark that we are now connected # return None def check_connection(self): if not self._connected: self._connect() self._login() if self._database is not None: self.cmd_use(self._database) self._connected = 1 return None def build_login_packet (self, challenge): auth = string.join (map (chr, scramble (challenge, self.password)), '') # 2 bytes of client_capability # 3 bytes of max_allowed_packet # no idea what they are return '\005\000\000\000\020' + self.username + '\000' + auth def command (self, command_type, command): q = chr(decode_db_cmds[command_type]) + command self.send_packet (q, 0) return None def unpack_data (self, d): r = [] i = 0 while i < len(d): fl = ord(d[i]) if fl > 250: fl, scoot = decode_length (d, i) i = i + scoot else: i = i + 1 r.append (d[i:i+fl]) i = i + fl return r def unpack_int(self, data_str): if len(data_str) > 4: raise TypeError, 'data too long to be an int32: <%d>' % len(data_str) value = 0 while len(data_str): i = ord(data_str[len(data_str)-1]) data_str = data_str[:len(data_str)-1] value = value + (i << (8 * len(data_str))) return value def read_reply_header(self): # # read in the reply header and return the results. # seq, data = self.read_packet() rows_in_set = 0 affected_rows = 0 insert_id = 0 if data[0] == chr(0xff): error_num = ord(data[1]) + (ord(data[2]) << 8) error_msg = data[3:] raise InternalError, 'ERROR %d: %s' % (error_num, error_msg) elif data[0] == chr(0xfe): raise InternalError, 'unknown header <%s>' % (repr(data)) else: rows_in_set, move = decode_length(data, 0) data = data[move:] if len(data): affected_rows, move = decode_length(data, 0) data = data[move:] insert_id, move = decode_length(data, 0) data = data[move:] msg = data return rows_in_set, affected_rows, insert_id # # Internal mysql client requests to get raw data from db (cmd_*) # def cmd_use (self, database): self.command ('init_db', database) rows, affected, insert_id = self.read_reply_header() if rows != 0 or affected != 0 or insert_id != 0: msg = 'unexpected header: <%d> <%d> <%d>' % (rows, affected, insert_id) raise InternalError, msg self._database = database return None def cmd_query (self, query): # print 'coro mysql query: "%s"' % (repr(query)) self.command ('query', query) # # read in the header # nfields, affected, insert_id = self.read_reply_header() if not nfields: return statement([], [], affected, insert_id) decoders = range(nfields) fields = [] i = 0 while 1: seq, data = self.read_packet() if data == chr(0xfe): break else: field = self.unpack_data (data) decoders[i] = decode_type_map[ord(field[3])] fields.append (field) i = i + 1 if len(fields) != nfields: raise InternalError, "number of fields did not match" # read rows rows = [] field_range = range(nfields) while 1: seq, data = self.read_packet() if data == chr(0xfe): break else: row = self.unpack_data (data) # apply decoders for i in field_range: row[i] = decoders[i](row[i]) rows.append (row) return statement(fields, rows) def cmd_quit (self): self.command ('quit', '') # # no reply! # return None def cmd_shutdown (self): self.command ('shutdown', '') seq, data = self.read_packet() print "shutdown: seq: <%s> data: <%s>" % (repr(seq), repr(data)) return None def cmd_drop (self, db_name): self.command ('drop_db', db_name) nfields, affected, insert_id = self.read_reply_header() return None def cmd_listfields(self, cmd): self.command ('field_list', cmd) rows = [] # # read data line until we get 255 which is error or 254 which is # end of data ( I think :-) # while 1: seq, data = self.read_packet() # # terminal cases. # if data[0] == chr(0xff): raise InternalError, data[3:] elif data[0] == chr(0xfe): return rows else: row = self.unpack_data(data) table_name = row[0] field_name = row[1] field_size = self.unpack_int(row[2]) field_type = decode_type_names[ord(row[3])] field_flag = self.unpack_int(row[4]) field_val = row[5] flag_value = '' if field_flag & decode_flag_value['pri_key']: flag_value = flag_value + decode_flag_name['pri_key'] if field_flag & decode_flag_value['not_null']: flag_value = flag_value + ' ' + decode_flag_name['not_null'] if field_flag & decode_flag_value['unique_key']: flag_value = flag_value + ' ' + decode_flag_name['unique_key'] if field_flag & decode_flag_value['multiple_key']: flag_value = flag_value + ' ' + decode_flag_name['multiple_key'] if field_flag & decode_flag_value['auto']: flag_value = flag_value + ' ' + decode_flag_name['auto'] # # for some reason we do not pass back the default value (row[5]) # rows.append([field_name, table_name, field_type, field_size, flag_value]) return None def cmd_create(self, name): self.command ('create_db', name) # # response # nfields, affected, insert_id = self.read_reply_header() return None def _execute_with_retry(self, method_name, args = ()): method = getattr(self, method_name) retry_count = 0 error_msg = None # # lock down the connection while we are performing a query. # self._lock_connection() try: while error_msg is None: try: self.check_connection() retval = apply(method, args) except InternalError, msg: if not is_disconnect(msg): error_msg = msg[0] elif retry_count > MAX_RECONNECT_RETRY: error_msg = "connect retry limit reached <%d>" % \ (MAX_RECONNECT_RETRY) else: sleep_time = retry_count * RECONNECT_RETRY_GRAIN log("<%s> lost connection, sleeping <%0.1f>" % \ (self.address[0], sleep_time)) retry_count = retry_count + 1 self._connected = 0 self.sleep(sleep_time) else: break finally: # # unlock the connection ( I must be retarded for not having this in # a finally clause. I was loosing the connect thread, and everyone # else was frozen on the lock.) # self._unlock_connection() if error_msg is not None: raise error, error_msg return retval # # MySQL module compatibility, properly wraps raw client requests, # to format the return types. # # use_result option is currently not implemented, if anyone has the # time, please add support for it. Libor 4/2/00 # def selectdb(self, database, use_result = 0): return self._execute_with_retry('cmd_use', (database,)) def query (self, q, use_result = 0): return self._execute_with_retry('cmd_query', (q,)) def listtables (self, wildcard = None): if wildcard is None: cmd = "show tables" else: cmd = "show tables like '%s'" % (wildcard) o = self._execute_with_retry('cmd_query', (cmd,)) return o.fetchrows() def listfields (self, table_name, wildcard = None): if wildcard is None: cmd = "%s\000\000" % (table_name) else: cmd = "%s\000%s\000" % (table_name, wildcard) return self._execute_with_retry('cmd_listfields', (cmd,)) def drop(self, database, use_result = 0): return self._execute_with_retry('cmd_drop', (database,)) def create(self, db_name, use_result = 0): return self._execute_with_retry('cmd_create', (db_name,)) def close(self): return self.cmd_quit() # compatibility layer, avoid it if you can by using cmd_query directly. # incomplete and hackish. perhaps a better solution would be to implement # the DB API ourselves rather than using Mysqldb.py class statement: def __init__ (self, fields, rows, affected_rows = -1, insert_id = 0): self._fields = fields self._rows = rows if affected_rows < 0: self._affected_rows = len(rows) else: self._affected_rows = affected_rows self._index = 0 self._insert_id = insert_id return None # ======================================================================= # internal methods # ======================================================================= def _fetchone (self): if self._index < len(self._rows): result = self._rows[self._index] self._index = self._index + 1 else: result = [] return result def _fetchmany (self, size): result = self._rows[self._index:self._index + size] self._index = self._index + len(result) return result def _fetchall (self): result = self._rows[self._index:] self._index = self._index + len(result) return result # ======================================================================= # external methods # ======================================================================= def affectedrows (self): return self._affected_rows def numrows (self): return len(self._rows) def numfields(self): return len(self._fields) def fields (self): # raw format: # table, fieldname, ??? (flags?), datatype # ['groupmap', 'gid', '\013\000\000', '\003', '\013B\000'] # MySQL returns # ['gid', 'groupmap', 'long', 11, 'pri notnull auto_inc mkey'] return map (lambda x: (x[1], x[0], decode_type_names[ord(x[3])], ord(x[4][0])), self._fields) def fetchrows(self, size = 0): if size: return self._fetchmany(size) else: return self._fetchall() # [{'groupmap.podid': 2, # 'groupmap.listname': 'medusa', # 'groupmap.active': 'y', # 'groupmap.gid': 116225, # 'groupmap.locked': 'n'}] def fetchdict (self, size = 0): keys = map (lambda x: "%s.%s" % (x[0], x[1]), self._fields) range_len_keys = range(len(keys)) result = [] for row in self.fetchrows(size): d = {} for j in range_len_keys: d[keys[j]] = row[j] result.append(d) return result def insert_id (self): # i have no idea what this is return self._insert_id # ====================================================================== # decoding MySQL data types # # from mysql-3.21.33/include/mysql_com.h.in # # by default leave as a string decode_type_map = [str] * 256 decode_type_names = ['unknown'] * 256 # Many of these are not correct! Note especially # the time/date types... If you want to write a real decoder # for any of these, just replace 'str' with your function. for code, cast, name in ( (0, int, 'decimal'), (1, int, 'tiny'), (2, int, 'short'), (3, int, 'long'), (4, float, 'float'), (5, float, 'double'), (6, str, 'null'), (7, str, 'timestamp'), #(8, long, 'longlong'), (8, str, 'unhandled'), # Mysqldb expects unhandled. strange. (9, int, 'int24'), (10, str, 'date'), # looks like YYYY-MM-DD ?? (11, str, 'time'), # looks like HH:MM:SS (12, str, 'datetime'), (13, str, 'year'), (14, str, 'newdate'), (247, str, 'enum'), (248, str, 'set'), (249, str, 'tiny_blob'), (250, str, 'medium_blob'), (251, str, 'long_blob'), (252, str, 'blob'), (253, str, 'varchar'), # in the C code it is VAR_STRING (254, str, 'string') ): decode_type_map[code] = cast decode_type_names[code] = name # # we need flag mappings also # decode_flag_value = {} decode_flag_name = {} for value, flag, name in ( (1, 'not_null', 'notnull'), # Field can not be NULL (2, 'pri_key', 'pri'), # Field is part of a primary key (4, 'unique_key', 'ukey'), # Field is part of a unique key (8, 'multiple_key', 'mkey'), # Field is part of a key (16, 'blob', 'unused'), # Field is a blob (32, 'unsigned', 'unused'), # Field is unsigned (64, 'zerofill', 'unused'), # Field is zerofill (128, 'binary', 'unused'), (256, 'enum', 'unused'), # field is an enum (512, 'auto', 'auto_inc'), # field is a autoincrement field (1024, 'timestamp', 'unused'), # Field is a timestamp (2048, 'set', 'unused'), # field is a set (16384, 'part_key', 'unused'), # Intern; Part of some key (32768, 'group', 'unused'), # Intern: Group field (65536, 'unique', 'unused') # Intern: Used by sql_yacc ): decode_flag_value[flag] = value decode_flag_name[flag] = name # # database commands # decode_db_cmds = {} for value, name in ( (0, 'sleep'), (1, 'quit'), (2, 'init_db'), (3, 'query'), (4, 'field_list'), (5, 'create_db'), (6, 'drop_db'), (7, 'refresh'), (8, 'shutdown'), (9, 'statistics'), (10, 'process_info'), (11, 'connect'), (12, 'process_kill'), (13, 'debug') ): decode_db_cmds[name] = value ## ====================================================================== ## ## SMR - borrowed from daGADFLY.py, moved dict 'constant' out of ## function definition. # #quote_for_escape = {'\0': '\\0', "'": "''", '"': '""', '\\': '\\\\'} # martinb - changed to match the behaviour of MySQL: quote_for_escape = {'\0': '\\0', "'": "\\'", '"': '\\"', '\\': '\\\\'} import types def escape(s): quote = quote_for_escape if type(s) == types.IntType: return str(s) elif s == None: return "" elif type(s) == types.StringType: r = range(len(s)) r.reverse() # iterate backwards, so as not to destroy indexing for i in r: if quote.has_key(s[i]): s = s[:i] + quote[s[i]] + s[i+1:] return s else: log(s) log (type(s)) raise MySQLError # # MySQL module compatibility # def connect (host, user, passwd, timeout=None, connect_timeout=None): conn = mysql_client(user, passwd, (host, 3306), debug=0, timeout=timeout, connect_timeout=connect_timeout) # I've found that this is the best way to maximize the number of ultimately # successful requests if many threads (>50) are running. - martinb 99/11/03 try: conn.check_connection() except InternalError, msg: pass return conn # # testing emulation: # import coromysql # coromysql.emulate(); import mlm ; mls = mlm.MLS() ; mls.GetList ('medusa') # def emulate(): "have this module pretend to be the real MySQL module" sys.modules['MySQL'] = sys.modules[__name__] def test (): c = mysql_client ('rushing', 'fnord', ('127.0.0.1', 3306)) print 'connecting...' c.connect() print 'logging in...' c.login() print c c.cmd_use ('mysql') print c.cmd_query ('select * from host') c.cmd_quit() if __name__ == '__main__': for i in range(10): coro.spawn (test) coro.event_loop (30.0) # # - mysql_client is analogous to DBH in MySQLmodule.c, and statment is # analogous to STH in MySQLmodule.c # - DBH is the database handler, and STH is the statment handler, # - Here are the methods that the MySQLmodule.c implements, and if they # are at least attempted here in coromysql # # DBH: # # "selectdb" - yes # "do" - no # "query" - yes # "listdbs" - no # "listtables" - yes # "listfields" - yes # "listprocesses" - no # "create" - yes # "stat" - no # "clientinfo" - no # "hostinfo" - no # "serverinfo" - no # "protoinfo" - no # "drop" - yes # "reload" - no # "insert_id" - no # "close" - yes # "shutdown" - no # # STH: # # "fields" - yes # "fetchrows" - yes # "fetchdict" - yes # "seek" - no # "numrows" - yes # "numfields" - yes # "eof" - no # "affectedrows" - yes # "insert_id" - yes