Fault Detection and Recovery

One of Cicada’s most distinguishing features is that it supports fault tolerance at the application level: put simply, Cicada raises exceptions when failures occur, where most MPC tools simply stop executing. This allows Cicada MPC programs to keep running, react to failures, and continue computation. Of course, there are many possible ways that an application can react to a failure, and many engineering tradeoffs involved in developing a fault tolerance strategy, which is why Cicada leaves it to the application developer to decide which approach is appropriate. We don’t want to suggest that designing for fault recovery is easy, but Cicada provides a solid set of tools to implement a fault recovery policy.

To begin, let’s setup a naïve calculation without fault tolerance. In the following example, four players take turns using Shamir Secret Sharing to increment a running total:

[1]:
import logging

import numpy

from cicada.communicator import SocketCommunicator
from cicada.logging import Logger
from cicada.shamir import ShamirBasicProtocolSuite

logging.basicConfig(level=logging.INFO)

def main(communicator):
    # One-time initialization.
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in range(0, 8):
        # Increment the total.
        contributor = iteration % communicator.world_size
        increment = numpy.array(1) if communicator.rank == contributor else None
        increment_share = protocol.share(src=contributor, secret=increment, shape=())
        total_share = protocol.add(total_share, increment_share)

        # Print the current total.
        total = protocol.reveal(total_share)
        log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
INFO:root:Iteration 1 comm world total: 2.0
INFO:root:Iteration 2 comm world total: 3.0
INFO:root:Iteration 3 comm world total: 4.0
INFO:root:Iteration 4 comm world total: 5.0
INFO:root:Iteration 5 comm world total: 6.0
INFO:root:Iteration 6 comm world total: 7.0
INFO:root:Iteration 7 comm world total: 8.0

For simplicity, every player in this example increments the total by the same hard-coded value - one - so that after eight iterations the total is eight. It should be clear to you that this isn’t particularly useful nor privacy-preserving; we assume that for a real problem, each player would increment the total with some meaningful private value, such as a count of events detected since the previous iteration.

Warning

Because it’s revealed at the end of each iteration, a malicious player could easily keep track of the running total and reveal the other players’ secret increments using subtraction. As always, we use logging strictly for pedagogical purposes - be sure you aren’t leaking secrets when you deploy your own code!

Now that we have a working example, let’s simulate a fault by killing player 3 partway through the program:

[2]:
import os
import signal

def main(communicator):
    # One-time initialization.
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in range(0, 8):
        # Simulate the unexpected death of player 3.
        if iteration == 4 and communicator.rank == 3:
            os.kill(os.getpid(), signal.SIGKILL)

        # Increment the total.
        contributor = iteration % communicator.world_size
        increment = numpy.array(1) if communicator.rank == contributor else None
        increment_share = protocol.share(src=contributor, secret=increment, shape=())
        total_share = protocol.add(total_share, increment_share)

        # Print the current total.
        total = protocol.reveal(total_share)
        log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
INFO:root:Iteration 1 comm world total: 2.0
INFO:root:Iteration 2 comm world total: 3.0
INFO:root:Iteration 3 comm world total: 4.0
ERROR:cicada.communicator.socket:Comm world player 0 failed: Timeout('Tag GATHERV from player 3 timed-out after 5s')
ERROR:cicada.communicator.socket:Comm world player 1 failed: Timeout('Tag GATHERV from player 0 timed-out after 5s')
ERROR:cicada.communicator.socket:Comm world player 2 failed: Timeout('Tag GATHERV from player 0 timed-out after 5s')
ERROR:cicada.communicator.socket:Comm world player 3 failed: Terminated(exitcode=-9)

… take the time to examine the log output closely: you’ll see that player 0 raises an exception after waiting for five seconds without hearing from from player 3 (as a side-effect, players 1 and 2 also time out waiting for player 0, which is waiting for player 3; these kinds of secondary timeouts are very common). This behavior (communicator methods that raise exceptions) is the primary mechanism that Cicada uses to notify the application when a fault has occurred. To create a fault-tolerant application, you must be prepared to handle exceptions any time you implicitly or explicitly use a communicator.

Important

You should understand that without communication, there is no fault detection. In the general case Cicada player processes are meant to execute on separate hosts. Imagine that you are a player process: how do you know that a player process on another host has failed? The information that a failure has occurred must reach you somehow, either explicitly because some other process informed you that the player failed, or implicitly because you didn’t hear from the other player within a set period of time. Either approach implies communication, which is why fault detection is inextricably tied to communicators.

