Crypto disclaimer! I am NOT a crypto expert. Don’t take the information here as 100% correct; you should verify it yourself. You are dangerously bad at crypto.


The problem

Man-in-the-middle attacks are a serious problem when designing any cryptographic protocol. Without using a PKI, a common solution is to provide users’ with the fingerprint of exchanged public keys which they should then verify with the other party via another secure channel to ensure there is no MITM. In practice, this is a very poor solution because most users will not check fingeprints and even if they do, they may only compare the first and last few digits of the fingerprint meaning an attacker only need create a public key with the same few first and last digits of the public key they are trying to impersonate.

The solution

There’s no good protection from MITM, but there is a way to exchange secrets without worrying about a MITM without using a PKI and without checking fingerprints. OTR (off-the-record) messaging utilizes the Socialist Millionaire Protocol. In a (very small) nutshell, SMP allows two parties to check if a secret they both hold are equal to one another without revealing the actual secret to one another (or anyone else). If the secrets are not equal, no other information is revealed except that the secrets are not equal. Because of this, a would-be MITM attacker cannot interfere with the SMP, except to make it fail, because the secret value is never exchanged by the two parties.

How does it work?

As usual, the Wikipedia article on SMP drowns the reader with difficult to read math and does a poor job explaining the basic principle behind SMP. Luckily, there are much better explanations out there. The actual implementation of it is, unfortunately, just as convoluted as the math. The full implementation details can be found in the OTR protocol 3 spec under the SMP section. Below is the basic implementation of the protocol as defined in OTR version 3:

But that’s not all! There’s also data integrity checks each step of the way, but I will defer to the OTR spec as those are just as, if not more, complicated than the basic protocol outlined above.

Show me the code already!

Hold your horses. First, my implementation is not meant to adhere to the OTR protocol 100%. It’s pretty close, but I intend for this to be merged into my own project and used as an example here. If you want an OTR implementation, look elsewhere. If you’re looking for a simple example of exchanging a secret over a network and then checking if that secret was transmitted securely, read on.

Below is a Python implementation of SMP as defined by the OTR spec (somewhat, it’s not exactly OTR). The test program asks for a shared secret and then checks that both secrets are the same with SMP. One final disclaimer, the test program is intended for demonstration purposes only. Never do socket programming like this in a production environment!

smp.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import hashlib
import os
import random
import struct

