From db66ed2d529ec1a5ad08ca36d33db0e47d94ffb5 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Mon, 28 Feb 2022 14:01:18 +0800 Subject: [PATCH] bugfix --- appPublic/across_nat.py | 23 ++++++++++++---- appPublic/rc4.py | 61 +++++++++++++++++++++++++++++------------ test/test_aioupnp.py | 19 ++++++++----- 3 files changed, 73 insertions(+), 30 deletions(-) diff --git a/appPublic/across_nat.py b/appPublic/across_nat.py index 434bac3..2871f73 100644 --- a/appPublic/across_nat.py +++ b/appPublic/across_nat.py @@ -38,14 +38,25 @@ class AcrossNat(object): return get('https://ipapi.co/ip/').text async def upnp_map_port(self, inner_port, - protocol='TCP', from_port=40003): + protocol='TCP', from_port=40003, ip=None, desc=None): + if self.upnp is None: await self.init_upnp() protocol = protocol.upper() + if ip is None: + ip = self.upnp.lan_address + + all_mappings = [i for i in await self.upnp.get_redirects()] + x = [ i for i in all_mappings if i.internal_port == inner_port \ + and i.lan_address == ip \ + and i.protocol == protocol ] + if len(x) > 0: + return x[0].external_port + + occupied_ports = [ i.external_port for i in all_mappings if i.protocol == protocol ] external_port = from_port while external_port < 52333: - x = await self.upnp.get_specific_port_mapping(external_port, protocol) - if len(x) == 0: + if external_port not in occupied_ports: break external_port += 1 @@ -53,7 +64,7 @@ class AcrossNat(object): await self.upnp.add_port_mapping(external_port, protocol, inner_port, - lan_address, + ip, desc or 'user added') return external_port return None @@ -87,9 +98,9 @@ class AcrossNat(object): lifetime=999999999) return x.public_port - async def map_port(self, inner_port, protocol='tcp', from_port=40003): + async def map_port(self, inner_port, protocol='tcp', from_port=40003, lan_ip=None, desc=None): if self.pmp_supported: return self.pmp_map_port(inner_port, protocol=protocol) - return await self.upnp_map_port( inner_port, protocol=protocol) + return await self.upnp_map_port( inner_port, protocol=protocol, ip=lan_ip, desc=desc) diff --git a/appPublic/rc4.py b/appPublic/rc4.py index e2ef31b..41c1df7 100644 --- a/appPublic/rc4.py +++ b/appPublic/rc4.py @@ -62,17 +62,44 @@ class RC4: return r.decode(self.dcoding) class KeyChain(object): - def __init__(self, seed_str, crypter, keylen=23): + def __init__(self, seed_str, crypter=None, keylen=23, period=600, threshold=60): self.seed_str = seed_str + self.period = int(period) + self.threshold = int(threshold) self.crypter = crypter + if crypter is None: + self.crypter = RC4() self.keylen = keylen self.keypool = { } delta = datetime.timedelta(0) self.timezone = datetime.timezone(delta, name='gmt') - def genKey(self, y, m, d): - vv = y * 1000 + m * 100 + d + def is_near_bottom(self, indicator=None): + ts = time.time() + i = indicator + if i is None: + i = self.get_indicator(ts) + if i + self.threshold > ts: + return True + return FalseTrue + + def is_near_top(self, indicator=None): + ts = time.time() + i = indicator + if i is None: + i = self.get_indicator(ts) + if i + self.period - self.threshold < ts: + return True + return False + + def get_indicator(self, ts=None): + if ts is None: + ts = time.time() + return int(ts / self.period) * self.period + + def genKey(self, indicator): + vv = indicator if self.keypool.get(vv): return self.keypool[vv] v = vv @@ -83,13 +110,13 @@ class KeyChain(object): j = v % self.keylen v = v - (j + k1) * m + self.keylen k = k + self.seed_str[j] - k1 += 1 + k1 += self.threshold / 2 key = k.encode('utf-8') self.keypool[vv] = key dates = [ d for d in self.keypool.keys() ] - if len(dates) > 6: - d = min(dates) - del self.keypool[d] + for d in dates: + if d < indicator - self.period: + del self.keypool[d] return key def encode(self, text): @@ -97,8 +124,8 @@ class KeyChain(object): return self.encode_bytes(bdata) def encode_bytes(self, bdata): - dt = datetime.datetime.now(self.timezone) - key = self.genKey(dt.year, dt.month, dt.day) + indicator = self.get_indicator() + key = self.genKey(indicator) data = key + bdata return self.crypter.encode_bytes(data, key) @@ -109,20 +136,20 @@ class KeyChain(object): return None def decode_bytes(self, data): - dt = datetime.datetime.now(self.timezone) - key = self.genKey(dt.year, dt.month, dt.day) + indicator = self.get_indicator() + key = self.genKey(indicator) d = self._decode(data, key) if d is not None: return d - if dt.hour == 0 and dt.minute < 1: - ndt = dt + datetime.timedelta(-1) - key = self.genKey(ndt.year, ndt.month, ndt.day) + if self.is_near_bottom(indicator): + indicator -= self.period + key = self.genKey(indicator) return self._decode(data, key) - if dt.hour ==23 and dt.minute == 59: - ndt = dt + datetime.timedelta(1) - key = self.genKey(ndt.year, ndt.month, ndt.day) + if self.is_near_top(indicator): + indicator += self.period + key = self.genKey(indicator) return self._decode(data, key) return None diff --git a/test/test_aioupnp.py b/test/test_aioupnp.py index a826084..99ba7cb 100644 --- a/test/test_aioupnp.py +++ b/test/test_aioupnp.py @@ -7,20 +7,25 @@ async def main(): print(dir(upnp)) print('gateway=', upnp.gateway, upnp.gateway_address, upnp.lan_address) print(await upnp.get_external_ip()) - print(await upnp.get_redirects()) - - x = await upnp.get_specific_port_mapping(40009, 'TCP') - if len(x) == 0: - print('port available') + port = 40009 + while port < 41000: + x = await upnp.get_specific_port_mapping(40009, 'TCP') + if len(x) == 0: + print(port, 'port available') + break + else: + print(port, 'port occupied') + port += 1 print("adding a port mapping") - x = await upnp.add_port_mapping(40009, 'TCP', 8999, '192.168.1.8', 'test mapping') + x = await upnp.add_port_mapping(port, 'TCP', 8999, '192.168.1.8', 'test mapping') + print(8999, '-->', port) print('x=', x, await upnp.get_redirects()) # print("deleting the port mapping") # await upnp.delete_port_mapping(51234, 'TCP') - print(await upnp.get_redirects()) + # print(await upnp.get_redirects()) asyncio.run(main())