So, the first step in adding fault tolerance to our program is add exception handling:

[3]:
def main(communicator):
    # One-time initialization.
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in range(0, 8):
        # Simulate the unexpected death of player 3.
        if iteration == 4 and communicator.rank == 3:
            os.kill(os.getpid(), signal.SIGKILL)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            log.sync = False
            log.error(f"Iteration {iteration} comm {communicator.name} player {communicator.rank} exception: {e}")
            break

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
INFO:root:Iteration 1 comm world total: 2.0
INFO:root:Iteration 2 comm world total: 3.0
INFO:root:Iteration 3 comm world total: 4.0
ERROR:root:Iteration 4 comm world player 0 exception: Tag GATHERV from player 3 timed-out after 5s
ERROR:root:Iteration 4 comm world player 2 exception: Tag GATHERV from player 0 timed-out after 5s
ERROR:root:Iteration 4 comm world player 1 exception: Tag GATHERV from player 0 timed-out after 5s
ERROR:cicada.communicator.socket:Comm world player 3 failed: Terminated(exitcode=-9)

Note that we’ve put all of our iterative computation code in the try block, because any communication of any kind could raise an exception - we have to assume that any method of Logger and ShamirBasicProtocolSuite could use the communicator at any time.

Important

We turn-off coordinated logging before using the Logger in our exception block using:

log.sync = False

since coordinated logging requires communication with every player and would fail otherwise; see Logging for details.

Now that we’re reacting to the exception, what should we do? We could simply report the error and shut-down gracefully as in the example above, which is already an improvement over other MPC tools. But let’s say that we don’t want our secret total to be lost, even when a player dies. Is it gone forever? Fortunately, the answer is “no” - because we used Shamir Secret Sharing with four players and a minimum threshold of two, we should be able to continue iterating with our remaining three players. But how can we call methods like reveal() using just a subset of the original four players? Fundamentally, our original communicator was created with four players and the number of players in a communicator is never allowed to change (this is by design because it provides important benefits that we won’t go into here). When we call reveal(), the protocol object expects to communicate with all four players, but can’t because one of the players is no longer responding.

This may seem like an impasse, but the solution is quite elegant: we simply stop using the old communicator and replace it with a new one that represents the remaining players. Happily, SocketCommunicator provides a pair of methods to make this easy. Let’s put them to use for our failure recovery:

[4]:
def main(communicator):
    # One-time initialization.
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in range(0, 8):
        # Simulate the unexpected death of player 3.
        if iteration == 4 and communicator.rank == 3:
            os.kill(os.getpid(), signal.SIGKILL)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            log.sync = False
            log.error(f"Iteration {iteration} comm {communicator.name} player {communicator.rank} exception: {e}")

            # Something went wrong.  Revoke the current communicator to
            # ensure that all players are aware of it.
            communicator.revoke()

            # Obtain a new communicator that contains the remaining players.
            newcommunicator, oldranks = communicator.shrink(name="fallback")

            # Recreate the logger since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            log = Logger(logging.getLogger(), newcommunicator)
            log.info(f"Iteration {iteration} shrank comm {communicator.name} with {communicator.world_size} players to comm {newcommunicator.name} with {newcommunicator.world_size} players.", src=0)

            # Recreate the protocol since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

            # Cleanup the old communicator.
            communicator.free()
            communicator = newcommunicator

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
INFO:root:Iteration 1 comm world total: 2.0
INFO:root:Iteration 2 comm world total: 3.0
INFO:root:Iteration 3 comm world total: 4.0
ERROR:root:Iteration 4 comm world player 0 exception: Tag GATHERV from player 3 timed-out after 5s
ERROR:root:Iteration 4 comm world player 1 exception: Tag GATHERV from player 0 timed-out after 5s
ERROR:root:Iteration 4 comm world player 2 exception: Tag GATHERV from player 0 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world player 0 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 2 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 1 revoked by player 0
INFO:root:Iteration 4 shrank comm world with 4 players to comm fallback with 3 players.
INFO:root:Iteration 5 comm fallback total: 6.0
INFO:root:Iteration 6 comm fallback total: 7.0
INFO:root:Iteration 7 comm fallback total: 8.0
ERROR:cicada.communicator.socket:Comm world player 3 failed: Terminated(exitcode=-9)

