# Copyright 2021 National Technology & Engineering Solutions
# of Sandia, LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functionality for generating transcripts of library activity.
Use-cases include detailed debugging, logging network traffic, function
tracing, and MPC-in-the-head for zero knowledge proofs. The latter case
is described in
Ishai, Yuval, et al. "Zero-knowledge from secure multiparty computation." *Proceedings of the thirty-ninth annual ACM symposium on Theory of computing. 2007.*
Examples of this technique appear in
https://csrc.nist.gov/projects/pqc-dig-sig/round-1-additional-signatures as
proposals for NIST PQC standardization.
"""
import copy
import logging
import types
import hunter
import numpy
from .active import ActiveArrayShare, ActiveProtocolSuite
from .additive import AdditiveArrayShare, AdditiveProtocolSuite
from .communicator.interface import tagname
from .shamir import ShamirArrayShare, ShamirProtocolSuite, ShamirBasicProtocolSuite
class _CallLogger(hunter.actions.Action):
def __init__(self):
self.stack = []
self.display_whitelist = set([
"cicada.active.ActiveArrayShare.__repr__",
"cicada.active.ActiveProtocolSuite.absolute",
"cicada.active.ActiveProtocolSuite.add",
"cicada.active.ActiveProtocolSuite.bit_compose",
"cicada.active.ActiveProtocolSuite.bit_decompose",
"cicada.active.ActiveProtocolSuite.divide",
"cicada.active.ActiveProtocolSuite.dot",
"cicada.active.ActiveProtocolSuite.equal",
"cicada.active.ActiveProtocolSuite.field_add",
"cicada.active.ActiveProtocolSuite.field_dot",
"cicada.active.ActiveProtocolSuite.field_multiply",
"cicada.active.ActiveProtocolSuite.field_power",
"cicada.active.ActiveProtocolSuite.field_subtract",
"cicada.active.ActiveProtocolSuite.field_uniform",
"cicada.active.ActiveProtocolSuite.floor",
"cicada.active.ActiveProtocolSuite.less",
"cicada.active.ActiveProtocolSuite.less_zero",
"cicada.active.ActiveProtocolSuite.logical_and",
"cicada.active.ActiveProtocolSuite.logical_not",
"cicada.active.ActiveProtocolSuite.logical_or",
"cicada.active.ActiveProtocolSuite.logical_xor",
"cicada.active.ActiveProtocolSuite.maximum",
"cicada.active.ActiveProtocolSuite.minimum",
"cicada.active.ActiveProtocolSuite.multiplicative_inverse",
"cicada.active.ActiveProtocolSuite.multiply",
"cicada.active.ActiveProtocolSuite.negative",
"cicada.active.ActiveProtocolSuite.power",
"cicada.active.ActiveProtocolSuite.random_bitwise_secret",
"cicada.active.ActiveProtocolSuite.relu",
"cicada.active.ActiveProtocolSuite.reshare",
"cicada.active.ActiveProtocolSuite.reveal",
"cicada.active.ActiveProtocolSuite.right_shift",
"cicada.active.ActiveProtocolSuite.share",
"cicada.active.ActiveProtocolSuite.subtract",
"cicada.active.ActiveProtocolSuite.sum",
"cicada.active.ActiveProtocolSuite.zigmoid",
"cicada.additive.AdditiveArrayShare.__repr__",
"cicada.additive.AdditiveProtocolSuite.absolute",
"cicada.additive.AdditiveProtocolSuite.add",
"cicada.additive.AdditiveProtocolSuite.bit_compose",
"cicada.additive.AdditiveProtocolSuite.bit_decompose",
"cicada.additive.AdditiveProtocolSuite.divide",
"cicada.additive.AdditiveProtocolSuite.dot",
"cicada.additive.AdditiveProtocolSuite.equal",
"cicada.additive.AdditiveProtocolSuite.field_add",
"cicada.additive.AdditiveProtocolSuite.field_dot",
"cicada.additive.AdditiveProtocolSuite.field_multiply",
"cicada.additive.AdditiveProtocolSuite.field_power",
"cicada.additive.AdditiveProtocolSuite.field_subtract",
"cicada.additive.AdditiveProtocolSuite.field_uniform",
"cicada.additive.AdditiveProtocolSuite.floor",
"cicada.additive.AdditiveProtocolSuite.less",
"cicada.additive.AdditiveProtocolSuite.less_zero",
"cicada.additive.AdditiveProtocolSuite.logical_and",
"cicada.additive.AdditiveProtocolSuite.logical_not",
"cicada.additive.AdditiveProtocolSuite.logical_or",
"cicada.additive.AdditiveProtocolSuite.logical_xor",
"cicada.additive.AdditiveProtocolSuite.maximum",
"cicada.additive.AdditiveProtocolSuite.minimum",
"cicada.additive.AdditiveProtocolSuite.multiplicative_inverse",
"cicada.additive.AdditiveProtocolSuite.multiply",
"cicada.additive.AdditiveProtocolSuite.negative",
"cicada.additive.AdditiveProtocolSuite.power",
"cicada.additive.AdditiveProtocolSuite.random_bitwise_secret",
"cicada.additive.AdditiveProtocolSuite.relu",
"cicada.additive.AdditiveProtocolSuite.reshare",
"cicada.additive.AdditiveProtocolSuite.reveal",
"cicada.additive.AdditiveProtocolSuite.right_shift",
"cicada.additive.AdditiveProtocolSuite.share",
"cicada.additive.AdditiveProtocolSuite.subtract",
"cicada.additive.AdditiveProtocolSuite.sum",
"cicada.additive.AdditiveProtocolSuite.zigmoid",
"cicada.shamir.ShamirArrayShare.__repr__",
"cicada.shamir.ShamirBasicProtocolSuite.add",
"cicada.shamir.ShamirBasicProtocolSuite.bit_compose",
"cicada.shamir.ShamirBasicProtocolSuite.field_add",
"cicada.shamir.ShamirBasicProtocolSuite.field_subtract",
"cicada.shamir.ShamirBasicProtocolSuite.field_uniform",
"cicada.shamir.ShamirBasicProtocolSuite.negative",
"cicada.shamir.ShamirBasicProtocolSuite.reshare",
"cicada.shamir.ShamirBasicProtocolSuite.reveal",
"cicada.shamir.ShamirBasicProtocolSuite.share",
"cicada.shamir.ShamirBasicProtocolSuite.subtract",
"cicada.shamir.ShamirBasicProtocolSuite.sum",
"cicada.shamir.ShamirProtocolSuite.absolute",
"cicada.shamir.ShamirProtocolSuite.add",
"cicada.shamir.ShamirProtocolSuite.bit_compose",
"cicada.shamir.ShamirProtocolSuite.bit_decompose",
"cicada.shamir.ShamirProtocolSuite.divide",
"cicada.shamir.ShamirProtocolSuite.dot",
"cicada.shamir.ShamirProtocolSuite.equal",
"cicada.shamir.ShamirProtocolSuite.field_add",
"cicada.shamir.ShamirProtocolSuite.field_dot",
"cicada.shamir.ShamirProtocolSuite.field_multiply",
"cicada.shamir.ShamirProtocolSuite.field_power",
"cicada.shamir.ShamirProtocolSuite.field_subtract",
"cicada.shamir.ShamirProtocolSuite.field_uniform",
"cicada.shamir.ShamirProtocolSuite.floor",
"cicada.shamir.ShamirProtocolSuite.less",
"cicada.shamir.ShamirProtocolSuite.less_zero",
"cicada.shamir.ShamirProtocolSuite.logical_and",
"cicada.shamir.ShamirProtocolSuite.logical_not",
"cicada.shamir.ShamirProtocolSuite.logical_or",
"cicada.shamir.ShamirProtocolSuite.logical_xor",
"cicada.shamir.ShamirProtocolSuite.matvec",
"cicada.shamir.ShamirProtocolSuite.maximum",
"cicada.shamir.ShamirProtocolSuite.minimum",
"cicada.shamir.ShamirProtocolSuite.multiplicative_inverse",
"cicada.shamir.ShamirProtocolSuite.multiply",
"cicada.shamir.ShamirProtocolSuite.negative",
"cicada.shamir.ShamirProtocolSuite.power",
"cicada.shamir.ShamirProtocolSuite.random_bitwise_secret",
"cicada.shamir.ShamirProtocolSuite.relu",
"cicada.shamir.ShamirProtocolSuite.reshare",
"cicada.shamir.ShamirProtocolSuite.reveal",
"cicada.shamir.ShamirProtocolSuite.right_shift",
"cicada.shamir.ShamirProtocolSuite.share",
"cicada.shamir.ShamirProtocolSuite.subtract",
"cicada.shamir.ShamirProtocolSuite.sum",
"cicada.shamir.ShamirProtocolSuite.zigmoid",
])
self.test_whitelist = set([
"cicada.arithmetic.Field.__call__",
"cicada.arithmetic.Field.add",
"cicada.arithmetic.Field.full_like",
"cicada.arithmetic.Field.inplace_add",
"cicada.arithmetic.Field.inplace_subtract",
"cicada.arithmetic.Field.multiply",
"cicada.arithmetic.Field.negative",
"cicada.arithmetic.Field.ones",
"cicada.arithmetic.Field.ones_like",
"cicada.arithmetic.Field.subtract",
"cicada.arithmetic.Field.sum",
"cicada.arithmetic.Field.uniform",
"cicada.arithmetic.Field.zeros",
"cicada.arithmetic.Field.zeros_like",
"cicada.encoding.Bits.decode",
"cicada.encoding.Bits.encode",
"cicada.encoding.Boolean.decode",
"cicada.encoding.Boolean.encode",
"cicada.encoding.FixedPoint.decode",
"cicada.encoding.FixedPoint.encode",
"cicada.encoding.Identity.decode",
"cicada.encoding.Identity.encode",
])
def __call__(self, event):
if not hasattr(event.function_object, "__qualname__"):
return
fqname = event.module + "." + event.function_object.__qualname__
qname = event.function_object.__qualname__
name = event.function_object.__name__
#####################################################################################
# Log sent messages
if event.kind == "call" and fqname == "cicada.communicator.socket.SocketCommunicator._send":
communicator = event.locals["self"]
dst = event.locals["dst"]
payload = event.locals["payload"]
tag = event.locals["tag"]
net = Message()
net.arrow = "-->"
net.comm = communicator
net.dir = ">"
net.dst = dst
net.other = dst
net.payload = payload
net.src = communicator.rank
net.tag = tagname(tag)
net.verb = "sent"
logger.info("Sent message", extra={"net": net})
#####################################################################################
# Log received messages
if event.kind == "call" and fqname == "cicada.communicator.socket.SocketCommunicator._queue_message":
communicator = event.locals["self"]
src = event.locals["src"]
payload = event.locals["payload"]
tag = event.locals["tag"]
net = Message()
net.arrow = "<--"
net.comm = communicator
net.dir = "<"
net.dst = communicator.rank
net.other = src
net.payload = payload
net.src = src
net.tag = tagname(tag)
net.verb = "received"
logger.info("Received message", extra={"net": net})
#####################################################################################
# Log consistency verification code
# Hide __init__ functions.
if name in ["__init__"]:
return
# Hide private functions.
if name.startswith("_") and not name.startswith("__"):
return
# Identify functions that should be displayed in the transcript.
display = True if fqname in self.display_whitelist else False
# Identify functions that should generate consistency verification code in the transcript.
test = True if fqname in self.test_whitelist else False
# Generate output comments for some function calls.
if event.kind == "call" and display:
args = event.locals
o = self.repr(args["self"])
signature = ", ".join([f"{key}={self.repr(value)}" for key, value in args.items() if key != "self"])
_log_code(event, f"# {o}.{name}({signature})", first=True, last=True)
return
# Make copies of function arguments, in-case they're modified in-place.
if event.kind == "call" and test:
args = {}
for key, value in event.locals.items():
if key in ["self"]:
args[key] = value
else:
args[key] = copy.deepcopy(value)
self.stack.append(args)
return
# Generate consistency verification statements for some function calls.
if event.kind == "return" and test:
args = self.stack.pop()
locals = event.locals
result = event.arg
if fqname in ["cicada.arithmetic.Field.inplace_add", "cicada.arithmetic.Field.inplace_subtract"]:
o = self.repr(args["self"])
lhs = self.repr(args["lhs"])
rhs = self.repr(args["rhs"])
result = self.repr(locals["lhs"])
_log_code(event, f"lhs = {lhs}", first=True)
_log_code(event, f"{o}.{name}(lhs=lhs, rhs={rhs})")
_log_code(event, f"cicada.transcript.assert_equal(lhs, {result})", last=True)
elif fqname == "cicada.arithmetic.Field.uniform":
o = self.repr(args["self"])
size = self.repr(args["size"])
bg = self.repr(args["generator"].bit_generator)
state = self.repr(args["generator"].bit_generator.state)
result = self.repr(result)
_log_code(event, f"bg = {bg}", first=True)
_log_code(event, f"bg.state = {state}")
_log_code(event, f"cicada.transcript.assert_equal({o}.{name}(size={size}, generator=numpy.random.Generator(bg)), {result})", last=True)
elif "self" in args:
o = self.repr(args["self"])
signature = ", ".join([f"{key}={self.repr(value)}" for key, value in args.items() if key != "self"])
result = self.repr(result)
_log_code(event, f"cicada.transcript.assert_equal({o}.{name}({signature}), {result})", first=True, last=True)
else:
signature = ", ".join([f"{key}={self.repr(value)}" for key, value in args.items()])
result = self.repr(result)
_log_code(event, f"cicada.transcript.assert_equal({name}({signature}), {result})", first=True, last=True)
def repr(self, o):
if isinstance(o, list):
return f"[{', '.join(self.repr(item) for item in o)}]"
if isinstance(o, ActiveArrayShare):
return f"cicada.active.ActiveArrayShare(storage=({self.repr(o.storage[0])}, {self.repr(o.storage[1])}))"
if isinstance(o, ActiveProtocolSuite):
return f"cicada.active.ActiveProtcolSuite()"
if isinstance(o, AdditiveArrayShare):
return f"cicada.additive.AdditiveArrayShare(storage={self.repr(o.storage)})"
if isinstance(o, AdditiveProtocolSuite):
return f"cicada.additive.AdditiveProtocolSuite()"
if isinstance(o, ShamirArrayShare):
return f"cicada.shamir.ShamirArrayShare(storage={self.repr(o.storage)})"
if isinstance(o, ShamirBasicProtocolSuite):
return f"cicada.shamir.ShamirBasicProtocolSuite()"
if isinstance(o, ShamirProtocolSuite):
return f"cicada.shamir.ShamirProtocolSuite()"
if isinstance(o, numpy.ndarray):
return f"numpy.array({self.repr(o.tolist())}, dtype='{o.dtype}')"
if isinstance(o, numpy.random.Generator):
return f"numpy.random.Generator({self.repr(o.bit_generator)})"
if isinstance(o, numpy.random.PCG64):
return f"numpy.random.PCG64()"
return repr(o)
[docs]
class Code(types.SimpleNamespace):
"""Stores code-related metadata for use in :class:`Formatter`."""
pass
[docs]
class HideCode(object):
"""Log filter that hides code records."""
def filter(self, record):
if hasattr(record, "code"):
return False
return True
[docs]
class HideContextMessages(object):
"""Log filter that hides context message records."""
def filter(self, record):
if not hasattr(record, "code") and not hasattr(record, "net"):
return False
return True
[docs]
class HideReceivedMessages(object):
"""Log filter that hides network message records for received messages."""
def filter(self, record):
if hasattr(record, "net") and record.net.verb == "received":
return False
return True
[docs]
class HideSentMessages(object):
"""Log filter that hides network message records for sent messages."""
def filter(self, record):
if hasattr(record, "net") and record.net.verb == "sent":
return False
return True
[docs]
class Message(types.SimpleNamespace):
"""Stores message-related metadata for use in :class:`Formatter`."""
pass
def _log_code(event, message, first=False, last=False):
code = Code()
code.filename = event.filename
code.lineno = event.lineno
code.first = first
code.last = last
logger.info(message, extra={"code": code})
[docs]
def assert_equal(lhs, rhs):
"""Test two objects for equality.
Seamlessly handles special types such as numpy.ndarray, etc.
"""
if isinstance(lhs, numpy.ndarray) or isinstance(rhs, numpy.ndarray):
if not numpy.array_equal(lhs, rhs):
raise AssertionError(f"{lhs} != {rhs}")
return
if not lhs == rhs:
raise AssertionError(f"{lhs} != {rhs}")
[docs]
def code_handler(handler=None, fmt=None, netfmt=None, codefmt=None, codepre=None, codepost=None, sent=False, received=False):
"""Create a :class:`logging.Handler`, configured to display consistency verification code records.
Parameters
----------
handler: :class:`logging.Handler`, optional
The handler to be configured. Defaults to a new instance of
:class:`logging.StreamHandler` if :any:`None`.
fmt: :class:`str`, optional
Format string for context records.
netfmt: :class:`str`, optional
Format string for sent- and received-message records.
codefmt: :class:`str`, optional
Format string for consistency verification records.
codepre: :class:`str`, optional
Format string displayed before each group of consistency verification code records.
codepost: :class:`str`, optional
Format string displayed after each group of consistency verification code records.
sent: :any:`bool`, optional
Enable displaying sent-message records if :any:`True`.
received: :any:`bool`, optional
Enable displaying received-message records if :any:`True`.
"""
if handler is None:
handler = logging.StreamHandler()
if fmt is None:
fmt = "# {msg}"
if netfmt is None:
netfmt = "# {net.comm.rank} {net.arrow} {net.other} {net.tag} {net.payload}"
if codefmt is None:
codefmt = "{msg}"
if codepost is None:
codepost = "\n"
if not sent:
handler.addFilter(HideSentMessages())
if not received:
handler.addFilter(HideReceivedMessages())
handler.setFormatter(Formatter(fmt=fmt, netfmt=netfmt, codefmt=codefmt, codepre=codepre, codepost=codepost))
return handler
[docs]
def net_handler(handler=None, fmt=None, netfmt=None, codefmt=None, codepre=None, codepost=None, sent=True, received=True, code=False):
"""Create a :class:`logging.Handler`, configured to display network message records.
Parameters
----------
handler: :class:`logging.Handler`, optional
The handler to be configured. Defaults to a new instance of
:class:`logging.StreamHandler` if :any:`None`.
fmt: :class:`str`, optional
Format string for context records.
netfmt: :class:`str`, optional
Format string for sent- and received-message records.
codefmt: :class:`str`, optional
Format string for consistency verification records.
codepre: :class:`str`, optional
Format string displayed before each group of consistency verification code records.
codepost: :class:`str`, optional
Format string displayed after each group of consistency verification code records.
sent: :any:`bool`, optional
Disable displaying sent-message records if :any:`False`.
received: :any:`bool`, optional
Disable displaying received-message records if :any:`False`.
code: :any:`bool`, optional
Display consistency verification code records if :any:`True`.
"""
if handler is None:
handler = logging.StreamHandler()
if fmt is None:
fmt = "{processName}: {msg}"
if netfmt is None:
netfmt = "{processName}: {net.arrow} {net.other} {net.tag} {net.payload}"
if codefmt is None:
codefmt = "{processName}: {msg}"
if not code:
handler.addFilter(HideCode())
if not sent:
handler.addFilter(HideSentMessages())
if not received:
handler.addFilter(HideReceivedMessages())
handler.setFormatter(Formatter(fmt=fmt, netfmt=netfmt, codefmt=codefmt, codepre=codepre, codepost=codepost))
return handler
[docs]
def set_handler(logger, handler):
"""Set the handler for a logger, removing any other handlers.
Parameters
----------
logger: :class:`logging.Logger`, required
The logger to be modified.
handler: :class:`logging.Handler`, required
The handler to be assigned to `logger`.
"""
logger.level = logging.INFO
while logger.handlers:
logger.removeHandler(logger.handlers[0])
logger.addHandler(handler)
[docs]
def log(message=None):
"""Log general-purpose events into the transcription.
Application code should use this to incorporate high-level application
context information alongside the lower-level network and consistency
verification transcript contents.
"""
logger.info(message)
[docs]
def record():
"""Enable transcription.
All transcription functionality depends on tracing function calls, so this must
be called to begin transcription. The result is a context manager that can be
used in with-statements.
"""
return hunter.trace(module_startswith="cicada", kind_in=("call", "return"), action=_CallLogger())
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)