Multiplication and Truncation
As you might imagine, multiplication using AdditiveProtocolSuite
is a straightforward operation:
[1]:
import logging
import numpy
from cicada.additive import AdditiveProtocolSuite
from cicada.communicator import SocketCommunicator
from cicada.logging import Logger
logging.basicConfig(level=logging.INFO)
def main(communicator):
log = Logger(logging.getLogger(), communicator)
protocol = AdditiveProtocolSuite(communicator)
# Each player will contribute one operand.
a = numpy.array(2) if communicator.rank == 0 else None
b = numpy.array(3) if communicator.rank == 1 else None
log.info(f"Operand a: {a}", src=0)
log.info(f"Operand b: {b}", src=1)
# Secret share our operands.
a_share = protocol.share(src=0, secret=a, shape=())
b_share = protocol.share(src=1, secret=b, shape=())
c_share = protocol.multiply(a_share, b_share)
c = protocol.reveal(c_share)
log.info(f"Player {communicator.rank} result c: {c}")
SocketCommunicator.run(world_size=2, fn=main);
INFO:root:Operand a: 2
INFO:root:Operand b: 3
INFO:root:Player 0 result c: 6.0
INFO:root:Player 1 result c: 6.0
However, there are nuances to multiplication that you should be aware of, due to a concept we call truncation. To illustrate truncation in action, let’s see what happens when we ignore it:
[2]:
def main(communicator):
log = Logger(logging.getLogger(), communicator)
protocol = AdditiveProtocolSuite(communicator)
# Each player will contribute one operand.
a = numpy.array(2) if communicator.rank == 0 else None
b = numpy.array(3) if communicator.rank == 1 else None
log.info(f"Operand a: {a}", src=0)
log.info(f"Operand b: {b}", src=1)
# Secret share our operands.
a_share = protocol.share(src=0, secret=a, shape=())
b_share = protocol.share(src=1, secret=b, shape=())
c_share = protocol.field_multiply(a_share, b_share)
c = protocol.reveal(c_share)
log.info(f"Player {communicator.rank} result c: {c}")
SocketCommunicator.run(world_size=2, fn=main);
INFO:root:Operand a: 2
INFO:root:Operand b: 3
INFO:root:Player 0 result c: 393216.0
INFO:root:Player 1 result c: 393216.0
What the heck!? \(2 \times 3 = 393216\) is clearly wrong. What happened?
Actually this answer is correct, from a certain point of view. Note the call to field_multiply() instead of multiply(), and remember that real values must be encoded as integers before they can be secret shared for computation; for example, AdditiveProtocolSuite
uses a private instance of FixedPoint
encoding to encode and decode real values. The latter encodes real values as integers with a configurable number of bits set aside to store fractional values. Performing addition with this representation is straightforward, but multiplication produces results that are shifted left by the same number of bits. field_multiply() ignores the encoding, and simply multiplies the field values together.
By default the encoder uses 16 bits to store fractions, so in the example above the result returned by AdditiveProtocolSuite.field_multiply
is shifted left by 16 bits, or \(2^{16}\). If we shift right by the same number of bits, we get the expected answer:
\(393216 \div 2^{16} = 6\)
We use truncation to refer to this process of shifting right to eliminate the extra bits, which we can perform explicitly ourselves:
[3]:
def main(communicator):
log = Logger(logging.getLogger(), communicator)
protocol = AdditiveProtocolSuite(communicator)
# Each player will contribute one operand.
a = numpy.array(2) if communicator.rank == 0 else None
b = numpy.array(3) if communicator.rank == 1 else None
log.info(f"Operand a: {a}", src=0)
log.info(f"Operand b: {b}", src=1)
# Secret share our operands.
a_share = protocol.share(src=0, secret=a, shape=())
b_share = protocol.share(src=1, secret=b, shape=())
c_share = protocol.field_multiply(a_share, b_share)
c_share = protocol.right_shift(c_share, bits=protocol.encoding.precision)
c = protocol.reveal(c_share)
log.info(f"Player {communicator.rank} result c: {c}")
SocketCommunicator.run(world_size=2, fn=main);
INFO:root:Operand a: 2
INFO:root:Operand b: 3
INFO:root:Player 0 result c: 6.0
INFO:root:Player 1 result c: 6.0
Now we get the expected answer. As you might imagine, multiply() is implemented using a field_multiply() followed by a call to right_shift(). However, you may be wondering why there are separate AdditiveProtocolSuite.field_multiply
and AdditiveProtocolSuite.multiply
methods to begin with?
The reason we provide them is that the protocol to right-shift secret shared values is extremely expensive. Calling it separately allows you to defer using it until absolutely necessary. For example, if you’re computing the dot product of two large vectors, the distributive property allows you to compute all of the element-wise products using AdditiveProtocolSuite.field_multiply
, then sum the results to a single scalar, then call AdditiveProtocolSuite.right_shift
once, for a potentially huge savings in execution time. This is how the AdditiveProtocolSuite.dot
method works, and you can apply the same technique in your own programs where appropriate.
See also