There’s a lot happening here, so we’ll go over it line-by-line. First, we call revoke() immediately after detecting an error:

communicator.revoke()

Revoking a communicator allows any player who detects a problem to get the immediate attention of the remaining players: once revoked, a communicator can’t be used for any other purpose, because any attempt to use a revoked communicator will raise a Revoked exception. In our case, this ensures that every player will eventually execute the failure recovery block. Notice in the program output above that the communicator is actually revoked multiple times by different players, each as they enter their failure recovery block.

Note

You may wonder why revocation is necessary for the current program, since - as we pointed out earlier - players 1 and 2 experience secondary timeouts waiting for player 0, who times-out waiting for player 3. The answer is that the secondary timeouts are incidental and - depending on random quirks of scheduling and the type of communication that the program is using at the time - aren’t guaranteed to happen in all cases. Thus, you should always use revoke() to interrupt an MPC program and put all of the players into a failure recovery mode.

Next, we call shrink(), which returns a new communicator with just the remaining players:

newcommunicator, oldranks = communicator.shrink(name="fallback")

Note that shrink() returns a new instance of SocketCommunicator and a list of ranks from the old communicator in new-rank-order. Recall that player ranks are always numbered contiguously from zero (see Multiple Communicators for more detail), which means that for the new communicator, some player ranks will differ from the originals; oldranks is provided so that applications can map between old and new ranks (and perform any related bookkeeping).

Next, we recreate our Logger object from scratch, because it still references the original communicator, which is revoked and can no longer be used:

log = Logger(logging.getLogger(), newcommunicator)

Similarly, we recreate the ShamirBasicProtocolSuite object for the same reason. Note that we are also making use of the oldranks mapping returned by shrink() to copy over some metadata from the original protocol object - this is necessary when working with the Shamir Secret Sharing protocols, so that the new protocol objects can work with secret shares generated by the originals:

protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

Finally, we free the old communicator and replace it with the new, completing the recovery:

communicator.free()
communicator = newcommunicator

When we run the program now, we still get the failure, but the three remaining players are able to continue working without data loss. Keep in mind that, due to random quirks of timing and scheduling, the final total when the fault tolerant program ends may-or-may-not be eight, depending on whether the failure manifests before or after incrementing the total in iteration four. You should convince yourself that this indeterminacy is not a problem, since - just as in real life - players can fail at any time, and whether a given player’s contribution to the total is “lost” is strictly a matter of random chance.

Let’s continue to refine this example. Hard-coding a fault for a specific player at a specific time isn’t very true to life, so we’ll introduce a more realistic failure model. We will assume that, for each player, there is a fixed probability that a failure will occur during each iteration of the program, which we can capture with a generator function:

[5]:
import os

def random_failure(pfail, seed):
    generator = numpy.random.default_rng(seed=seed)
    while True:
        if generator.uniform() <= pfail:
            os.kill(os.getpid(), signal.SIGKILL)
        yield

Tip

Python generators are functions that behave like iterators. See Generators for details.

Now, we can replace our hard-coded failure logic. We’ll assign every player a 5% chance of failure per iteration using a random seed based on their rank, so that different players don’t fail in unison, except by chance:

[6]:
def main(communicator):
    # One-time initialization.
    failure = random_failure(pfail=0.05, seed=42 + communicator.rank)
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in range(0, 8):
        # Allow failures to occur at random.
        next(failure)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            log.sync = False
            log.error(f"Iteration {iteration} comm {communicator.name} player {communicator.rank} exception: {e}")

            # Something went wrong.  Revoke the current communicator to
            # ensure that all players are aware of it.
            communicator.revoke()

            # Obtain a new communicator that contains the remaining players.
            newcommunicator, oldranks = communicator.shrink(name="fallback")

            # Recreate the logger since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            log = Logger(logging.getLogger(), newcommunicator)
            log.info(f"Iteration {iteration} shrank comm {communicator.name} with {communicator.world_size} players to comm {newcommunicator.name} with {newcommunicator.world_size} players.", src=0)

            # Recreate the protocol since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

            # Cleanup the old communicator.
            communicator.free()
            communicator = newcommunicator

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
ERROR:root:Iteration 1 comm world player 0 exception: Tag SCATTER from player 1 timed-out after 5s
ERROR:root:Iteration 1 comm world player 3 exception: Tag SCATTER from player 1 timed-out after 5s
ERROR:root:Iteration 1 comm world player 2 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world player 0 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 3 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 2 revoked by player 0
INFO:root:Iteration 1 shrank comm world with 4 players to comm fallback with 3 players.
INFO:root:Iteration 2 comm fallback total: 2.0
INFO:root:Iteration 3 comm fallback total: 3.0
INFO:root:Iteration 4 comm fallback total: 4.0
INFO:root:Iteration 5 comm fallback total: 5.0
INFO:root:Iteration 6 comm fallback total: 6.0
INFO:root:Iteration 7 comm fallback total: 7.0
ERROR:cicada.communicator.socket:Comm world player 1 failed: Terminated(exitcode=-9)

