[viff-devel] [PATCH 08 of 12] Implementation of the basic multiplication command
Janus Dam Nielsen
janus.nielsen at alexandra.dk
Fri Jun 19 02:32:20 PDT 2009
# HG changeset patch
# User Janus Dam Nielsen <janus.nielsen at alexandra.dk>
# Date 1245395036 -7200
# Node ID a07740da4582869d11ead0f56ae055965aa2b4b0
# Parent 07a8329e75322d482dae15186422dd75e9ddb653
Implementation of the basic multiplication command.
diff --git a/viff/orlandi.py b/viff/orlandi.py
--- a/viff/orlandi.py
+++ b/viff/orlandi.py
@@ -15,6 +15,8 @@
# You should have received a copy of the GNU Lesser General Public
# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
+from twisted.internet.defer import DeferredList, gatherResults
+
from viff.runtime import Runtime, increment_pc, Share, ShareList, gather_shares
from viff.util import rand, dprint
@@ -442,6 +444,21 @@
return results[0]
return results
+ @increment_pc
+ def mul(self, share_x, share_y):
+ """Multiplication of shares.
+
+ Communication cost: ???.
+ """
+ # TODO: Communication cost?
+ assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+ "At least one of share_x and share_y must be a Share."
+
+ field = getattr(share_x, "field", getattr(share_y, "field", None))
+
+ a, b, c = self._get_triple(field)
+ return self._basic_multiplication(share_x, share_y, a, b, c)
+
def _additive_constant(self, zero, field_element):
"""Greate an additive constant.
@@ -488,6 +505,97 @@
Cz = Cx * Cy
return (zi, (rhozi1, rhozi2), Cz)
+ def _cmul(self, share_x, share_y, field):
+ """Multiplication of a share with a constant.
+
+ Either share_x or share_y must be an OrlandiShare but not both.
+ Returns None if both share_x and share_y are OrlandiShares.
+
+ """
+ def constant_multiply(x, c):
+ zi, rhoz, Cx = self._const_mul(c, x)
+ return OrlandiShare(self, field, zi, rhoz, Cx)
+ if not isinstance(share_x, Share):
+ # Then share_y must be a Share => local multiplication. We
+ # clone first to avoid changing share_y.
+ assert isinstance(share_y, Share), \
+ "At least one of the arguments must be a share."
+ result = share_y.clone()
+ result.addCallback(constant_multiply, share_x)
+ return result
+ if not isinstance(share_y, Share):
+ # Likewise when share_y is a constant.
+ assert isinstance(share_x, Share), \
+ "At least one of the arguments must be a share."
+ result = share_x.clone()
+ result.addCallback(constant_multiply, share_y)
+ return result
+ return None
+
+ def _const_mul(self, c, x):
+ """Multiplication of a share-tuple with a constant c."""
+ xi, (rhoi1, rhoi2), Cx = x
+ zi = xi * c
+ rhoz = (rhoi1 * c, rhoi2 * c)
+ Cz = Cx # TODO: exponentiation
+ return (zi, rhoz, Cx)
+
+
+ def _get_triple(self, field):
+ n = field(0)
+ a = OrlandiShare(self, field, field(2), (n, n), n)
+ b = OrlandiShare(self, field, field(4), (n, n), n)
+ c = OrlandiShare(self, field, field(24), (n, n), n)
+ return (a, b, c)
+
+ @increment_pc
+ def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
+ """Multiplication of shares give a triple.
+
+ Communication cost: ???.
+
+ d = Open([x] - [a])
+ e = Open([y] - [b])
+ [z] = e[x] + d[y] - [de] + [c]
+ """
+ assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+ "At least one of share_x and share_y must be a Share."
+
+ field = getattr(share_x, "field", getattr(share_y, "field", None))
+ n = field(0)
+
+ cmul_result = self._cmul(share_x, share_y, field)
+ if cmul_result is not None:
+ return cmul_result
+
+ def multiply((x, y, d, e, c)):
+ # [de]
+ de = self._additive_constant(field(0), d * e)
+ # e[x]
+ t1 = self._const_mul(e, x)
+ # d[y]
+ t2 = self._const_mul(d, y)
+ # d[y] - [de]
+ t3 = self._minus(t2, de)
+ # d[y] - [de] + [c]
+ t4 = self._plus(t3, c)
+ # [z] = e[x] + d[y] - [de] + [c]
+ zi, rhoz, Cz = self._plus(t1, t4)
+ return OrlandiShare(self, field, zi, rhoz, Cz)
+
+ # d = Open([x] - [a])
+ d = self.open(share_x - triple_a)
+ # e = Open([y] - [b])
+ e = self.open(share_y - triple_b)
+ result = gather_shares([share_x, share_y, d, e, triple_c])
+ result.addCallbacks(multiply, self.error_handler)
+
+ # do actual communication
+ self.activate_reactor()
+
+ return result
+
def error_handler(self, ex):
print "Error: ", ex
return ex
+
diff --git a/viff/test/test_orlandi_runtime.py b/viff/test/test_orlandi_runtime.py
--- a/viff/test/test_orlandi_runtime.py
+++ b/viff/test/test_orlandi_runtime.py
@@ -252,3 +252,126 @@
d2.addCallback(check)
return DeferredList([d1, d2])
+ @protocol
+ def test_basic_multiply(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ def check(v):
+ self.assertEquals(v, x1 * y1)
+
+ x2 = runtime.shift([2], self.Zp, x1)
+ y2 = runtime.shift([3], self.Zp, y1)
+
+ a, b, c = runtime._get_triple(self.Zp)
+ z2 = runtime._basic_multiplication(x2, y2, a, b, c)
+ d = runtime.open(z2)
+ d.addCallback(check)
+ return d
+
+ @protocol
+ def test_mul_mul(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ def check(v):
+ self.assertEquals(v, x1 * y1)
+
+ x2 = runtime.shift([2], self.Zp, x1)
+ y2 = runtime.shift([3], self.Zp, y1)
+
+ z2 = x2 * y2
+ d = runtime.open(z2)
+ d.addCallback(check)
+ return d
+
+ @protocol
+ def test_basic_multiply_constant_right(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ def check(v):
+ self.assertEquals(v, x1 * y1)
+
+ x2 = runtime.shift([1], self.Zp, x1)
+
+ a, b, c = runtime._get_triple(self.Zp)
+ z2 = runtime._basic_multiplication(x2, y1, a, b, c)
+ d = runtime.open(z2)
+ d.addCallback(check)
+ return d
+
+ @protocol
+ def test_basic_multiply_constant_left(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ def check(v):
+ self.assertEquals(v, x1 * y1)
+
+ x2 = runtime.shift([1], self.Zp, x1)
+
+ a, b, c = runtime._get_triple(self.Zp)
+ z2 = runtime._basic_multiplication(y1, x2, a, b, c)
+ d = runtime.open(z2)
+ d.addCallback(check)
+ return d
+
+ @protocol
+ def test_constant_multiplication_constant_left(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ def check(v):
+ self.assertEquals(v, x1 * y1)
+
+ x2 = runtime.shift([1], self.Zp, x1)
+
+ a, b, c = runtime._get_triple(self.Zp)
+ z2 = runtime._cmul(y1, x2, self.Zp)
+ d = runtime.open(z2)
+ d.addCallback(check)
+ return d
+
+ @protocol
+ def test_constant_multiplication_constant_right(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ def check(v):
+ self.assertEquals(v, x1 * y1)
+
+ x2 = runtime.shift([1], self.Zp, x1)
+
+ a, b, c = runtime._get_triple(self.Zp)
+ z2 = runtime._cmul(x2, y1, self.Zp)
+ d = runtime.open(z2)
+ d.addCallback(check)
+ return d
+
+ @protocol
+ def test_constant_multiplication_constant_None(self, runtime):
+ """Test multiplication of two numbers."""
+
+ x1 = 42
+ y1 = 7
+
+ x2 = runtime.shift([1], self.Zp, x1)
+ y2 = runtime.shift([1], self.Zp, y1)
+
+ a, b, c = runtime._get_triple(self.Zp)
+ z2 = runtime._cmul(y2, x2, self.Zp)
+ self.assertEquals(z2, None)
+ return z2
More information about the viff-devel
mailing list