diff options
author | Timo Ewalds <timo@ewalds.ca> | 2017-10-05 19:30:45 +0100 |
---|---|---|
committer | Gregory P. Smith <greg@krypto.org> | 2017-10-05 11:30:45 -0700 |
commit | 6e8e445c9d1f704c945deb0e5484c2f7f8329532 (patch) | |
tree | ae7f38324a5043dc64beb3805e0b5359f3135dc3 | |
parent | 914f47ef62853a2a72d554c23d8f8d940980d7d9 (diff) | |
download | portpicker-6e8e445c9d1f704c945deb0e5484c2f7f8329532.tar.gz |
Add a way to reserve ports and return the ports so they get reused. (#7)
This way portpicker can also be used to manage a set of ports you were assigned from somewhere other than a portserver.
-rw-r--r-- | src/portpicker.py | 53 | ||||
-rw-r--r-- | src/tests/portpicker_test.py | 47 |
2 files changed, 93 insertions, 7 deletions
diff --git a/src/portpicker.py b/src/portpicker.py index 7e194dd..15b6d59 100644 --- a/src/portpicker.py +++ b/src/portpicker.py @@ -36,19 +36,50 @@ Typical usage: """ from __future__ import print_function + +import logging import os import random import socket import sys # The legacy Bind, IsPortFree, etc. names are not exported. -__all__ = ('bind', 'is_port_free', 'pick_unused_port', - 'get_port_from_port_server') +__all__ = ('bind', 'is_port_free', 'pick_unused_port', 'return_port', + 'add_reserved_port', 'get_port_from_port_server') _PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP), (socket.SOCK_DGRAM, socket.IPPROTO_UDP)] +# Ports that are currently available to be given out. +_free_ports = set() + +# Ports that are reserved or from the portserver that may be returned. +_owned_ports = set() + +# Ports that we chose randomly that may be returned. +_random_ports = set() + + +def add_reserved_port(port): + """Add a port that was acquired by means other than the port server.""" + _free_ports.add(port) + + +def return_port(port): + """Return a port that is no longer being used so it can be reused.""" + if port in _random_ports: + _random_ports.remove(port) + elif port in _owned_ports: + _owned_ports.remove(port) + _free_ports.add(port) + elif port in _free_ports: + logging.info("Returning a port that was already returned: %s", port) + else: + logging.info("Returning a port that wasn't given by portpicker: %s", + port) + + def bind(port, socket_type, socket_proto): """Try to bind to a socket of the specified type, protocol, and port. @@ -113,14 +144,17 @@ def pick_unused_port(pid=None): Returns: A port number that is unused on both TCP and UDP. """ - port = None + if _free_ports: + port = _free_ports.pop() + _owned_ports.add(port) + return port # Provide access to the portserver on an opt-in basis. if 'PORTSERVER_ADDRESS' in os.environ: port = get_port_from_port_server(os.environ['PORTSERVER_ADDRESS'], pid=pid) - if not port: - return _pick_unused_port_without_server() - return port + if port: + return port + return _pick_unused_port_without_server() PickUnusedPort = pick_unused_port # legacy API. pylint: disable=invalid-name @@ -141,6 +175,7 @@ def _pick_unused_port_without_server(): # Protected. pylint: disable=invalid-na for _ in range(10): port = int(rng.randrange(15000, 25000)) if is_port_free(port): + _random_ports.add(port) return port # Try OS-assigned ports next. @@ -151,6 +186,7 @@ def _pick_unused_port_without_server(): # Protected. pylint: disable=invalid-na port = bind(0, _PROTOS[0][0], _PROTOS[0][1]) # Check if this port is unused on the other protocol. if port and bind(port, _PROTOS[1][0], _PROTOS[1][1]): + _random_ports.add(port) return port @@ -207,10 +243,13 @@ def get_port_from_port_server(portserver_address, pid=None): return None try: - return int(buf.split(b'\n')[0]) + port = int(buf.split(b'\n')[0]) except ValueError: print('Portserver failed to find a port.', file=sys.stderr) return None + _owned_ports.add(port) + return port + GetPortFromPortServer = get_port_from_port_server # legacy API. pylint: disable=invalid-name diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py index 9e826a6..b3924cd 100644 --- a/src/tests/portpicker_test.py +++ b/src/tests/portpicker_test.py @@ -42,6 +42,9 @@ class PickUnusedPortTest(unittest.TestCase): def setUp(self): # So we can Bind even if portpicker.bind is stubbed out. self._bind = portpicker.bind + portpicker._owned_ports.clear() + portpicker._free_ports.clear() + portpicker._random_ports.clear() def testPickUnusedPortActuallyWorks(self): """This test can be flaky.""" @@ -92,6 +95,50 @@ class PickUnusedPortTest(unittest.TestCase): server.sendall.assert_called_once_with(b'9876\n') self.assertEqual(port, 52768) + @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': 'portserver'}) + def testReusesPortServerPorts(self): + server = mock.Mock() + server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n'] + with mock.patch.object(socket, 'socket', return_value=server): + self.assertEqual(portpicker.pick_unused_port(), 12345) + self.assertEqual(portpicker.pick_unused_port(), 23456) + portpicker.return_port(12345) + self.assertEqual(portpicker.pick_unused_port(), 12345) + + @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''}) + def testDoesntReuseRandomPorts(self): + ports = set() + for _ in range(10): + port = portpicker.pick_unused_port() + ports.add(port) + portpicker.return_port(port) + self.assertGreater(len(ports), 5) # Allow some random reuse. + + def testReturnsReservedPorts(self): + with mock.patch.object(portpicker, '_pick_unused_port_without_server'): + portpicker._pick_unused_port_without_server.side_effect = ( + Exception('eek!')) + # Arbitrary port. In practice you should get this from somewhere + # that assigns ports. + reserved_port = 28465 + portpicker.add_reserved_port(reserved_port) + ports = set() + for _ in range(10): + port = portpicker.pick_unused_port() + ports.add(port) + portpicker.return_port(port) + self.assertEqual(len(ports), 1) + self.assertEqual(ports.pop(), reserved_port) + + @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''}) + def testFallsBackToRandomAfterRunningOutOfReservedPorts(self): + # Arbitrary port. In practice you should get this from somewhere + # that assigns ports. + reserved_port = 23456 + portpicker.add_reserved_port(reserved_port) + self.assertEqual(portpicker.pick_unused_port(), reserved_port) + self.assertNotEqual(portpicker.pick_unused_port(), reserved_port) + def testRandomlyChosenPorts(self): # Unless this box is under an overwhelming socket load, this test # will heavily exercise the "pick a port randomly" part of the |