Note that:

next(failure)

takes the place of our hard-coded logic. Each time it’s called, there’s a 5% chance that it will kill the player. As we can see, the program still works and the new failure logic kills a different player on a different iteration, but there’s still only a single failure. Let’s let the program run for more than eight iterations and see how it handles multiple failures:

[7]:
import itertools

def main(communicator):
    # One-time initialization.
    failure = random_failure(pfail=0.05, seed=42 + communicator.rank)
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in itertools.count():
        # Allow failures to occur at random.
        next(failure)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            log.sync = False
            log.error(f"Iteration {iteration} comm {communicator.name} player {communicator.rank} exception: {e}")

            # Something went wrong.  Revoke the current communicator to
            # ensure that all players are aware of it.
            communicator.revoke()

            # Obtain a new communicator that contains the remaining players.
            newcommunicator, oldranks = communicator.shrink(name="fallback")

            # Recreate the logger since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            log = Logger(logging.getLogger(), newcommunicator)
            log.info(f"Iteration {iteration} shrank comm {communicator.name} with {communicator.world_size} players to comm {newcommunicator.name} with {newcommunicator.world_size} players.", src=0)

            # Recreate the protocol since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

            # Cleanup the old communicator.
            communicator.free()
            communicator = newcommunicator

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
ERROR:root:Iteration 1 comm world player 0 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world player 0 revoked by player 0
ERROR:root:Iteration 1 comm world player 2 exception: Tag SCATTER from player 1 timed-out after 5s
ERROR:root:Iteration 1 comm world player 3 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world player 3 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 2 revoked by player 0
INFO:root:Iteration 1 shrank comm world with 4 players to comm fallback with 3 players.
INFO:root:Iteration 2 comm fallback total: 2.0
INFO:root:Iteration 3 comm fallback total: 3.0
INFO:root:Iteration 4 comm fallback total: 4.0
INFO:root:Iteration 5 comm fallback total: 5.0
INFO:root:Iteration 6 comm fallback total: 6.0
INFO:root:Iteration 7 comm fallback total: 7.0
INFO:root:Iteration 8 comm fallback total: 8.0
INFO:root:Iteration 9 comm fallback total: 9.0
INFO:root:Iteration 10 comm fallback total: 10.0
INFO:root:Iteration 11 comm fallback total: 11.0
INFO:root:Iteration 12 comm fallback total: 12.0
INFO:root:Iteration 13 comm fallback total: 13.0
INFO:root:Iteration 14 comm fallback total: 14.0
ERROR:root:Iteration 15 comm fallback player 0 exception: Tag GATHERV from player 1 timed-out after 5s
ERROR:root:Iteration 15 comm fallback player 2 exception: Tag GATHERV from player 0 timed-out after 5s
WARNING:cicada.communicator.socket:Comm fallback player 0 revoked by player 0
WARNING:cicada.communicator.socket:Comm fallback player 2 revoked by player 2
INFO:root:Iteration 15 shrank comm fallback with 3 players to comm fallback with 2 players.
INFO:root:Iteration 16 comm fallback total: 16.0
INFO:root:Iteration 17 comm fallback total: 17.0
INFO:root:Iteration 18 comm fallback total: 18.0
ERROR:root:Iteration 19 comm fallback player 0 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm fallback player 0 revoked by player 0
INFO:root:Iteration 19 shrank comm fallback with 2 players to comm fallback with 1 players.
ERROR:cicada.communicator.socket:Comm world player 0 failed: ValueError('threshold must be <= world_size')
ERROR:cicada.communicator.socket:Comm world player 1 failed: Terminated(exitcode=-9)
ERROR:cicada.communicator.socket:Comm world player 2 failed: Terminated(exitcode=-9)
ERROR:cicada.communicator.socket:Comm world player 3 failed: Terminated(exitcode=-9)

