# Copyright 2021 National Technology & Engineering Solutions
# of Sandia, LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functionality for communicating using the builtin :mod:`socket` module.
"""
import collections
import contextlib
import hashlib
import logging
import multiprocessing
import numbers
import os
import pickle
import queue
import select
import socket
import ssl
import tempfile
import threading
import time
import traceback
import urllib.parse
import numpy
import pynetstring
from ..interface import Communicator, Tag, tagname
from .connect import NetstringSocket, Timeout, Timer, direct, getLogger, gettls, geturl, listen, message, rendezvous
[docs]
class BrokenPipe(Exception):
"""Raised trying to send to another player that no longer exists."""
pass
[docs]
class Failed(Exception):
"""Used to indicate that a player process raised an exception."""
def __init__(self, exception, traceback):
self.exception = exception
self.traceback = traceback
def __repr__(self):
return f"Failed(exception={self.exception!r})" # pragma: no cover
[docs]
class NotRunning(Exception):
"""Raised calling an operation after the communicator has been freed."""
pass
[docs]
class Revoked(Exception):
"""Raised calling an operation after the communicator has been revoked."""
pass
[docs]
class Terminated(Exception):
"""Used to indicate that a player process terminated unexpectedly without output."""
def __init__(self, exitcode):
self.exitcode = exitcode
def __repr__(self):
return f"Terminated(exitcode={self.exitcode!r})" # pragma: no cover
[docs]
class TryAgain(Exception):
"""Raised when a non-blocking operation would block."""
pass
[docs]
class SocketCommunicator(Communicator):
"""Cicada communicator that uses Python's builtin :mod:`socket` module as the transport layer.
Note
----
Creating a communicator is a collective operation that must be called by
all players that will be members.
Parameters
----------
sockets: :class:`dict` of :class:`~cicada.communicator.socket.connect.NetstringSocket`, required
Dictionary containing sockets that are connected to the other players
and ready to use. The dictionary keys must be the ranks of the other
players, and there must be one socket in the dictionary for every
player except the caller (since players don't need a socket to
communicate with themselves). Note that the communicator world size is
inferred from the size of the dictionary, and the communicator rank
from whichever key doesn't appear in the dictionary.
name: :class:`str`, optional
Human-readable name for this communicator, used for logging and
debugging. Defaults to "world"
timeout: :class:`numbers.Number`
Maximum time to wait for communication to complete, in seconds.
"""
def __init__(self, *, sockets, name="world", timeout=5):
if not isinstance(sockets, dict):
raise ValueError("sockets must be a dict, got {sockets} instead.") # pragma: no cover
for key, socket in sockets.items():
if not isinstance(key, int):
raise ValueError("sockets keys must be ints, got {sockets} instead.") # pragma: no cover
if not isinstance(socket, NetstringSocket):
raise ValueError("sockets values must be NetstringSocket, got {sockets} instead.") # pragma: no cover
world_size = len(sockets) + 1
for index in range(world_size):
if index not in sockets:
rank = index
if not isinstance(name, str):
raise ValueError("name must be a string, got {name} instead.") # pragma: no cover
if not isinstance(timeout, numbers.Number):
raise ValueError(f"timeout must be a number, got {timeout} instead.") # pragma: no cover
# Setup internal state.
self._name = name
self._world_size = world_size
self._rank = rank
self._timeout = timeout
self._revoked = False
self._log = getLogger(__name__, name, rank)
self._players = sockets
self._sent = {}
self._received = {}
# Begin normal operation.
self._running = True
# Setup queues for incoming messages.
self._incoming_queue = queue.SimpleQueue()
self._message_queues = [[] for rank in range(self.world_size)]
self._message_queue_lock = threading.Lock()
# Start queueing incoming messages.
self._queueing_thread = threading.Thread(name=f"Comm {name} player {rank} queueing thread", target=self._queue_messages, daemon=True)
self._queueing_thread.start()
# Start receiving incoming messages.
self._incoming_thread = threading.Thread(name=f"Comm {name} player {rank} incoming thread", target=self._receive_messages, daemon=True)
self._incoming_thread.start()
self._log.info(f"communicator ready.")
def _queue_messages(self):
# Place incoming messages in the correct queue.
while self._running:
# Wait for the next incoming message.
try:
raw_message = self._incoming_queue.get(block=True, timeout=0.1)
except queue.Empty:
continue
# Drop messages with missing attributes or unexpected values.
try:
src, tag, payload = raw_message
except Exception as e: # pragma: no cover
self._log.warning(f"ignoring message: {raw_message} exception: {e}")
return
self._queue_message(src, tag, payload, raw_message)
self._log.debug(f"queueing thread closed.")
def _queue_message(self, src, tag, payload, raw_message):
try:
tag = Tag(tag)
except:
pass
if tag not in self._received:
self._received[tag] = {"messages": 0}
self._received[tag]["messages"] += 1
# Revoke messages don't get queued because they receive special handling.
if tag == Tag.REVOKE:
if not self._revoked:
self._revoked = True
self._log.warning(f"revoked by player {src}")
return
# Insert the message into the correct queue.
with self._message_queue_lock:
self._message_queues[src].append(raw_message)
def _receive_messages(self):
# Parse and queue incoming messages as they arrive.
while self._running:
# Wait for data to arrive from the other players.
ready, _, _ = select.select(self._players.values(), [], [], 0.01)
for src, player in self._players.items():
if player in ready:
try:
player.feed()
except ConnectionResetError as e: # pragma: no cover
# These are pretty common, log them at a lower priority to streamline outputs.
self._log.info(f"exception reading from player {src} socket: {e}")
except Exception as e: # pragma: no cover
self._log.warning(f"exception reading from player {src} socket: {e}")
# Process any messages that were received. Note that
# we iterate over every player, not just the ones that
# were selected above, because there might be a few
# messages left from the startup process.
for src, player in self._players.items():
for raw_message in player.messages():
# Ignore unparsable messages.
try:
tag, payload = pickle.loads(raw_message)
except Exception as e: # pragma: no cover
self._log.warning(f"ignoring unparsable message: {e}")
continue
#self._log.debug(f"received {tagname(tag), payload}")
# Insert the message into the incoming queue.
self._incoming_queue.put((src, tag, payload), block=True, timeout=None)
# The communicator has been freed, so exit the thread.
self._log.debug(f"receive thread closed.")
def _messages(self, *, src=None, tag=None):
"""Return every message that matches the given src and tag.
Parameters
----------
src: :class:`int` or sequence of :class:`int`, optional
Return messages from the given player(s). If :any:`None` (the
default), return matching messages from every player.
tag: :class:`int`, optional
Return messages that match the given tag. If :any:`None` (the
default), return messages with any tag.
Returns
-------
messages: :class:`list` of (src, tag, payload) tuples.
"""
if src is None:
src = self.ranks
if not isinstance(src, list):
src = [src] # pragma: no cover
matches = []
with self._message_queue_lock:
for rank in src:
misses = []
for message in self._message_queues[rank]:
if tag is not None and tag != message[1]:
misses.append(message)
else:
matches.append(message)
self._message_queues[rank] = misses
return matches
def _next_message(self, *, src, tag):
"""Return the next message (if any) matching the given src and tag.
Parameters
----------
src: :class:`int`, required
Return the next matching message from the given player.
tag: :class:`int`, required
Return the next matching message with the given tag.
"""
with self._message_queue_lock:
for index, message in enumerate(self._message_queues[src]):
if tag == message[1]: # tag
del self._message_queues[src][index]
return message
return None
def _wait_next_payload(self, *, src, tag):
"""Return the next payload matching the given src and tag, blocking if necessary.
Parameters
----------
src: :class:`int`, required
Return the next matching message from the given player.
tag: :class:`int`, required
Return the next matching message with the given tag.
"""
timer = Timer(threshold=self._timeout)
while not timer.expired:
message = self._next_message(src=src, tag=tag)
if message is not None:
return message[2] # payload
time.sleep(0.0001)
raise Timeout(f"Tag {tagname(tag)} from player {src} timed-out after {self._timeout}s")
def _require_rank(self, rank):
if not isinstance(rank, numbers.Integral):
raise ValueError("Rank must be an integer.") # pragma: no cover
if rank < 0 or rank >= self._world_size:
raise ValueError(f"Rank must be in the range [0, {self._world_size}).") # pragma: no cover
return int(rank)
def _require_rank_list(self, ranks):
ranks = [self._require_rank(rank) for rank in ranks]
if len(ranks) != len(set(ranks)):
raise ValueError("Duplicate ranks are not allowed.") # pragma: no cover
return ranks
def _require_running(self):
if not self._running:
raise NotRunning(f"Comm {self.name} player {self.rank} is not running.")
def _require_unrevoked(self):
if self._revoked:
raise Revoked(f"Comm {self.name} player {self.rank} has been revoked.")
def _send(self, *, tag, payload, dst):
if dst not in self.ranks:
raise ValueError(f"Unknown destination: {dst}") # pragma: no cover
if tag not in self._sent:
self._sent[tag] = {"messages": 0}
self._sent[tag]["messages"] += 1
# As a special-case, route messages sent to ourself directly to the incoming queue.
if dst == self.rank:
self._incoming_queue.put((self.rank, tag, payload), block=True, timeout=None)
# Otherwise, send the message to the appropriate socket.
else:
try:
raw_message = pickle.dumps((int(tag), payload))
player = self._players[dst]
player.send(raw_message)
except BlockingIOError as e: # pragma: no cover
raise TryAgain(message(self.name, self.rank, f"operation would block sending to player {dst}."))
except BrokenPipeError as e: # pragma: no cover
raise BrokenPipe(message(self.name, self.rank, f"broken pipe sending to player {dst}."))
[docs]
def allgather(self, value):
self._log.debug(f"allgather()")
self._require_unrevoked()
self._require_running()
# Send messages.
for rank in self.ranks:
self._send(tag=Tag.ALLGATHER, payload=value, dst=rank)
# Receive messages.
values = [self._wait_next_payload(src=rank, tag=Tag.ALLGATHER) for rank in self.ranks]
return values
[docs]
def barrier(self):
"""If the implementation returns without raising an exception, then
every player entered the barrier. If an exception is raised then there
are no guarantees about whether every player entered.
"""
self._log.debug(f"barrier()")
self._require_unrevoked()
self._require_running()
# Notify rank 0 that we've entered the barrier.
self._send(tag=Tag.BARRIER, payload=None, dst=0)
if self.rank == 0:
# Wait until every player enters the barrier.
for rank in self.ranks:
self._wait_next_payload(src=rank, tag=Tag.BARRIER)
# Notify every player that it's time to exit the barrier.
for rank in self.ranks:
self._send(tag=Tag.BARRIER, payload=None, dst=rank)
# Wait until we're told to exit.
self._wait_next_payload(src=0, tag=Tag.BARRIER)
[docs]
def broadcast(self, *, src, value):
self._log.debug(f"broadcast(src={src})")
self._require_unrevoked()
self._require_running()
src = self._require_rank(src)
# Broadcast the value to all players.
if self.rank == src:
for rank in self.ranks:
self._send(tag=Tag.BROADCAST, payload=value, dst=rank)
# Receive the broadcast value.
return self._wait_next_payload(src=src, tag=Tag.BROADCAST)
[docs]
@staticmethod
def connect(*, world_size=None, rank=None, address=None, root_address=None, identity=None, trusted=None, name="world", timeout=5, startup_timeout=5):
"""High level function to create a SocketCommunicator.
This is a high level convenience function that can be used to create a
communicator, given just the calling player's address and the address
of the root player. By default, the parameters will be read from
environment variables that can be set permanently by the user, or
temporarily using the :ref:`cicada` command.
Parameters
----------
world_size: :class:`int`, optional
Number of players. Defaults to the value of the CICADA_WORLD_SIZE
environment variable, which is automatically set by the :ref:`cicada` command.
rank: :class:`int`, optional
Rank of the caller. Defaults to the value of the CICADA_RANK
environment variable, which is automatically set by the :ref:`cicada` command.
address: :class:`str`, optional
Listening address of the caller. This must be a URL of the form
`"tcp://{host}:{port}"` for TCP sockets, or `"file:///path/to/foo"`
for Unix domain sockets. Defaults to the value of the
CICADA_ADDRESS environment variable, which is automatically set
by the :ref:`cicada` command.
root_address: :class:`str`, optional
Listening address of the root (rank 0) player. This must be a URL
of the form `"tcp://{host}:{port}"` for TCP sockets, or
`"file:///path/to/foo"` for Unix domain sockets. Defaults to the
value of the CICADA_ROOT_ADDRESS environment variable, which is
automatically set by the :ref:`cicada` command.
identity: :class:`str`, optional
Path to a private key and certificate in PEM format. Defaults to
the value of the CICADA_IDENTITY environment variable, which is
automatically set by the :ref:`cicada` command.
trusted: sequence of :class:`str`, optional
Path to certificates in PEM format. Defaults to
the value of the CICADA_TRUSTED environment variable, which is
automatically set by the :ref:`cicada` command.
name: :class:`str`, optional
Human-readable name for the new communicator. Defaults to "world".
timeout: :class:`numbers.Number`
Communication timeout for the new communicator, in seconds. Defaults to five.
startup_timeout: :class:`numbers.Number`
Maximum time to wait for communicator setup, in seconds. Defaults to five.
Raises
------
:class:`ValueError`
If there are problems with input parameters.
:class:`Timeout`
If `timeout` seconds elapses before all connections are established.
:class:`TokenMismatch`
If every player doesn't provide the same token during startup.
Returns
-------
communicator: :class:`SocketCommunicator`
A fully-initialized communicator, ready for use.
"""
if world_size is None:
world_size = int(os.environ.get("CICADA_WORLD_SIZE"))
if rank is None:
rank = int(os.environ.get("CICADA_RANK"))
if address is None:
address = os.environ.get("CICADA_ADDRESS")
if root_address is None:
root_address = os.environ.get("CICADA_ROOT_ADDRESS")
if identity is None:
identity = os.environ.get("CICADA_IDENTITY", "")
if trusted is None:
trusted = [trust for trust in os.environ.get("CICADA_TRUSTED", "").split(",") if trust]
tls = gettls(identity=identity, trusted=trusted)
timer = Timer(threshold=startup_timeout)
listen_socket = listen(address=address, rank=rank, name=name, timer=timer)
sockets = rendezvous(listen_socket=listen_socket, root_address=root_address, world_size=world_size, rank=rank, timer=timer, tls=tls)
return SocketCommunicator(sockets=sockets, timeout=timeout)
[docs]
def free(self):
# Calling free() multiple times is a no-op.
if not self._running:
return
self._running = False
# Stop receiving incoming messages.
self._incoming_thread.join()
# Stop queueing incoming messages.
self._queueing_thread.join()
# Close connections to the other players.
for player in self._players.values():
player.close()
self._log.info(f"communicator freed.")
[docs]
def gather(self, *, value, dst):
self._log.debug(f"gather(dst={dst})")
self._require_unrevoked()
self._require_running()
dst = self._require_rank(dst)
# Send local data to the destination.
self._send(tag=Tag.GATHER, payload=value, dst=dst)
# Receive data from all ranks.
if self.rank == dst:
# Messages could arrive out of order.
values = [self._wait_next_payload(src=rank, tag=Tag.GATHER) for rank in self.ranks]
return values
return None
[docs]
def gatherv(self, *, src, value, dst):
self._log.debug(f"gatherv(src={src}, dst={dst})")
self._require_unrevoked()
self._require_running()
src = self._require_rank_list(src)
dst = self._require_rank(dst)
# Send local data to the destination.
if self.rank in src:
self._send(tag=Tag.GATHERV, payload=value, dst=dst)
# Receive data from the other players.
if self.rank == dst:
# Messages could arrive out of order.
values = [self._wait_next_payload(src=rank, tag=Tag.GATHERV) for rank in src]
return values
return None
[docs]
def irecv(self, *, src, tag):
self._log.debug(f"irecv(src={src}, tag={tagname(tag)})")
self._require_unrevoked()
self._require_running()
src = self._require_rank(src)
class Result:
def __init__(self, *, communicator, src, tag):
self._communicator = communicator
self._src = src
self._tag = tag
self._payload = None
@property
def is_completed(self):
if self._payload is None:
message = self._communicator._next_message(src=self._src, tag=self._tag)
if message is not None:
src, tag, payload = message
self._payload = payload
return self._payload is not None
@property
def value(self):
if self._payload is None:
raise RuntimeError("Call not completed.") # pragma: no cover
return self._payload
def wait(self):
if self._payload is None:
self._payload = self._communicator._wait_next_payload(src=self._src, tag=self._tag)
return Result(communicator=self, src=src, tag=tag)
[docs]
def isend(self, *, value, dst, tag):
self._log.debug(f"isend(dst={dst}, tag={tagname(tag)})")
self._require_unrevoked()
self._require_running()
dst = self._require_rank(dst)
self._send(tag=tag, payload=value, dst=dst)
# This is safe because we pickle the value before returning; thus,
# nothing the caller can do to the value will have unexpected
# side-effects.
class Result:
@property
def is_completed(self):
return True
def wait(self):
pass
return Result()
@property
def name(self):
"""The name of this communicator, which can be used for logging / debugging.
Returns
-------
name: :class:`str`
"""
return self._name
[docs]
@contextlib.contextmanager
def override(self, *, timeout=None):
"""Temporarily change communicator properties.
Use :meth:`override` to temporarily modify communicator behavior in a with statement::
with communicator.override(timeout=10):
# Do stuff with the new timeout here.
# The timeout will return to its previous value here.
Parameters
----------
timeout: :class:`numbers.Number`, optional
If specified, override the maximum time for communications to complete, in seconds.
Returns
-------
context: :class:`object`
A context manager object that will restore the communicator state when exited.
"""
original_context = {
"timeout": self._timeout,
}
try:
if timeout is not None:
self._timeout = timeout
yield original_context
finally:
if timeout is not None:
self._timeout = original_context["timeout"]
@property
def rank(self):
return self._rank
[docs]
def recv(self, *, src, tag):
self._log.debug(f"recv(src={src}, tag={tagname(tag)})")
self._require_unrevoked()
self._require_running()
src = self._require_rank(src)
return self._wait_next_payload(src=src, tag=tag)
[docs]
def revoke(self):
"""Revoke the current communicator.
Revokes the communicator for this player, and any players able to
receive messages. A revoked communicator cannot be used to perform any
operation other than :meth:`shrink`. Typically, revoke should be
called by any player that detects a communication failure, to initiate
a recovery phase.
"""
self._log.debug(f"revoke()")
self._require_running()
# Notify all players that the communicator is revoked.
for rank in self.ranks:
try:
self._send(tag=Tag.REVOKE, payload=None, dst=rank)
# Ignore BrokenPipe errors, they're to be expected under the circumstances.
except BrokenPipe as e: # pragma: no cover
pass
[docs]
@staticmethod
def run(*, world_size, fn, identities=None, trusted=None, args=(), kwargs={}, family="tcp", name="world", timeout=5, startup_timeout=5, show_traceback=False):
"""Run a function in parallel using sub-processes on the local host.
This method returns when the callback functions finish, returning a
:class:`list` of results from each, in rank order. Special sentinel
classes are used to indicate whether a process raised an exception or
terminated unexpectedly. This is extremely useful for running examples
and regression tests on a single machine.
The given function `fn` *must* accept a communicator as its first argument.
Additional caller-provided positional and keyword arguments are passed
to the function following the communicator.
To perform computation using multiple hosts, you should use
:meth:`~cicada.communicator.socket.SocketCommunicator.connect` and the
:ref:`cicada` command line executable instead.
Parameters
----------
world_size: :class:`int`, required
The number of players that will run the function.
fn: :func:`callable`, required
The function to execute in parallel.
identities: sequence of :class:`str`, optional
Path to files in PEM format each containing a private key and a certificate, one per player.
trusted: sequence of :class:`str`, optional
Path to files in PEM format containing certificates.
args: :class:`tuple`, optional
Positional arguments to pass to `fn` when it is executed.
kwargs: :class:`dict`, optional
Keyword arguments to pass to `fn` when it is executed.
family: :class:`str`, optional
Address family that matches the scheme used in address URLs
elsewhere in the API. Allowed values are "tcp" and "file".
name: :class:`str`, optional
Human-readable name for the communicator created by this function.
Defaults to "world".
timeout: :class:`numbers.Number`, optional
Maximum time to wait for normal communication to complete in
seconds. Defaults to five seconds.
startup_timeout: :class:`numbers.Number`, optional
Maximum time allowed to setup the communicator in seconds.
Defaults to five seconds.
show_traceback: :class:`bool`, optional
If :any:`True`, a traceback will be printed for every player that fails.
Returns
-------
results: :class:`list`
A value returned from `fn` for each player, in rank order. If a
player process terminates unexpectedly, its value will be an
instance of :class:`Terminated`, which can be used to access the
process exit code. If the player process raises a Python
exception, its value will be an instance of :class:`Failed`, which
can be used to access the Python exception and a traceback for the
failing code.
"""
def launch(*, parent_queue, child_queue, rank, fn, identity, trusted, args, kwargs, family, name, timeout, startup_timeout):
# Run the work function.
try:
# Create a socket with a randomly-assigned address.
if family == "file":
fd, path = tempfile.mkstemp()
os.close(fd)
address = f"file://{path}"
elif family == "tcp":
address = "tcp://127.0.0.1"
timer = Timer(threshold=startup_timeout)
listen_socket = listen(name=name, rank=rank, address=address, timer=timer)
address = geturl(listen_socket)
# Send our address to the parent process.
parent_queue.put((rank, address))
# Get all addresses from the parent process.
addresses = child_queue.get()
# Optionally setup TLS.
tls = gettls(identity=identity, trusted=trusted)
sockets=direct(listen_socket=listen_socket, addresses=addresses, rank=rank, name=name, timer=timer, tls=tls)
communicator = SocketCommunicator(sockets=sockets, name=name, timeout=timeout)
result = fn(communicator, *args, **kwargs)
communicator.free()
except Exception as e: # pragma: no cover
result = Failed(e, traceback.format_exc())
# Return results to the parent process.
parent_queue.put((rank, result))
# Setup the multiprocessing context.
context = multiprocessing.get_context(method="fork") # I don't remember why we prefer fork().
# Create queues for IPC.
parent_queue = context.Queue()
child_queue = context.Queue()
# Create per-player processes.
processes = []
for rank in range(world_size):
identity = None if identities is None else identities[rank]
processes.append(context.Process(
name=f"Player {rank}",
target=launch,
kwargs=dict(parent_queue=parent_queue, child_queue=child_queue, rank=rank, fn=fn, identity=identity, trusted=trusted, args=args, family=family, name=name, kwargs=kwargs, timeout=timeout, startup_timeout=startup_timeout),
))
# Start per-player processes.
for process in processes:
process.daemon = True
process.start()
# Collect addresses from every process.
addresses = [None] * world_size
for process in processes:
rank, address = parent_queue.get(block=True)
addresses[rank] = address
# Send addresses to every process.
for process in processes:
child_queue.put(addresses)
# Collect results from processes until every process has completed.
results = []
while any([process.is_alive() for process in processes]):
while True:
try:
rank, result = parent_queue.get(block=False)
results.append((rank, result))
except:
break
time.sleep(0.01)
# Join all processes, just to be safe.
for process in processes:
process.join()
# Now that every process has exited, collect any remaining results.
while True:
try:
rank, result = parent_queue.get(block=False)
results.append((rank, result))
except:
break
# Return the output of each rank, in rank order, with a sentinel object for missing outputs.
output = [Terminated(process.exitcode) for process in processes]
for rank, result in results:
output[rank] = result
# Log the results for each player.
log = logging.getLogger(__name__)
for rank, result in enumerate(output):
if isinstance(result, Failed):
log.error(f"Comm {name} player {rank} failed: {result.exception!r}")
elif isinstance(result, Exception):
log.error(f"Comm {name} player {rank} failed: {result!r}")
else:
log.info(f"Comm {name} player {rank} result: {result}")
# Print a traceback for players that failed.
if show_traceback: # pragma: no cover
for rank, result in enumerate(output):
if isinstance(result, Failed):
log.error("*" * 80)
log.error(f"Comm {name} player {rank} traceback:")
log.error(result.traceback)
return output
[docs]
@staticmethod
def run_forever(*, world_size, fn, identities=None, trusted=None, args=(), kwargs={}, family="tcp", name="world", timeout=5, startup_timeout=5):
"""Execute a long-running function in parallel using sub-processes on the local host.
This method returns immediately after networking has been setup and the
callback function begins executing, returning a :class:`list` of
network addresses and a :class:`list` of corresponding processes. This
is particularly useful for running "MPC-as-a-service" applications on
the local machine - the caller can use the addresses to communicate with
the individual processes.
The given function *must* accept a listening socket and a communicator
as its first two arguments. Additional caller-provided positional and
keyword arguments are passed to the function after the socket and communicator.
To run a service with multiple hosts, you should use
:meth:`~cicada.communicator.socket.SocketCommunicator.connect` and the
:ref:`cicada` command line executable instead.
Parameters
----------
world_size: :class:`int`, required
The number of players that will run the function.
fn: :func:`callable`, required
The function to execute in parallel.
identities: sequence of :class:`str`, optional
Path to files in PEM format each containing a private key and a certificate, one per player.
trusted: sequence of :class:`str`, optional
Path to files in PEM format containing certificates.
args: :class:`tuple`, optional
Positional arguments to pass to `fn` when it is executed.
kwargs: :class:`dict`, optional
Keyword arguments to pass to `fn` when it is executed.
family: :class:`str`, optional
Address family that matches the scheme used in address URLs
elsewhere in the API. Allowed values are "tcp" and "file".
name: :class:`str`, optional
Human-readable name for the communicator created by this function.
Defaults to "world".
timeout: :class:`numbers.Number`, optional
Maximum time to wait for normal communication to complete in
seconds. Defaults to five seconds.
startup_timeout: :class:`numbers.Number`, optional
Maximum time allowed to setup the communicator in seconds.
Defaults to five seconds.
Returns
-------
addresses: :class:`list` of :class:`str`
A listening address for each player, in rank order.
processes: :class:`list` of :class:`multiprocessing.Process`
An instance of :class:`multiprocessing.Process` for each player, in rank order.
"""
def launch(*, parent_queue, child_queue, rank, fn, identity, trusted, args, kwargs, family, name, timeout, startup_timeout):
# Run the work function.
try:
# Create a socket with a randomly-assigned address.
if family == "file":
fd, path = tempfile.mkstemp()
os.close(fd)
address = f"file://{path}"
elif family == "tcp":
address = "tcp://127.0.0.1"
timer = Timer(threshold=startup_timeout)
listen_socket = listen(name=name, rank=rank, address=address, timer=timer)
address = geturl(listen_socket)
# Send our address to the parent process.
parent_queue.put((rank, address))
# Get all addresses from the parent process.
addresses = child_queue.get()
# Optionally setup TLS.
tls = gettls(identity=identity, trusted=trusted)
sockets=direct(listen_socket=listen_socket, addresses=addresses, rank=rank, name=name, timer=timer, tls=tls)
communicator = SocketCommunicator(sockets=sockets, name=name, timeout=timeout)
parent_queue.put("ready")
result = fn(listen_socket, communicator, *args, **kwargs)
communicator.free()
except Exception as e: # pragma: no cover
result = Failed(e, traceback.format_exc())
# Setup the multiprocessing context.
context = multiprocessing.get_context(method="fork") # I don't remember why we prefer fork().
# Create queues for IPC.
parent_queue = context.Queue()
child_queue = context.Queue()
# Create per-player processes.
processes = []
for rank in range(world_size):
identity = None if identities is None else identities[rank]
processes.append(context.Process(
name=f"Player {rank}",
target=launch,
kwargs=dict(parent_queue=parent_queue, child_queue=child_queue, rank=rank, fn=fn, identity=identity, trusted=trusted, args=args, family=family, name=name, kwargs=kwargs, timeout=timeout, startup_timeout=startup_timeout),
))
# Start per-player processes.
for process in processes:
process.daemon = True
process.start()
# Collect addresses from every process.
addresses = [None] * world_size
for process in processes:
rank, address = parent_queue.get(block=True)
addresses[rank] = address
# Send addresses to every process.
for process in processes:
child_queue.put(addresses)
# Wait until networking has been established.
for process in processes:
parent_queue.get(block=True)
return addresses, processes
[docs]
def scatter(self, *, src, values):
self._log.debug(f"scatter(src={src})")
self._require_unrevoked()
self._require_running()
src = self._require_rank(src)
if self.rank == src:
values = [value for value in values]
if len(values) != self._world_size:
raise ValueError(f"Expected {self._world_size} values, received {len(values)}.") # pragma: no cover
# Send data to every player.
if self.rank == src:
for value, rank in zip(values, self.ranks):
self._send(tag=Tag.SCATTER, payload=value, dst=rank)
# Receive data from the sender.
return self._wait_next_payload(src=src, tag=Tag.SCATTER)
[docs]
def scatterv(self, *, src, values, dst):
self._log.debug(f"scatterv(src={src}, dst={dst})")
self._require_unrevoked()
self._require_running()
src = self._require_rank(src)
dst = self._require_rank_list(dst)
if self.rank == src:
values = [value for value in values]
if len(values) != len(dst):
raise ValueError("values must contain one value instance for every destination player.") # pragma: no cover
# Send data to every destination.
if self.rank == src:
for value, rank in zip(values, dst):
self._send(tag=Tag.SCATTERV, payload=value, dst=rank)
# Receive data from the sender.
if self.rank in dst:
return self._wait_next_payload(src=src, tag=Tag.SCATTERV)
return None
[docs]
def send(self, *, value, dst, tag):
self._log.debug(f"send(dst={dst}, tag={tagname(tag)})")
self._require_unrevoked()
self._require_running()
dst = self._require_rank(dst)
self._send(tag=tag, payload=value, dst=dst)
[docs]
def shrink(self, *, name, identity=None, trusted=None, shrink_timeout=5, startup_timeout=5, timeout=5):
"""Create a new communicator containing surviving players.
This method should be called as part of a failure-recovery phase by as
many players as possible (ideally, every player still running). It
will attempt to rendezvous with the other players and return a new
communicator, but the process could fail and raise an exception
instead. In that case it is up to the application to decide how to
proceed.
Parameters
----------
name: :class:`str`, required
New communicator name.
identity: :class:`str`, optional
Path to a private key and certificate in PEM format that will identify the current player.
trusted: sequence of :class:`str`, optional
Path to certificates in PEM format that will identify the other players in the new communicator.
shrink_timeout: :class:`numbers.Number`, optional
Maximum amount of time to spend identifying remaining members.
startup_timeout: :class:`numbers.Number`, optional
Maximum time to wait for communicator_setup, in seconds.
timeout: :class:`numbers.Number`, optional
Maximum time to wait for communication, in seconds.
Returns
-------
communicator: :class:`SocketCommunicator`
New communicator containing the remaining players.
oldranks: sequence of :class:`int`
Previous ranks of the remaining players, in rank order.
Use this if you need to know the original rank of any member of
`communicator`.
"""
self._log.debug(f"shrink()")
self._require_running()
##################################################################################
# Phase 1: Determine which players are still alive.
# By default, we assume that no-one is alive.
remaining_ranks = set()
tls = gettls(identity=identity, trusted=trusted)
timer = Timer(threshold=shrink_timeout)
while not timer.expired:
# Send beacons to the other players (including ourself).
for rank in self.ranks:
try:
self._send(tag=Tag.BEACON, payload=None, dst=rank)
# Ignore broken pipe errors, they're to be expected under the circumstances.
except BrokenPipe: # pragma: no cover
pass
# Get any received beacons (including our own).
beacons = self._messages(src=None, tag=Tag.BEACON)
# If we received a beacon, the player is alive.
for src, tag, payload in beacons:
remaining_ranks.add(src)
# If every player is accounted for, we can terminate early.
if remaining_ranks == set(self.ranks):
break
time.sleep(0.5)
##################################################################################
# Phase 2: Use the list of remaining players to generate new network parameters.
# Sort the remaining ranks; the lowest rank will become rank 0 in the new communicator.
remaining_ranks = sorted(list(remaining_ranks))
# Generate a token based on a hash of the remaining ranks that we can
# use to ensure that every player is in agreement on who's remaining.
token = hashlib.sha3_256()
for rank in remaining_ranks:
token.update(f"rank-{rank}".encode("utf8"))
token = token.hexdigest()
# Generate new connection information.
world_size=len(remaining_ranks)
rank = remaining_ranks.index(self.rank)
# Generate a random new listening address.
for sock in self._players.values():
address = urllib.parse.urlparse(geturl(sock))
break
if address.scheme == "file":
fd, path = tempfile.mkstemp()
os.close(fd)
address = f"file://{path}"
elif address.scheme == "tcp":
address = f"tcp://{address.hostname}"
else:
raise ValueError(f"Comm {self.name} player {self.rank} cannot parse unknown address scheme: {address.scheme}") # pragma: no cover
# Create a new listening socket and update the address to match
timer = Timer(threshold=startup_timeout)
listen_socket = listen(address=address, rank=rank, name=name, timer=timer)
address = geturl(listen_socket)
##################################################################################
# Phase 3: Send new network parameters to the remaining players.
# Send new connection info to the remaining players.
if self.rank == remaining_ranks[0]:
for remaining_rank in remaining_ranks:
self._send(tag=Tag.SHRINK, payload=address, dst=remaining_rank)
root_address = self._wait_next_payload(src=remaining_ranks[0], tag=Tag.SHRINK)
# Return a new communicator.
sockets=rendezvous(listen_socket=listen_socket, root_address=root_address, world_size=world_size, rank=rank, name=name, token=token, timer=timer, tls=tls)
return SocketCommunicator(sockets=sockets, name=name, timeout=timeout), remaining_ranks
[docs]
def split(self, *, name, identity=None, trusted=None, timeout=5, startup_timeout=5):
"""Return a new communicator with the given name.
If players specify different names - which can be any :class:`str`
- then a new communicator will be created for each unique name, with
those players as members. If a player supplies a name of `None`,
they will not be a part of any communicator, and this method will
return `None`.
.. note::
This is a collective operation that *must* be called by every member
of the communicator, even if they aren't going to be a member of any
of the resulting groups!
Parameters
----------
name: :class:`str` or :any:`None`, required
Communicator name, or `None`.
identity: :class:`str`, optional
Path to a private key and certificate in PEM format that will identify the current player.
trusted: sequence of :class:`str`, optional
Path to certificates in PEM format that will identify the other players in the new communicator.
timeout: :class:`numbers.Number`, optional
Maximum time to wait for communication, in seconds.
startup_timeout: :class:`numbers.Number`, optional
Maximum time to wait for communicator setup, in seconds.
Returns
-------
communicator: a new :class:`SocketCommunicator` instance, or `None`
"""
self._log.debug(f"split(name={name})")
self._require_unrevoked()
self._require_running()
if not isinstance(name, (str, type(None))):
raise ValueError(f"Comm {self.name} player {self.rank} name must be a string or None, got {name} instead.") # pragma: no cover
tls = gettls(identity=identity, trusted=trusted)
timer = Timer(threshold=startup_timeout)
# Generate a random new listening address.
address = None
if name is not None:
for sock in self._players.values():
address = urllib.parse.urlparse(geturl(sock))
break
if address.scheme == "file":
fd, path = tempfile.mkstemp()
os.close(fd)
address = f"file://{path}"
elif address.scheme == "tcp":
address = f"tcp://{address.hostname}"
else:
raise ValueError(f"Comm {self.name} player {self.rank} cannot split unknown address scheme: {address.scheme}") # pragma: no cover
# Create a new listening socket and update the address to match.
if address is not None:
listen_socket = listen(address=address, rank=self.rank, name=self.name, timer=timer)
address = geturl(listen_socket)
# Send names and addresses to rank 0.
addresses = self.gather(value=(name, address), dst=0)
# Compute new communicator parameters.
players = None
if self.rank == 0:
groups = collections.defaultdict(list)
for rank, (name, address) in enumerate(addresses):
groups[name].append((rank, address))
players = []
for rank, (name, address) in enumerate(addresses):
group = sorted(groups[name])
ranks, addresses = zip(*group)
players.append((name, len(group), ranks.index(rank), addresses))
# Send new connection info to all players.
group_name, group_world_size, group_rank, group_addresses = self.scatter(src=0, values=players)
# Return a new communicator.
if group_name is not None:
sockets = direct(listen_socket=listen_socket, addresses=group_addresses, rank=group_rank, name=group_name, timer=timer, tls=tls)
return SocketCommunicator(sockets=sockets, name=group_name, timeout=timeout)
return None
@property
def stats(self):
"""Nested dict containing communication statistics for logging / debugging."""
results = {
"player": {},
"tag": {},
"total": { "sent": {"bytes": 0, "messages": 0}, "received": {"bytes": 0, "messages": 0}},
}
for tag, sent in self._sent.items():
if tag not in results["tag"]:
results["tag"][tag] = {}
results["tag"][tag]["sent"] = sent
for tag, received in self._received.items():
if tag not in results["tag"]:
results["tag"][tag] = {}
results["tag"][tag]["received"] = received
for rank, player in self._players.items():
stats = player.stats
results["player"][rank] = stats
results["total"]["sent"]["bytes"] += stats["sent"]["bytes"]
results["total"]["sent"]["messages"] += stats["sent"]["messages"]
results["total"]["received"]["bytes"] += stats["received"]["bytes"]
results["total"]["received"]["messages"] += stats["received"]["messages"]
return results
@property
def timeout(self):
"""Amount of time allowed for communications to complete, in seconds.
Returns
-------
timeout: :class:`numbers.Number`.
The timeout in seconds.
"""
return self._timeout
@timeout.setter
def timeout(self, timeout):
self._timeout = timeout
@property
def world_size(self):
return self._world_size