diff --git a/Demo/rpc/rpc.py b/Demo/rpc/rpc.py index d1c2c5e599e..00397dd852e 100644 --- a/Demo/rpc/rpc.py +++ b/Demo/rpc/rpc.py @@ -1,4 +1,4 @@ -# Implement (a subset of) Sun RPC, version 2 -- RFC1057. +# Sun RPC version 2 -- RFC1057. # XXX There should be separate exceptions for the various reasons why # XXX an RPC can fail, rather than using RuntimeError for everything @@ -177,8 +177,8 @@ class Client: self.port = port self.makesocket() # Assigns to self.sock self.bindsocket() - self.sock.connect((host, port)) - self.lastxid = 0 + self.connsocket() + self.lastxid = 0 # XXX should be more random? self.addpackers() self.cred = None self.verf = None @@ -191,6 +191,10 @@ class Client: # This MUST be overridden raise RuntimeError, 'makesocket not defined' + def connsocket(self): + # Override this if you don't want/need a connection + self.sock.connect((self.host, self.port)) + def bindsocket(self): # Override this to bind to a different port (e.g. reserved) self.sock.bind(('', 0)) @@ -200,6 +204,21 @@ class Client: self.packer = Packer().init() self.unpacker = Unpacker().init('') + def make_call(self, proc, args, pack_func, unpack_func): + # Don't normally override this (but see Broadcast) + if pack_func is None and args is not None: + raise TypeError, 'non-null args with null pack_func' + self.start_call(proc) + if pack_func: + pack_func(args) + self.do_call() + if unpack_func: + result = unpack_func() + else: + result = None + self.unpacker.done() + return result + def start_call(self, proc): # Don't override this self.lastxid = xid = self.lastxid + 1 @@ -209,14 +228,10 @@ class Client: p.reset() p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf) - def do_call(self, *rest): + def do_call(self): # This MUST be overridden raise RuntimeError, 'do_call not defined' - def end_call(self): - # Don't override this - self.unpacker.done() - def mkcred(self): # Override this to use more powerful credentials if self.cred == None: @@ -230,9 +245,7 @@ class Client: return self.verf def Null(self): # Procedure 0 is always like this - self.start_call(0) - self.do_call(0) - self.end_call() + return self.make_call(0, None, None, None) # Record-Marking standard support @@ -293,23 +306,14 @@ def bindresvport(sock, host): raise RuntimeError, 'can\'t assign reserved port' -# Raw TCP-based client +# Client using TCP to a specific port class RawTCPClient(Client): def makesocket(self): self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - def start_call(self, proc): - self.lastxid = xid = self.lastxid + 1 - cred = self.mkcred() - verf = self.mkverf() - p = self.packer - p.reset() - p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf) - - def do_call(self, *rest): - # rest is used for UDP buffer size; ignored for TCP + def do_call(self): call = self.packer.get_buf() sendrecord(self.sock, call) reply = recvrecord(self.sock) @@ -321,41 +325,25 @@ class RawTCPClient(Client): raise RuntimeError, 'wrong xid in reply ' + `xid` + \ ' instead of ' + `self.lastxid` - def end_call(self): - self.unpacker.done() - -# Raw UDP-based client +# Client using UDP to a specific port class RawUDPClient(Client): def makesocket(self): self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - def start_call(self, proc): - self.lastxid = xid = self.lastxid + 1 - cred = self.mkcred() - verf = self.mkverf() - p = self.packer - p.reset() - p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf) - - def do_call(self, *rest): + def do_call(self): + call = self.packer.get_buf() + self.sock.send(call) try: from select import select except ImportError: print 'WARNING: select not found, RPC may hang' select = None - if len(rest) == 0: - bufsize = 8192 - elif len(rest) > 1: - raise TypeError, 'too many args' - else: - bufsize = rest[0] + 512 - call = self.packer.get_buf() + BUFSIZE = 8192 # Max UDP buffer size timeout = 1 count = 5 - self.sock.send(call) while 1: r, w, x = [self.sock], [], [] if select: @@ -367,7 +355,7 @@ class RawUDPClient(Client): ## print 'RESEND', timeout, count self.sock.send(call) continue - reply = self.sock.recv(bufsize) + reply = self.sock.recv(BUFSIZE) u = self.unpacker u.reset(reply) xid, verf = u.unpack_replyheader() @@ -376,14 +364,71 @@ class RawUDPClient(Client): continue break - def end_call(self): - self.unpacker.done() + +# Client using UDP broadcast to a specific port + +class RawBroadcastUDPClient(RawUDPClient): + + def init(self, bcastaddr, prog, vers, port): + self = RawUDPClient.init(self, bcastaddr, prog, vers, port) + self.reply_handler = None + self.timeout = 30 + return self + + def connsocket(self): + # Don't connect -- use sendto + self.sock.allowbroadcast(1) + + def set_reply_handler(self, reply_handler): + self.reply_handler = reply_handler + + def set_timeout(self, timeout): + self.timeout = timeout # Use None for infinite timeout + + def make_call(self, proc, args, pack_func, unpack_func): + if pack_func is None and args is not None: + raise TypeError, 'non-null args with null pack_func' + self.start_call(proc) + if pack_func: + pack_func(args) + call = self.packer.get_buf() + self.sock.sendto(call, (self.host, self.port)) + try: + from select import select + except ImportError: + print 'WARNING: select not found, broadcast will hang' + select = None + BUFSIZE = 8192 # Max UDP buffer size (for reply) + replies = [] + if unpack_func is None: + def dummy(): pass + unpack_func = dummy + while 1: + r, w, x = [self.sock], [], [] + if select: + if self.timeout is None: + r, w, x = select(r, w, x) + else: + r, w, x = select(r, w, x, self.timeout) + if self.sock not in r: + break + reply, fromaddr = self.sock.recvfrom(BUFSIZE) + u = self.unpacker + u.reset(reply) + xid, verf = u.unpack_replyheader() + if xid <> self.lastxid: +## print 'BAD xid' + continue + reply = unpack_func() + self.unpacker.done() + replies.append((reply, fromaddr)) + if self.reply_handler: + self.reply_handler(reply, fromaddr) + return replies # Port mapper interface -# XXX CALLIT is not implemented - # Program number, version and (fixed!) port number PMAP_PROG = 100000 PMAP_VERS = 2 @@ -421,6 +466,13 @@ class PortMapperPacker(Packer): def pack_pmaplist(self, list): self.pack_list(list, self.pack_mapping) + def pack_call_args(self, ca): + prog, vers, proc, args = ca + self.pack_uint(prog) + self.pack_uint(vers) + self.pack_uint(proc) + self.pack_opaque(args) + class PortMapperUnpacker(Unpacker): @@ -434,6 +486,11 @@ class PortMapperUnpacker(Unpacker): def unpack_pmaplist(self): return self.unpack_list(self.unpack_mapping) + def unpack_call_result(self): + port = self.unpack_uint() + res = self.unpack_opaque() + return port, res + class PartialPortMapperClient: @@ -442,35 +499,29 @@ class PartialPortMapperClient: self.unpacker = PortMapperUnpacker().init('') def Set(self, mapping): - self.start_call(PMAPPROC_SET) - self.packer.pack_mapping(mapping) - self.do_call() - res = self.unpacker.unpack_uint() - self.end_call() - return res + return self.make_call(PMAPPROC_SET, mapping, \ + self.packer.pack_mapping, \ + self.unpacker.unpack_uint) def Unset(self, mapping): - self.start_call(PMAPPROC_UNSET) - self.packer.pack_mapping(mapping) - self.do_call() - res = self.unpacker.unpack_uint() - self.end_call() - return res + return self.make_call(PMAPPROC_UNSET, mapping, \ + self.packer.pack_mapping, \ + self.unpacker.unpack_uint) def Getport(self, mapping): - self.start_call(PMAPPROC_GETPORT) - self.packer.pack_mapping(mapping) - self.do_call(4) - port = self.unpacker.unpack_uint() - self.end_call() - return port + return self.make_call(PMAPPROC_GETPORT, mapping, \ + self.packer.pack_mapping, \ + self.unpacker.unpack_uint) def Dump(self): - self.start_call(PMAPPROC_DUMP) - self.do_call(8192-512) - list = self.unpacker.unpack_pmaplist() - self.end_call() - return list + return self.make_call(PMAPPROC_DUMP, None, \ + None, \ + self.unpacker.unpack_pmaplist) + + def Callit(self, ca): + return self.make_call(PMAPPROC_CALLIT, ca, \ + self.packer.pack_call_args, \ + self.unpacker.unpack_call_result) class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient): @@ -487,6 +538,16 @@ class UDPPortMapperClient(PartialPortMapperClient, RawUDPClient): host, PMAP_PROG, PMAP_VERS, PMAP_PORT) +class BroadcastUDPPortMapperClient(PartialPortMapperClient, \ + RawBroadcastUDPClient): + + def init(self, bcastaddr): + return RawBroadcastUDPClient.init(self, \ + bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT) + + +# Generic clients that find their server through the Port mapper + class TCPClient(RawTCPClient): def init(self, host, prog, vers): @@ -509,6 +570,51 @@ class UDPClient(RawUDPClient): return RawUDPClient.init(self, host, prog, vers, port) +class BroadcastUDPClient(Client): + + def init(self, bcastaddr, prog, vers): + self.pmap = BroadcastUDPPortMapperClient().init(bcastaddr) + self.pmap.set_reply_handler(self.my_reply_handler) + self.prog = prog + self.vers = vers + self.user_reply_handler = None + self.addpackers() + return self + + def close(self): + self.pmap.close() + + def set_reply_handler(self, reply_handler): + self.user_reply_handler = reply_handler + + def set_timeout(self, timeout): + self.pmap.set_timeout(timeout) + + def my_reply_handler(self, reply, fromaddr): + port, res = reply + self.unpacker.reset(res) + result = self.unpack_func() + self.unpacker.done() + self.replies.append((result, fromaddr)) + if self.user_reply_handler is not None: + self.user_reply_handler(result, fromaddr) + + def make_call(self, proc, args, pack_func, unpack_func): + self.packer.reset() + if pack_func: + pack_func(args) + if unpack_func is None: + def dummy(): pass + self.unpack_func = dummy + else: + self.unpack_func = unpack_func + self.replies = [] + packed_args = self.packer.get_buf() + dummy_replies = self.pmap.Callit( \ + (self.prog, self.vers, proc, packed_args)) + return self.replies + + # Server classes # These are not symmetric to the Client classes @@ -657,14 +763,9 @@ class UDPServer(Server): # Simple test program -- dump local portmapper status def test(): - import T - T.TSTART() pmap = UDPPortMapperClient().init('') - T.TSTOP() pmap.Null() - T.TSTOP() list = pmap.Dump() - T.TSTOP() list.sort() for prog, vers, prot, port in list: print prog, vers, @@ -674,7 +775,24 @@ def test(): print port -# Server and client test program. +# Test program for broadcast operation -- dump everybody's portmapper status + +def testbcast(): + import sys + if sys.argv[1:]: + bcastaddr = sys.argv[1] + else: + bcastaddr = '' + def rh(reply, fromaddr): + host, port = fromaddr + print host + '\t' + `reply` + pmap = BroadcastUDPPortMapperClient().init(bcastaddr) + pmap.set_reply_handler(rh) + pmap.set_timeout(5) + replies = pmap.Getport((100002, 1, IPPROTO_UDP, 0)) + + +# Test program for server, with corresponding client # On machine A: python -c 'import rpc; rpc.testsvr()' # On machine B: python -c 'import rpc; rpc.testclt()' A # (A may be == B) @@ -709,12 +827,9 @@ def testclt(): # Client for above server class C(UDPClient): def call_1(self, arg): - self.start_call(1) - self.packer.pack_string(arg) - self.do_call() - reply = self.unpacker.unpack_string() - self.end_call() - return reply + return self.make_call(1, arg, \ + self.packer.pack_string, \ + self.unpacker.unpack_string) c = C().init(host, 0x20000000, 1) print 'making call...' reply = c.call_1('hello, world, ')