This is great! We lost one player during iteration 1, another player during iteration 15, and a third player during iteration 19, yet each time the communicator was revoked and shrunk, and the program continued. In fact, our code is written to run forever:

for iteration in itertools.count():

… but it didn’t do so, because we eventually ran out of players! Note that the program finally exits because we can’t create an instance of ShamirBasicProtocolSuite with one player and a threshold of two. This causes an exception to be raised inside our failure recovery exception block, which exits the program.

Since this is an foreseeable situation, we should detect it and shutdown the program cleanly:

[8]:
def main(communicator):
    # One-time initialization.
    failure = random_failure(pfail=0.05, seed=42 + communicator.rank)
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in itertools.count():
        # Allow failures to occur at random.
        next(failure)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            log.sync = False
            log.error(f"Iteration {iteration} comm {communicator.name} player {communicator.rank} exception: {e}")

            # Something went wrong.  Revoke the current communicator to
            # ensure that all players are aware of it.
            communicator.revoke()

            # If we don't have enough players to continue, it's time to shutdown cleanly.
            if communicator.world_size == protocol.threshold:
                log.info(f"Iteration {iteration} not enough players to continue.", src=0)
                break

            # Obtain a new communicator that contains the remaining players.
            newcommunicator, oldranks = communicator.shrink(name="fallback")

            # Recreate the logger since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            log = Logger(logging.getLogger(), newcommunicator)
            log.info(f"Iteration {iteration} shrank comm {communicator.name} with {communicator.world_size} players to comm {newcommunicator.name} with {newcommunicator.world_size} players.", src=0)

            # Recreate the protocol since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

            # Cleanup the old communicator.
            communicator.free()
            communicator = newcommunicator

SocketCommunicator.run(world_size=4, fn=main);
INFO:root:Iteration 0 comm world total: 1.0
ERROR:root:Iteration 1 comm world player 0 exception: Tag SCATTER from player 1 timed-out after 5s
ERROR:root:Iteration 1 comm world player 3 exception: Tag SCATTER from player 1 timed-out after 5s
ERROR:root:Iteration 1 comm world player 2 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world player 0 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 2 revoked by player 0
WARNING:cicada.communicator.socket:Comm world player 3 revoked by player 0
INFO:root:Iteration 1 shrank comm world with 4 players to comm fallback with 3 players.
INFO:root:Iteration 2 comm fallback total: 2.0
INFO:root:Iteration 3 comm fallback total: 3.0
INFO:root:Iteration 4 comm fallback total: 4.0
INFO:root:Iteration 5 comm fallback total: 5.0
INFO:root:Iteration 6 comm fallback total: 6.0
INFO:root:Iteration 7 comm fallback total: 7.0
INFO:root:Iteration 8 comm fallback total: 8.0
INFO:root:Iteration 9 comm fallback total: 9.0
INFO:root:Iteration 10 comm fallback total: 10.0
INFO:root:Iteration 11 comm fallback total: 11.0
INFO:root:Iteration 12 comm fallback total: 12.0
INFO:root:Iteration 13 comm fallback total: 13.0
INFO:root:Iteration 14 comm fallback total: 14.0
ERROR:root:Iteration 15 comm fallback player 0 exception: Tag GATHERV from player 1 timed-out after 5s
ERROR:root:Iteration 15 comm fallback player 2 exception: Tag GATHERV from player 0 timed-out after 5s
WARNING:cicada.communicator.socket:Comm fallback player 0 revoked by player 0
WARNING:cicada.communicator.socket:Comm fallback player 2 revoked by player 2
INFO:root:Iteration 15 shrank comm fallback with 3 players to comm fallback with 2 players.
INFO:root:Iteration 16 comm fallback total: 16.0
INFO:root:Iteration 17 comm fallback total: 17.0
INFO:root:Iteration 18 comm fallback total: 18.0
ERROR:root:Iteration 19 comm fallback player 0 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm fallback player 0 revoked by player 0
INFO:root:Iteration 19 not enough players to continue.
ERROR:cicada.communicator.socket:Comm world player 1 failed: Terminated(exitcode=-9)
ERROR:cicada.communicator.socket:Comm world player 2 failed: Terminated(exitcode=-9)
ERROR:cicada.communicator.socket:Comm world player 3 failed: Terminated(exitcode=-9)

