Crypto disclaimer! I am NOT a crypto expert. Don’t take the information here as 100% correct; you should verify it yourself. You are dangerously bad at crypto.
The problem
Man-in-the-middle attacks are a serious problem when designing any cryptographic protocol. Without using a PKI, a common solution is to provide users’ with the fingerprint of exchanged public keys which they should then verify with the other party via another secure channel to ensure there is no MITM. In practice, this is a very poor solution because most users will not check fingeprints and even if they do, they may only compare the first and last few digits of the fingerprint meaning an attacker only need create a public key with the same few first and last digits of the public key they are trying to impersonate.
The solution
There’s no good protection from MITM, but there is a way to exchange secrets without worrying about a MITM without using a PKI and without checking fingerprints. OTR (off-the-record) messaging utilizes the Socialist Millionaire Protocol. In a (very small) nutshell, SMP allows two parties to check if a secret they both hold are equal to one another without revealing the actual secret to one another (or anyone else). If the secrets are not equal, no other information is revealed except that the secrets are not equal. Because of this, a would-be MITM attacker cannot interfere with the SMP, except to make it fail, because the secret value is never exchanged by the two parties.
How does it work?
As usual, the Wikipedia article on SMP drowns the reader with difficult to read math and does a poor job explaining the basic principle behind SMP. Luckily, there are much better explanations out there. The actual implementation of it is, unfortunately, just as convoluted as the math. The full implementation details can be found in the OTR protocol 3 spec under the SMP section. Below is the basic implementation of the protocol as defined in OTR version 3:
- Alice:
- Picks random exponents a2 and a3
- Sends Bob g2a = g1a2 and g3a = g1a3
- Bob:
- Picks random exponents b2 and b3
- Computes g2b = g1b2 and g3b = g1b3
- Computes g2 = g2ab2 and g3 = g3ab3
- Picks random exponent r
- Computes Pb = g3r and Qb = g1r g2y
- Sends Alice g2b, g3b, Pb and Qb
- Alice:
- Computes g2 = g2ba2 and g3 = g3ba3
- Picks random exponent s
- Computes Pa = g3s and Qa = g1s g2x
- Computes Ra = (Qa / Qb) a3
- Sends Bob Pa, Qa and Ra
- Bob:
- Computes Rb = (Qa / Qb) b3
- Computes Rab = Rab3
- Checks whether Rab == (Pa / Pb)
- Sends Alice Rb
- Alice:
- Computes Rab = Rba3
- Checks whether Rab == (Pa / Pb)
But that’s not all! There’s also data integrity checks each step of the way, but I will defer to the OTR spec as those are just as, if not more, complicated than the basic protocol outlined above.
Show me the code already!
Hold your horses. First, my implementation is not meant to adhere to the OTR protocol 100%. It’s pretty close, but I intend for this to be merged into my own project and used as an example here. If you want an OTR implementation, look elsewhere. If you’re looking for a simple example of exchanging a secret over a network and then checking if that secret was transmitted securely, read on.
Below is a Python implementation of SMP as defined by the OTR spec (somewhat, it’s not exactly OTR). The test program asks for a shared secret and then checks that both secrets are the same with SMP. One final disclaimer, the test program is intended for demonstration purposes only. Never do socket programming like this in a production environment!
smp.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import hashlib
import os
import random
import struct
class SMP(object):
def __init__(self, secret=None):
self.mod = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
self.modOrder = (self.mod-1) / 2
self.gen = 2
self.match = False
if type(secret) is str:
# Encode the string as a hex value
self.secret = int(secret.encode('hex'), 16)
elif type(secret) is int or type(secret) is long:
self.secret = secret
else:
raise TypeError("Secret must be an int or a string. Got type: " + str(type(secret)))
def step1(self):
self.x2 = createRandomExponent()
self.x3 = createRandomExponent()
self.g2 = pow(self.gen, self.x2, self.mod)
self.g3 = pow(self.gen, self.x3, self.mod)
(c1, d1) = self.createLogProof('1', self.x2)
(c2, d2) = self.createLogProof('2', self.x3)
# Send g2a, g3a, c1, d1, c2, d2
return packList(self.g2, self.g3, c1, d1, c2, d2)
def step2(self, buffer):
(g2a, g3a, c1, d1, c2, d2) = unpackList(buffer)
if not self.isValidArgument(g2a) or not self.isValidArgument(g3a):
raise ValueError("Invalid g2a/g3a values")
if not self.checkLogProof('1', g2a, c1, d1):
raise ValueError("Proof 1 check failed")
if not self.checkLogProof('2', g3a, c2, d2):
raise ValueError("Proof 2 check failed")
self.g2a = g2a
self.g3a = g3a
self.x2 = createRandomExponent()
self.x3 = createRandomExponent()
r = createRandomExponent()
self.g2 = pow(self.gen, self.x2, self.mod)
self.g3 = pow(self.gen, self.x3, self.mod)
(c3, d3) = self.createLogProof('3', self.x2)
(c4, d4) = self.createLogProof('4', self.x3)
self.gb2 = pow(self.g2a, self.x2, self.mod)
self.gb3 = pow(self.g3a, self.x3, self.mod)
self.pb = pow(self.gb3, r, self.mod)
self.qb = mulm(pow(self.gen, r, self.mod), pow(self.gb2, self.secret, self.mod), self.mod)
(c5, d5, d6) = self.createCoordsProof('5', self.gb2, self.gb3, r)
# Sends g2b, g3b, pb, qb, all the c's and d's
return packList(self.g2, self.g3, self.pb, self.qb, c3, d3, c4, d4, c5, d5, d6)
def step3(self, buffer):
(g2b, g3b, pb, qb, c3, d3, c4, d4, c5, d5, d6) = unpackList(buffer)
if not self.isValidArgument(g2b) or not self.isValidArgument(g3b) or \
not self.isValidArgument(pb) or not self.isValidArgument(qb):
raise ValueError("Invalid g2b/g3b/pb/qb values")
if not self.checkLogProof('3', g2b, c3, d3):
raise ValueError("Proof 3 check failed")
if not self.checkLogProof('4', g3b, c4, d4):
raise ValueError("Proof 4 check failed")
self.g2b = g2b
self.g3b = g3b
self.ga2 = pow(self.g2b, self.x2, self.mod)
self.ga3 = pow(self.g3b, self.x3, self.mod)
if not self.checkCoordsProof('5', c5, d5, d6, self.ga2, self.ga3, pb, qb):
raise ValueError("Proof 5 check failed")
s = createRandomExponent()
self.qb = qb
self.pb = pb
self.pa = pow(self.ga3, s, self.mod)
self.qa = mulm(pow(self.gen, s, self.mod), pow(self.ga2, self.secret, self.mod), self.mod)
(c6, d7, d8) = self.createCoordsProof('6', self.ga2, self.ga3, s)
inv = self.invm(qb)
self.ra = pow(mulm(self.qa, inv, self.mod), self.x3, self.mod)
(c7, d9) = self.createEqualLogsProof('7', self.qa, inv, self.x3)
# Sends pa, qa, ra, c6, d7, d8, c7, d9
return packList(self.pa, self.qa, self.ra, c6, d7, d8, c7, d9)
def step4(self, buffer):
(pa, qa, ra, c6, d7, d8, c7, d9) = unpackList(buffer)
if not self.isValidArgument(pa) or not self.isValidArgument(qa) or not self.isValidArgument(ra):
raise ValueError("Invalid pa/qa/ra values")
if not self.checkCoordsProof('6', c6, d7, d8, self.gb2, self.gb3, pa, qa):
raise ValueError("Proof 6 check failed")
if not self.checkEqualLogs('7', c7, d9, self.g3a, mulm(qa, self.invm(self.qb), self.mod), ra):
raise ValueError("Proof 7 check failed")
inv = self.invm(self.qb)
rb = pow(mulm(qa, inv, self.mod), self.x3, self.mod)
(c8, d10) = self.createEqualLogsProof('8', qa, inv, self.x3)
rab = pow(ra, self.x3, self.mod)
inv = self.invm(self.pb)
if rab == mulm(pa, inv, self.mod):
self.match = True
# Send rb, c8, d10
return packList(rb, c8, d10)
def step5(self, buffer):
(rb, c8, d10) = unpackList(buffer)
if not self.isValidArgument(rb):
raise ValueError("Invalid rb values")
if not self.checkEqualLogs('8', c8, d10, self.g3b, mulm(self.qa, self.invm(self.qb), self.mod), rb):
raise ValueError("Proof 8 check failed")
rab = pow(rb, self.x3, self.mod)
inv = self.invm(self.pb)
if rab == mulm(self.pa, inv, self.mod):
self.match = True
def createLogProof(self, version, x):
randExponent = createRandomExponent()
c = sha256(version + str(pow(self.gen, randExponent, self.mod)))
d = (randExponent - mulm(x, c, self.modOrder)) % self.modOrder
return (c, d)
def checkLogProof(self, version, g, c, d):
gd = pow(self.gen, d, self.mod)
gc = pow(g, c, self.mod)
gdgc = gd * gc % self.mod
return (sha256(version + str(gdgc)) == c)
def createCoordsProof(self, version, g2, g3, r):
r1 = createRandomExponent()
r2 = createRandomExponent()
tmp1 = pow(g3, r1, self.mod)
tmp2 = mulm(pow(self.gen, r1, self.mod), pow(g2, r2, self.mod), self.mod)
c = sha256(version + str(tmp1) + str(tmp2))
# TODO: make a subm function
d1 = (r1 - mulm(r, c, self.modOrder)) % self.modOrder
d2 = (r2 - mulm(self.secret, c, self.modOrder)) % self.modOrder
return (c, d1, d2)
def checkCoordsProof(self, version, c, d1, d2, g2, g3, p, q):
tmp1 = mulm(pow(g3, d1, self.mod), pow(p, c, self.mod), self.mod)
tmp2 = mulm(mulm(pow(self.gen, d1, self.mod), pow(g2, d2, self.mod), self.mod), pow(q, c, self.mod), self.mod)
cprime = sha256(version + str(tmp1) + str(tmp2))
return (c == cprime)
def createEqualLogsProof(self, version, qa, qb, x):
r = createRandomExponent()
tmp1 = pow(self.gen, r, self.mod)
qab = mulm(qa, qb, self.mod)
tmp2 = pow(qab, r, self.mod)
c = sha256(version + str(tmp1) + str(tmp2))
tmp1 = mulm(x, c, self.modOrder)
d = (r - tmp1) % self.modOrder
return (c, d)
def checkEqualLogs(self, version, c, d, g3, qab, r):
tmp1 = mulm(pow(self.gen, d, self.mod), pow(g3, c, self.mod), self.mod)
tmp2 = mulm(pow(qab, d, self.mod), pow(r, c, self.mod), self.mod)
cprime = sha256(version + str(tmp1) + str(tmp2))
return (c == cprime)
def invm(self, x):
return pow(x, self.mod-2, self.mod)
def isValidArgument(self, val):
return (val >= 2 and val <= self.mod-2)
def packList(*items):
buffer = ''
# For each item in the list, convert it to a byte string and add its length as a prefix
for item in items:
bytes = longToBytes(item)
buffer += struct.pack('!I', len(bytes)) + bytes
return buffer
def unpackList(buffer):
items = []
index = 0
while index < len(buffer):
# Get the length of the long (4 byte int before the actual long)
length = struct.unpack('!I', buffer[index:index+4])[0]
index += 4
# Convert the data back to a long and add it to the list
item = bytesToLong(buffer[index:index+length])
items.append(item)
index += length
return items
def bytesToLong(bytes):
length = len(bytes)
string = 0
for i in range(length):
string += byteToLong(bytes[i:i+1]) << 8*(length-i-1)
return string
def longToBytes(long):
bytes = ''
while long != 0:
bytes = longToByte(long & 0xff) + bytes
long >>= 8
return bytes
def byteToLong(byte):
return struct.unpack('B', byte)[0]
def longToByte(long):
return struct.pack('B', long)
def mulm(x, y, mod):
return x * y % mod
def createRandomExponent():
return random.getrandbits(192*8)
def sha256(message):
return long(hashlib.sha256(str(message)).hexdigest(), 16)
smpTest.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import smp
import socket
import sys
import M2Crypto
# Check command line args
if len(sys.argv) != 2:
print "Usage: %s [IP/listen]" % sys.argv[0]
sys.exit(1)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if sys.argv[1] == 'listen':
# Listen for incoming connections
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('0.0.0.0', 5000))
sock.listen(1)
print "Listening for client"
client = sock.accept()[0]
# Prompt the user for a shared secret to use in SMP
secret = raw_input("Enter shared secret: ")
# Create an SMP object with the calculated secret
smp = smp.SMP(secret)
# Do the SMP protocol
buffer = client.recv(4096)
buffer = smp.step2(buffer)
client.send(buffer)
buffer = client.recv(4096)
buffer = smp.step4(buffer)
client.send(buffer)
else:
# Connect to the server
sock.connect((sys.argv[1], 5000))
# Prompt the user for a shared secret to use in SMP
secret = raw_input("Enter shared secret: ")
# Create an SMP object with the calculated secret
smp = smp.SMP(secret)
# Do the SMP protocol
buffer = smp.step1()
sock.send(buffer)
buffer = sock.recv(4096)
buffer = smp.step3(buffer)
sock.send(buffer)
buffer = sock.recv(4096)
smp.step5(buffer)
# Check if the secrets match
if smp.match:
print "Secrets match"
else:
print "Secrets do not match"
To use it, run one instance as a server and another as a client. Such as:
1
2
$ python smpTest.py listen
$ python smpTest.py localhost # In another terminal
If everything went well, you should see output similar to:
1
2
3
4
$ python smpTest.py listen
Listening for client
Enter shared secret: biscuits
Secrets match
1
2
3
$ python smpTest.py localhost
Enter shared secret: biscuits
Secrets match