import errno import os import selectors import signal import socket import struct import sys import threading import warnings from . import connection from . import process from .context import reduction from . import semaphore_tracker from . import spawn from . import util __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', 'set_forkserver_preload'] # # # MAXFDS_TO_SEND = 256 SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t # # Forkserver class # class ForkServer(object): def __init__(self): self._forkserver_address = None self._forkserver_alive_fd = None self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] def set_forkserver_preload(self, modules_names): '''Set list of module names to try to load in forkserver process.''' if not all(type(mod) is str for mod in self._preload_modules): raise TypeError('module_names must be a list of strings') self._preload_modules = modules_names def get_inherited_fds(self): '''Return list of fds inherited from parent process. This returns None if the current process was not started by fork server. ''' return self._inherited_fds def connect_to_new_process(self, fds): '''Request forkserver to create a child process. Returns a pair of fds (status_r, data_w). The calling process can read the child process's pid and (eventually) its returncode from status_r. The calling process should write to data_w the pickled preparation and process data. ''' self.ensure_running() if len(fds) + 4 >= MAXFDS_TO_SEND: raise ValueError('too many fds') with socket.socket(socket.AF_UNIX) as client: client.connect(self._forkserver_address) parent_r, child_w = os.pipe() child_r, parent_w = os.pipe() allfds = [child_r, child_w, self._forkserver_alive_fd, semaphore_tracker.getfd()] allfds += fds try: reduction.sendfds(client, allfds) return parent_r, parent_w except: os.close(parent_r) os.close(parent_w) raise finally: os.close(child_r) os.close(child_w) def ensure_running(self): '''Make sure that a fork server is running. This can be called from any process. Note that usually a child process will just reuse the forkserver started by its parent, so ensure_running() will do nothing. ''' with self._lock: semaphore_tracker.ensure_running() if self._forkserver_alive_fd is not None: return cmd = ('from multiprocessing.forkserver import main; ' + 'main(%d, %d, %r, **%r)') if self._preload_modules: desired_keys = {'main_path', 'sys_path'} data = spawn.get_preparation_data('ignore') data = {x: y for x, y in data.items() if x in desired_keys} else: data = {} with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') listener.bind(address) os.chmod(address, 0o600) listener.listen() # all client processes own the write end of the "alive" pipe; # when they all terminate the read end becomes ready. alive_r, alive_w = os.pipe() try: fds_to_pass = [listener.fileno(), alive_r] cmd %= (listener.fileno(), alive_r, self._preload_modules, data) exe = spawn.get_executable() args = [exe] + util._args_from_interpreter_flags() args += ['-c', cmd] pid = util.spawnv_passfds(exe, args, fds_to_pass) except: os.close(alive_w) raise finally: os.close(alive_r) self._forkserver_address = address self._forkserver_alive_fd = alive_w # # # def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): '''Run forkserver.''' if preload: if '__main__' in preload and main_path is not None: process.current_process()._inheriting = True try: spawn.import_main_path(main_path) finally: del process.current_process()._inheriting for modname in preload: try: __import__(modname) except ImportError: pass util._close_stdin() sig_r, sig_w = os.pipe() os.set_blocking(sig_r, False) os.set_blocking(sig_w, False) def sigchld_handler(*_unused): # Dummy signal handler, doesn't do anything pass # letting SIGINT through avoids KeyboardInterrupt tracebacks # unblocking SIGCHLD allows the wakeup fd to notify our event loop handlers = { signal.SIGCHLD: sigchld_handler, signal.SIGINT: signal.SIG_DFL, } old_handlers = {sig: signal.signal(sig, val) for (sig, val) in handlers.items()} # calling os.write() in the Python signal handler is racy signal.set_wakeup_fd(sig_w) # map child pids to client fds pid_to_fd = {} with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ selectors.DefaultSelector() as selector: _forkserver._forkserver_address = listener.getsockname() selector.register(listener, selectors.EVENT_READ) selector.register(alive_r, selectors.EVENT_READ) selector.register(sig_r, selectors.EVENT_READ) while True: try: while True: rfds = [key.fileobj for (key, events) in selector.select()] if rfds: break if alive_r in rfds: # EOF because no more client processes left assert os.read(alive_r, 1) == b'', "Not at EOF?" raise SystemExit if sig_r in rfds: # Got SIGCHLD os.read(sig_r, 65536) # exhaust while True: # Scan for child processes try: pid, sts = os.waitpid(-1, os.WNOHANG) except ChildProcessError: break if pid == 0: break child_w = pid_to_fd.pop(pid, None) if child_w is not None: if os.WIFSIGNALED(sts): returncode = -os.WTERMSIG(sts) else: if not os.WIFEXITED(sts): raise AssertionError( "Child {0:n} status is {1:n}".format( pid,sts)) returncode = os.WEXITSTATUS(sts) # Send exit code to client process try: write_signed(child_w, returncode) except BrokenPipeError: # client vanished pass os.close(child_w) else: # This shouldn't happen really warnings.warn('forkserver: waitpid returned ' 'unexpected pid %d' % pid) if listener in rfds: # Incoming fork request with listener.accept()[0] as s: # Receive fds from client fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) if len(fds) > MAXFDS_TO_SEND: raise RuntimeError( "Too many ({0:n}) fds to send".format( len(fds))) child_r, child_w, *fds = fds s.close() pid = os.fork() if pid == 0: # Child code = 1 try: listener.close() selector.close() unused_fds = [alive_r, child_w, sig_r, sig_w] unused_fds.extend(pid_to_fd.values()) code = _serve_one(child_r, fds, unused_fds, old_handlers) except Exception: sys.excepthook(*sys.exc_info()) sys.stderr.flush() finally: os._exit(code) else: # Send pid to client process try: write_signed(child_w, pid) except BrokenPipeError: # client vanished pass pid_to_fd[pid] = child_w os.close(child_r) for fd in fds: os.close(fd) except OSError as e: if e.errno != errno.ECONNABORTED: raise def _serve_one(child_r, fds, unused_fds, handlers): # close unnecessary stuff and reset signal handlers signal.set_wakeup_fd(-1) for sig, val in handlers.items(): signal.signal(sig, val) for fd in unused_fds: os.close(fd) (_forkserver._forkserver_alive_fd, semaphore_tracker._semaphore_tracker._fd, *_forkserver._inherited_fds) = fds # Run process object received over pipe code = spawn._main(child_r) return code # # Read and write signed numbers # def read_signed(fd): data = b'' length = SIGNED_STRUCT.size while len(data) < length: s = os.read(fd, length - len(data)) if not s: raise EOFError('unexpected EOF') data += s return SIGNED_STRUCT.unpack(data)[0] def write_signed(fd, n): msg = SIGNED_STRUCT.pack(n) while msg: nbytes = os.write(fd, msg) if nbytes == 0: raise RuntimeError('should not get here') msg = msg[nbytes:] # # # _forkserver = ForkServer() ensure_running = _forkserver.ensure_running get_inherited_fds = _forkserver.get_inherited_fds connect_to_new_process = _forkserver.connect_to_new_process set_forkserver_preload = _forkserver.set_forkserver_preload