If you look closely at the logs, you’ll notice in early log messages that the original communicator is named “world”, but after the first failure it’s named “fallback”. This is because that’s the name we supply when the communicator is shrunk:

newcommunicator, oldranks = communicator.shrink(name="fallback")

Since communicator names are arbitrary, let’s specify something more useful, like a name that tells us how many “generations” of communicators we’ve created:

[9]:
def main(communicator):
    # One-time initialization.
    communicator_index = itertools.count(1)
    failure = random_failure(pfail=0.05, seed=42 + communicator.rank)
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in itertools.count():
        # Allow failures to occur at random.
        next(failure)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            log.sync = False
            log.error(f"Iteration {iteration} comm {communicator.name} player {communicator.rank} exception: {e}")

            # Something went wrong.  Revoke the current communicator to
            # ensure that all players are aware of it.
            communicator.revoke()

            # If we don't have enough players to continue, it's time to shutdown cleanly.
            if communicator.world_size == protocol.threshold:
                log.info(f"Iteration {iteration} not enough players to continue.", src=0)
                break

            # Obtain a new communicator that contains the remaining players.
            newcommunicator, oldranks = communicator.shrink(name=f"world-{next(communicator_index)}")

            # Recreate the logger since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            log = Logger(logging.getLogger(), newcommunicator)
            log.info(f"Iteration {iteration} shrank comm {communicator.name} with {communicator.world_size} players to comm {newcommunicator.name} with {newcommunicator.world_size} players.", src=0)

            # Recreate the protocol since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

            # Cleanup the old communicator.
            communicator.free()
            communicator = newcommunicator

SocketCommunicator.run(world_size=4, fn=main, name="world-0");
INFO:root:Iteration 0 comm world-0 total: 1.0
ERROR:root:Iteration 1 comm world-0 player 0 exception: Tag SCATTER from player 1 timed-out after 5s
ERROR:root:Iteration 1 comm world-0 player 3 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world-0 player 0 revoked by player 0
ERROR:root:Iteration 1 comm world-0 player 2 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world-0 player 3 revoked by player 0
WARNING:cicada.communicator.socket:Comm world-0 player 2 revoked by player 0
INFO:root:Iteration 1 shrank comm world-0 with 4 players to comm world-1 with 3 players.
INFO:root:Iteration 2 comm world-1 total: 2.0
INFO:root:Iteration 3 comm world-1 total: 3.0
INFO:root:Iteration 4 comm world-1 total: 4.0
INFO:root:Iteration 5 comm world-1 total: 5.0
INFO:root:Iteration 6 comm world-1 total: 6.0
INFO:root:Iteration 7 comm world-1 total: 7.0
INFO:root:Iteration 8 comm world-1 total: 8.0
INFO:root:Iteration 9 comm world-1 total: 9.0
INFO:root:Iteration 10 comm world-1 total: 10.0
INFO:root:Iteration 11 comm world-1 total: 11.0
INFO:root:Iteration 12 comm world-1 total: 12.0
INFO:root:Iteration 13 comm world-1 total: 13.0
INFO:root:Iteration 14 comm world-1 total: 14.0
ERROR:root:Iteration 15 comm world-1 player 0 exception: Tag GATHERV from player 1 timed-out after 5s
ERROR:root:Iteration 15 comm world-1 player 2 exception: Tag GATHERV from player 0 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world-1 player 0 revoked by player 0
WARNING:cicada.communicator.socket:Comm world-1 player 2 revoked by player 2
INFO:root:Iteration 15 shrank comm world-1 with 3 players to comm world-2 with 2 players.
INFO:root:Iteration 16 comm world-2 total: 16.0
INFO:root:Iteration 17 comm world-2 total: 17.0
INFO:root:Iteration 18 comm world-2 total: 18.0
ERROR:root:Iteration 19 comm world-2 player 0 exception: Tag SCATTER from player 1 timed-out after 5s
WARNING:cicada.communicator.socket:Comm world-2 player 0 revoked by player 0
INFO:root:Iteration 19 not enough players to continue.
ERROR:cicada.communicator.socket:Comm world-0 player 1 failed: Terminated(exitcode=-9)
ERROR:cicada.communicator.socket:Comm world-0 player 2 failed: Terminated(exitcode=-9)
ERROR:cicada.communicator.socket:Comm world-0 player 3 failed: Terminated(exitcode=-9)

