forked from data61/MP-SPDZ
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoprf_higher_order_residuals.mpc
More file actions
116 lines (90 loc) · 3.94 KB
/
oprf_higher_order_residuals.mpc
File metadata and controls
116 lines (90 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from Compiler import instructions_base
import sys
from sympy.ntheory import primefactors
import random
program.set_security(64) #Statistical security parameter.
nparallel = int(sys.argv[2])
eval_len = int(sys.argv[3])
order = int(sys.argv[4])
generator = int(sys.argv[5]) # Get the generator which we will use to comptue the higher order residuals
prime = program.prime
debug = False
if debug:
print_ln("Getting indexes")
indexes = [public_input() for _ in range(eval_len)]
if debug:
print_ln("Got Indexes")
def compute_residual(value, power, generator, prime):
"""Given some value x, the power and a generator of Z_p finds the power-residual of x in this group. The way this is currently output is by just outputting an array of all the elements, and the power residue is then the index of the element in the array that contains a 1.
Since the goal is to check if value/g^i = h^k for some h, we can equivalently check if (value/g^i)^((p-1)/k) = 1
"""
outputs = []
for out in range(1,power+1):
divisor = pow(generator, -out, prime)
outputs.append((value * divisor) ** ((prime-1)//power))
return outputs
if prime % order != 1:
raise Exception("Invalid prime for this order!")
instructions_base.set_global_vector_size(nparallel)
# This reads nparallel keys. But that's just a limitation of this library. If do the read and then set nparallel it'll use an extra round.
if debug:
print_ln("Getting key")
server_key = sint.get_input_from(1)
if debug:
print_ln("Got key")
if debug:
print_ln("Getting masks")
revealed_masks = [public_input() for _ in range(eval_len)]
if debug:
print_ln("Got masks")
class LegHigherOrderOPRF(object):
def __init__(self, bit_len, indexes):
self.prf_bit_len = bit_len
self.indexes = indexes
def compute_symbols(self, m, user_s, user_mask, server_s, k):
""" Takes as input [m] (H(x)), [user_s], [server_s], [k] returns value masked by s^2 and the user_mask to both parties."""
if len(user_s) != self.prf_bit_len:
return []
if len(server_s) != self.prf_bit_len:
return []
#Compute s values
final_s = [user_s[i] + server_s[i] for i in range(self.prf_bit_len)]
#Compute prf output
prf_out = self.power_prf(m, final_s, k)
# Reveal ts_squaredhe masked values
res = [(prf_out[i] + user_mask[i]).reveal() for i in range(len(prf_out))]
return res
def power_prf(self, msg, s, k):
""" Compute the actual values """
res = [0] * self.prf_bit_len
if len(s) != self.prf_bit_len:
return []
#Compute values
for i in range(self.prf_bit_len):
s_mask = s[i] ** order
res[i] = s_mask * (k + msg + self.indexes[i])
return res
HigherOrderOPRF = LegHigherOrderOPRF(eval_len, indexes)
#Read the inputs
user_input = sint.get_input_from(0)
user_s_values = [sint.get_input_from(0) for _ in range(eval_len)]
output_masks = [sint.get_input_from(0) for _ in range(eval_len)]
s_values = [sint.get_input_from(1) for _ in range(eval_len)]
if debug: ##To debug
print_ln('Output masks %s %s', revealed_masks[0], output_masks[0].reveal())
index = (indexes[0])
print_ln('User input and Key %s %s', user_input.reveal(), server_key.reveal())
res = HigherOrderOPRF.compute_symbols(user_input, user_s_values, output_masks, s_values, server_key)
for i in range(eval_len):
# Generate the output and write it to a file.
if debug:
print_ln('%s', res[i] - revealed_masks[i])
print_ln('Next entry:')
results = compute_residual(res[i] - revealed_masks[i], order, generator, prime)
index = 1
for j in results:
#This outptus power/order many residuals, from which then the corresponding one can be recovered, by checking which entry is one.
if debug:
print_ln('Result %s %s', index, j)
index += 1
j.binary_output() #Output to file