aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimo Ewalds <timo@ewalds.ca>2017-10-05 19:30:45 +0100
committerGregory P. Smith <greg@krypto.org>2017-10-05 11:30:45 -0700
commit6e8e445c9d1f704c945deb0e5484c2f7f8329532 (patch)
treeae7f38324a5043dc64beb3805e0b5359f3135dc3
parent914f47ef62853a2a72d554c23d8f8d940980d7d9 (diff)
downloadportpicker-6e8e445c9d1f704c945deb0e5484c2f7f8329532.tar.gz
Add a way to reserve ports and return the ports so they get reused. (#7)
This way portpicker can also be used to manage a set of ports you were assigned from somewhere other than a portserver.
-rw-r--r--src/portpicker.py53
-rw-r--r--src/tests/portpicker_test.py47
2 files changed, 93 insertions, 7 deletions
diff --git a/src/portpicker.py b/src/portpicker.py
index 7e194dd..15b6d59 100644
--- a/src/portpicker.py
+++ b/src/portpicker.py
@@ -36,19 +36,50 @@ Typical usage:
"""
from __future__ import print_function
+
+import logging
import os
import random
import socket
import sys
# The legacy Bind, IsPortFree, etc. names are not exported.
-__all__ = ('bind', 'is_port_free', 'pick_unused_port',
- 'get_port_from_port_server')
+__all__ = ('bind', 'is_port_free', 'pick_unused_port', 'return_port',
+ 'add_reserved_port', 'get_port_from_port_server')
_PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP),
(socket.SOCK_DGRAM, socket.IPPROTO_UDP)]
+# Ports that are currently available to be given out.
+_free_ports = set()
+
+# Ports that are reserved or from the portserver that may be returned.
+_owned_ports = set()
+
+# Ports that we chose randomly that may be returned.
+_random_ports = set()
+
+
+def add_reserved_port(port):
+ """Add a port that was acquired by means other than the port server."""
+ _free_ports.add(port)
+
+
+def return_port(port):
+ """Return a port that is no longer being used so it can be reused."""
+ if port in _random_ports:
+ _random_ports.remove(port)
+ elif port in _owned_ports:
+ _owned_ports.remove(port)
+ _free_ports.add(port)
+ elif port in _free_ports:
+ logging.info("Returning a port that was already returned: %s", port)
+ else:
+ logging.info("Returning a port that wasn't given by portpicker: %s",
+ port)
+
+
def bind(port, socket_type, socket_proto):
"""Try to bind to a socket of the specified type, protocol, and port.
@@ -113,14 +144,17 @@ def pick_unused_port(pid=None):
Returns:
A port number that is unused on both TCP and UDP.
"""
- port = None
+ if _free_ports:
+ port = _free_ports.pop()
+ _owned_ports.add(port)
+ return port
# Provide access to the portserver on an opt-in basis.
if 'PORTSERVER_ADDRESS' in os.environ:
port = get_port_from_port_server(os.environ['PORTSERVER_ADDRESS'],
pid=pid)
- if not port:
- return _pick_unused_port_without_server()
- return port
+ if port:
+ return port
+ return _pick_unused_port_without_server()
PickUnusedPort = pick_unused_port # legacy API. pylint: disable=invalid-name
@@ -141,6 +175,7 @@ def _pick_unused_port_without_server(): # Protected. pylint: disable=invalid-na
for _ in range(10):
port = int(rng.randrange(15000, 25000))
if is_port_free(port):
+ _random_ports.add(port)
return port
# Try OS-assigned ports next.
@@ -151,6 +186,7 @@ def _pick_unused_port_without_server(): # Protected. pylint: disable=invalid-na
port = bind(0, _PROTOS[0][0], _PROTOS[0][1])
# Check if this port is unused on the other protocol.
if port and bind(port, _PROTOS[1][0], _PROTOS[1][1]):
+ _random_ports.add(port)
return port
@@ -207,10 +243,13 @@ def get_port_from_port_server(portserver_address, pid=None):
return None
try:
- return int(buf.split(b'\n')[0])
+ port = int(buf.split(b'\n')[0])
except ValueError:
print('Portserver failed to find a port.', file=sys.stderr)
return None
+ _owned_ports.add(port)
+ return port
+
GetPortFromPortServer = get_port_from_port_server # legacy API. pylint: disable=invalid-name
diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py
index 9e826a6..b3924cd 100644
--- a/src/tests/portpicker_test.py
+++ b/src/tests/portpicker_test.py
@@ -42,6 +42,9 @@ class PickUnusedPortTest(unittest.TestCase):
def setUp(self):
# So we can Bind even if portpicker.bind is stubbed out.
self._bind = portpicker.bind
+ portpicker._owned_ports.clear()
+ portpicker._free_ports.clear()
+ portpicker._random_ports.clear()
def testPickUnusedPortActuallyWorks(self):
"""This test can be flaky."""
@@ -92,6 +95,50 @@ class PickUnusedPortTest(unittest.TestCase):
server.sendall.assert_called_once_with(b'9876\n')
self.assertEqual(port, 52768)
+ @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': 'portserver'})
+ def testReusesPortServerPorts(self):
+ server = mock.Mock()
+ server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n']
+ with mock.patch.object(socket, 'socket', return_value=server):
+ self.assertEqual(portpicker.pick_unused_port(), 12345)
+ self.assertEqual(portpicker.pick_unused_port(), 23456)
+ portpicker.return_port(12345)
+ self.assertEqual(portpicker.pick_unused_port(), 12345)
+
+ @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''})
+ def testDoesntReuseRandomPorts(self):
+ ports = set()
+ for _ in range(10):
+ port = portpicker.pick_unused_port()
+ ports.add(port)
+ portpicker.return_port(port)
+ self.assertGreater(len(ports), 5) # Allow some random reuse.
+
+ def testReturnsReservedPorts(self):
+ with mock.patch.object(portpicker, '_pick_unused_port_without_server'):
+ portpicker._pick_unused_port_without_server.side_effect = (
+ Exception('eek!'))
+ # Arbitrary port. In practice you should get this from somewhere
+ # that assigns ports.
+ reserved_port = 28465
+ portpicker.add_reserved_port(reserved_port)
+ ports = set()
+ for _ in range(10):
+ port = portpicker.pick_unused_port()
+ ports.add(port)
+ portpicker.return_port(port)
+ self.assertEqual(len(ports), 1)
+ self.assertEqual(ports.pop(), reserved_port)
+
+ @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''})
+ def testFallsBackToRandomAfterRunningOutOfReservedPorts(self):
+ # Arbitrary port. In practice you should get this from somewhere
+ # that assigns ports.
+ reserved_port = 23456
+ portpicker.add_reserved_port(reserved_port)
+ self.assertEqual(portpicker.pick_unused_port(), reserved_port)
+ self.assertNotEqual(portpicker.pick_unused_port(), reserved_port)
+
def testRandomlyChosenPorts(self):
# Unless this box is under an overwhelming socket load, this test
# will heavily exercise the "pick a port randomly" part of the