Now, we can clearly see the identities of the individual communicators as they’re created.

With things working reliably, let’s cut down on logging so we can clearly see what’s happening, and increase the number of players:

[10]:
logging.getLogger("cicada.communicator").setLevel(logging.CRITICAL)

def main(communicator):
    # One-time initialization.
    communicator_index = itertools.count(1)
    failure = random_failure(pfail=0.05, seed=42 + communicator.rank)
    log = Logger(logging.getLogger(), communicator=communicator)
    protocol = ShamirBasicProtocolSuite(communicator=communicator, threshold=2)

    total_share = protocol.share(src=0, secret=numpy.array(0), shape=())

    # Main iteration loop.
    for iteration in itertools.count():
        # Allow failures to occur at random.
        next(failure)

        # Do computation in this block.
        try:
            # Increment the total.
            contributor = iteration % communicator.world_size
            increment = numpy.array(1) if communicator.rank == contributor else None
            increment_share = protocol.share(src=contributor, secret=increment, shape=())
            total_share = protocol.add(total_share, increment_share)

            # Print the current total.
            total = protocol.reveal(total_share)
            log.info(f"Iteration {iteration} comm {communicator.name} total: {total}", src=0)

        # Implement failure recovery in this block.  Be careful here! Many
        # operations can't be used when there are unresponsive players.
        except Exception as e:
            # Something went wrong.  Revoke the current communicator to
            # ensure that all players are aware of it.
            communicator.revoke()

            # If we don't have enough players to continue, it's time to shutdown cleanly.
            if communicator.world_size == protocol.threshold:
                log.info(f"Iteration {iteration} not enough players to continue.", src=0)
                break

            # Obtain a new communicator that contains the remaining players.
            newcommunicator, oldranks = communicator.shrink(name=f"world-{next(communicator_index)}")

            # Recreate the logger since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            log = Logger(logging.getLogger(), newcommunicator)
            log.info(f"Iteration {iteration} shrank comm {communicator.name} with {communicator.world_size} players to comm {newcommunicator.name} with {newcommunicator.world_size} players.", src=0)

            # Recreate the protocol since objects that depend on the old,
            # revoked communicator must be rebuilt from scratch using the
            # new communicator.
            protocol = ShamirBasicProtocolSuite(newcommunicator, threshold=2, indices=protocol.indices[oldranks])

            # Cleanup the old communicator.
            communicator.free()
            communicator = newcommunicator

SocketCommunicator.run(world_size=5, fn=main, name="world-0");
INFO:root:Iteration 0 comm world-0 total: 1.0
INFO:root:Iteration 1 shrank comm world-0 with 5 players to comm world-1 with 4 players.
INFO:root:Iteration 2 comm world-1 total: 2.0
INFO:root:Iteration 3 comm world-1 total: 3.0
INFO:root:Iteration 4 comm world-1 total: 4.0
INFO:root:Iteration 5 comm world-1 total: 5.0
INFO:root:Iteration 6 comm world-1 total: 6.0
INFO:root:Iteration 7 comm world-1 total: 7.0
INFO:root:Iteration 8 comm world-1 total: 8.0
INFO:root:Iteration 9 comm world-1 total: 9.0
INFO:root:Iteration 10 comm world-1 total: 10.0
INFO:root:Iteration 11 comm world-1 total: 11.0
INFO:root:Iteration 12 comm world-1 total: 12.0
INFO:root:Iteration 13 comm world-1 total: 13.0
INFO:root:Iteration 14 comm world-1 total: 14.0
INFO:root:Iteration 15 shrank comm world-1 with 4 players to comm world-2 with 3 players.
INFO:root:Iteration 16 comm world-2 total: 16.0
INFO:root:Iteration 17 comm world-2 total: 17.0
INFO:root:Iteration 18 comm world-2 total: 18.0
INFO:root:Iteration 19 shrank comm world-2 with 3 players to comm world-3 with 2 players.
INFO:root:Iteration 20 comm world-3 total: 19.0
INFO:root:Iteration 21 comm world-3 total: 20.0
INFO:root:Iteration 22 not enough players to continue.