class SMP(object):
    def __init__(self, secret=None):
        self.mod = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
        self.modOrder = (self.mod-1) / 2
        self.gen = 2
        self.match = False

        if type(secret) is str:
            # Encode the string as a hex value
            self.secret = int(secret.encode('hex'), 16)
        elif type(secret) is int or type(secret) is long:
            self.secret = secret
        else:
            raise TypeError("Secret must be an int or a string. Got type: " + str(type(secret)))

    def step1(self):
        self.x2 = createRandomExponent()
        self.x3 = createRandomExponent()

        self.g2 = pow(self.gen, self.x2, self.mod)
        self.g3 = pow(self.gen, self.x3, self.mod)

        (c1, d1) = self.createLogProof('1', self.x2)
        (c2, d2) = self.createLogProof('2', self.x3)

        # Send g2a, g3a, c1, d1, c2, d2
        return packList(self.g2, self.g3, c1, d1, c2, d2)

    def step2(self, buffer):
        (g2a, g3a, c1, d1, c2, d2) = unpackList(buffer)

        if not self.isValidArgument(g2a) or not self.isValidArgument(g3a):
            raise ValueError("Invalid g2a/g3a values")

        if not self.checkLogProof('1', g2a, c1, d1):
            raise ValueError("Proof 1 check failed")

        if not self.checkLogProof('2', g3a, c2, d2):
            raise ValueError("Proof 2 check failed")

        self.g2a = g2a
        self.g3a = g3a

        self.x2 = createRandomExponent()
        self.x3 = createRandomExponent()

        r = createRandomExponent()

        self.g2 = pow(self.gen, self.x2, self.mod)
        self.g3 = pow(self.gen, self.x3, self.mod)

        (c3, d3) = self.createLogProof('3', self.x2)
        (c4, d4) = self.createLogProof('4', self.x3)

        self.gb2 = pow(self.g2a, self.x2, self.mod)
        self.gb3 = pow(self.g3a, self.x3, self.mod)

        self.pb = pow(self.gb3, r, self.mod)
        self.qb = mulm(pow(self.gen, r, self.mod), pow(self.gb2, self.secret, self.mod), self.mod)

        (c5, d5, d6) = self.createCoordsProof('5', self.gb2, self.gb3, r)

        # Sends g2b, g3b, pb, qb, all the c's and d's
        return packList(self.g2, self.g3, self.pb, self.qb, c3, d3, c4, d4, c5, d5, d6)

    def step3(self, buffer):
        (g2b, g3b, pb, qb, c3, d3, c4, d4, c5, d5, d6) = unpackList(buffer)

        if not self.isValidArgument(g2b) or not self.isValidArgument(g3b) or \
           not self.isValidArgument(pb) or not self.isValidArgument(qb):
            raise ValueError("Invalid g2b/g3b/pb/qb values")

        if not self.checkLogProof('3', g2b, c3, d3):
            raise ValueError("Proof 3 check failed")

        if not self.checkLogProof('4', g3b, c4, d4):
            raise ValueError("Proof 4 check failed")

        self.g2b = g2b
        self.g3b = g3b

        self.ga2 = pow(self.g2b, self.x2, self.mod)
        self.ga3 = pow(self.g3b, self.x3, self.mod)

        if not self.checkCoordsProof('5', c5, d5, d6, self.ga2, self.ga3, pb, qb):
            raise ValueError("Proof 5 check failed")

        s = createRandomExponent()

        self.qb = qb
        self.pb = pb
        self.pa = pow(self.ga3, s, self.mod)
        self.qa = mulm(pow(self.gen, s, self.mod), pow(self.ga2, self.secret, self.mod), self.mod)

        (c6, d7, d8) = self.createCoordsProof('6', self.ga2, self.ga3, s)

        inv = self.invm(qb)
        self.ra = pow(mulm(self.qa, inv, self.mod), self.x3, self.mod)

        (c7, d9) = self.createEqualLogsProof('7', self.qa, inv, self.x3)

        # Sends pa, qa, ra, c6, d7, d8, c7, d9
        return packList(self.pa, self.qa, self.ra, c6, d7, d8, c7, d9)

    def step4(self, buffer):
        (pa, qa, ra, c6, d7, d8, c7, d9) = unpackList(buffer)

        if not self.isValidArgument(pa) or not self.isValidArgument(qa) or not self.isValidArgument(ra):
            raise ValueError("Invalid pa/qa/ra values")

        if not self.checkCoordsProof('6', c6, d7, d8, self.gb2, self.gb3, pa, qa):
            raise ValueError("Proof 6 check failed")

        if not self.checkEqualLogs('7', c7, d9, self.g3a, mulm(qa, self.invm(self.qb), self.mod), ra):
            raise ValueError("Proof 7 check failed")

        inv = self.invm(self.qb)
        rb = pow(mulm(qa, inv, self.mod), self.x3, self.mod)

        (c8, d10) = self.createEqualLogsProof('8', qa, inv, self.x3)

        rab = pow(ra, self.x3, self.mod)

        inv = self.invm(self.pb)
        if rab == mulm(pa, inv, self.mod):
            self.match = True

        # Send rb, c8, d10
        return packList(rb, c8, d10)

    def step5(self, buffer):
        (rb, c8, d10) = unpackList(buffer)

        if not self.isValidArgument(rb):
            raise ValueError("Invalid rb values")

        if not self.checkEqualLogs('8', c8, d10, self.g3b, mulm(self.qa, self.invm(self.qb), self.mod), rb):
            raise ValueError("Proof 8 check failed")

        rab = pow(rb, self.x3, self.mod)

        inv = self.invm(self.pb)
        if rab == mulm(self.pa, inv, self.mod):
            self.match = True

    def createLogProof(self, version, x):
        randExponent = createRandomExponent()
        c = sha256(version + str(pow(self.gen, randExponent, self.mod)))
        d = (randExponent - mulm(x, c, self.modOrder)) % self.modOrder
        return (c, d)

    def checkLogProof(self, version, g, c, d):
        gd = pow(self.gen, d, self.mod)
        gc = pow(g, c, self.mod)
        gdgc = gd * gc % self.mod
        return (sha256(version + str(gdgc)) == c)

    def createCoordsProof(self, version, g2, g3, r):
        r1 = createRandomExponent()
        r2 = createRandomExponent()

        tmp1 = pow(g3, r1, self.mod)
        tmp2 = mulm(pow(self.gen, r1, self.mod), pow(g2, r2, self.mod), self.mod)

        c = sha256(version + str(tmp1) + str(tmp2))

        # TODO: make a subm function
        d1 = (r1 - mulm(r, c, self.modOrder)) % self.modOrder
        d2 = (r2 - mulm(self.secret, c, self.modOrder)) % self.modOrder

        return (c, d1, d2)

    def checkCoordsProof(self, version, c, d1, d2, g2, g3, p, q):
        tmp1 = mulm(pow(g3, d1, self.mod), pow(p, c, self.mod), self.mod)

        tmp2 = mulm(mulm(pow(self.gen, d1, self.mod), pow(g2, d2, self.mod), self.mod), pow(q, c, self.mod), self.mod)

        cprime = sha256(version + str(tmp1) + str(tmp2))

        return (c == cprime)

    def createEqualLogsProof(self, version, qa, qb, x):
        r = createRandomExponent()
        tmp1 = pow(self.gen, r, self.mod)
        qab = mulm(qa, qb, self.mod)
        tmp2 = pow(qab, r, self.mod)

        c = sha256(version + str(tmp1) + str(tmp2))
        tmp1 = mulm(x, c, self.modOrder)
        d = (r - tmp1) % self.modOrder

        return (c, d)

    def checkEqualLogs(self, version, c, d, g3, qab, r):
        tmp1 = mulm(pow(self.gen, d, self.mod), pow(g3, c, self.mod), self.mod)

        tmp2 = mulm(pow(qab, d, self.mod), pow(r, c, self.mod), self.mod)

        cprime = sha256(version + str(tmp1) + str(tmp2))
        return (c == cprime)

    def invm(self, x):
        return pow(x, self.mod-2, self.mod)

    def isValidArgument(self, val):
        return (val >= 2 and val <= self.mod-2)

