Skip to content

Commit af46629

Browse files
author
clowwindy
committed
fix salsa20
1 parent 9a18997 commit af46629

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

shadowsocks/encrypt_salsa20.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import struct
55
import logging
66
import sys
7+
import encrypt
78

89
slow_xor = False
910
imported = False
@@ -72,37 +73,55 @@ def update(self, data):
7273
cur_data = data[:remain]
7374
cur_data_len = len(cur_data)
7475
cur_stream = self._stream[self._pos:self._pos + cur_data_len]
75-
self._pos = (self._pos + cur_data_len) % BLOCK_SIZE
76+
self._pos = self._pos + cur_data_len
7677
data = data[remain:]
7778

7879
results.append(numpy_xor(cur_data, cur_stream))
7980

81+
if self._pos >= BLOCK_SIZE:
82+
self._next_stream()
83+
self._pos -= BLOCK_SIZE
84+
assert self._pos == 0
8085
if not data:
8186
break
82-
self._next_stream()
8387
return ''.join(results)
8488

8589

8690
def test():
8791
from os import urandom
8892
import random
8993

90-
rounds = 1 * 10
94+
rounds = 1 * 1024
9195
plain = urandom(BLOCK_SIZE * rounds)
96+
import M2Crypto.EVP
97+
cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1,
98+
key_as_bytes=0, d='md5', salt=None, i=1,
99+
padding=1)
100+
decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0,
101+
key_as_bytes=0, d='md5', salt=None, i=1,
102+
padding=1)
103+
92104
cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
93105
decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
94106
results = []
95107
pos = 0
96108
print 'start'
97109
start = time.time()
98110
while pos < len(plain):
99-
l = random.randint(10000, 32768)
111+
l = random.randint(100, 16384)
100112
c = cipher.update(plain[pos:pos + l])
101-
results.append(decipher.update(c))
113+
results.append(c)
114+
pos += l
115+
pos = 0
116+
c = ''.join(results)
117+
results = []
118+
while pos < len(plain):
119+
l = random.randint(100, 16384)
120+
results.append(decipher.update(c[pos:pos + l]))
102121
pos += l
103-
assert ''.join(results) == plain
104122
end = time.time()
105123
print BLOCK_SIZE * rounds / (end - start)
124+
assert ''.join(results) == plain
106125

107126

108127
if __name__ == '__main__':

0 commit comments

Comments
 (0)