def packList(*items):
    buffer = ''

    # For each item in the list, convert it to a byte string and add its length as a prefix
    for item in items:
        bytes = longToBytes(item)
        buffer += struct.pack('!I', len(bytes)) + bytes

    return buffer

def unpackList(buffer):
    items = []

    index = 0
    while index < len(buffer):
        # Get the length of the long (4 byte int before the actual long)
        length = struct.unpack('!I', buffer[index:index+4])[0]
        index += 4

        # Convert the data back to a long and add it to the list
        item = bytesToLong(buffer[index:index+length])
        items.append(item)
        index += length

    return items

def bytesToLong(bytes):
    length = len(bytes)
    string = 0
    for i in range(length):
        string += byteToLong(bytes[i:i+1]) << 8*(length-i-1)
    return string

def longToBytes(long):
    bytes = ''
    while long != 0:
        bytes = longToByte(long & 0xff) + bytes
        long >>= 8
    return bytes

def byteToLong(byte):
    return struct.unpack('B', byte)[0]

def longToByte(long):
    return struct.pack('B', long)

def mulm(x, y, mod):
    return x * y % mod

def createRandomExponent():
    return random.getrandbits(192*8)

def sha256(message):
    return long(hashlib.sha256(str(message)).hexdigest(), 16)

smpTest.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import smp
import socket
import sys
import M2Crypto

# Check command line args
if len(sys.argv) != 2:
    print "Usage: %s [IP/listen]" % sys.argv[0]
    sys.exit(1)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

if sys.argv[1] == 'listen':
    # Listen for incoming connections
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(('0.0.0.0', 5000))
    sock.listen(1)
    print "Listening for client"
    client = sock.accept()[0]

    # Prompt the user for a shared secret to use in SMP
    secret = raw_input("Enter shared secret: ")

    # Create an SMP object with the calculated secret
    smp = smp.SMP(secret)

    # Do the SMP protocol
    buffer = client.recv(4096)
    buffer = smp.step2(buffer)
    client.send(buffer)

    buffer = client.recv(4096)
    buffer = smp.step4(buffer)
    client.send(buffer)
else:
    # Connect to the server
    sock.connect((sys.argv[1], 5000))

    # Prompt the user for a shared secret to use in SMP
    secret = raw_input("Enter shared secret: ")

    # Create an SMP object with the calculated secret
    smp = smp.SMP(secret)

    # Do the SMP protocol
    buffer = smp.step1()
    sock.send(buffer)

    buffer = sock.recv(4096)
    buffer = smp.step3(buffer)
    sock.send(buffer)

    buffer = sock.recv(4096)
    smp.step5(buffer)

# Check if the secrets match
if smp.match:
    print "Secrets match"
else:
    print "Secrets do not match"

To use it, run one instance as a server and another as a client. Such as:

1
2
$ python smpTest.py listen
$ python smpTest.py localhost # In another terminal

If everything went well, you should see output similar to:

1
2
3
4
$ python smpTest.py listen
Listening for client
Enter shared secret: biscuits
Secrets match
1
2
3
$ python smpTest.py localhost
Enter shared secret: biscuits
Secrets match