diff options
author | Gregory P. Smith <greg@krypto.org> | 2015-12-02 19:13:23 -0800 |
---|---|---|
committer | Gregory P. Smith <greg@krypto.org> | 2015-12-02 19:13:23 -0800 |
commit | 90ab29b4a4f835fd0395228a050a4aa4e02f9cdd (patch) | |
tree | 2d87f8bdd5a338ce9cff6d8b522d93ee62e57250 | |
parent | ff089ad94c37bd865769a93cb86889d8d85cda10 (diff) | |
parent | 48e564ea0b5ef7aeb52a27fe7a3de83b3d72a5a9 (diff) | |
download | portpicker-90ab29b4a4f835fd0395228a050a4aa4e02f9cdd.tar.gz |
Merge pull request #1 from pmarks-net/master
Use both IPv4+IPv6 sockets to check whether a port is free.
-rw-r--r-- | src/portpicker.py | 34 | ||||
-rw-r--r-- | src/portserver.py | 53 | ||||
-rw-r--r-- | src/tests/portpicker_test.py | 48 | ||||
-rw-r--r-- | src/tests/portserver_test.py | 59 |
4 files changed, 160 insertions, 34 deletions
diff --git a/src/portpicker.py b/src/portpicker.py index a77869e..08dde0a 100644 --- a/src/portpicker.py +++ b/src/portpicker.py @@ -55,6 +55,10 @@ def bind(port, socket_type, socket_proto): This is primarily a helper function for PickUnusedPort, used to see if a particular port number is available. + For the port to be considered available, the kernel must support at least + one of (IPv6, IPv4), and the port must be available on each supported + family. + Args: port: The port number to bind to, or 0 to have the OS pick a free port. socket_type: The type of the socket (ex: socket.SOCK_STREAM). @@ -63,15 +67,24 @@ def bind(port, socket_type, socket_proto): Returns: The port number on success or None on failure. """ - sock = socket.socket(socket.AF_INET, socket_type, socket_proto) - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(('', port)) - return sock.getsockname()[1] - except socket.error: - return None - finally: - sock.close() + got_socket = False + for family in (socket.AF_INET6, socket.AF_INET): + try: + sock = socket.socket(family, socket_type, socket_proto) + got_socket = True + except socket.error: + continue + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(('', port)) + if socket_type == socket.SOCK_STREAM: + sock.listen(1) + port = sock.getsockname()[1] + except socket.error: + return None + finally: + sock.close() + return port if got_socket else None Bind = bind # legacy API. pylint: disable=invalid-name @@ -84,8 +97,7 @@ def is_port_free(port): Returns: boolean, whether it is free to use for both TCP and UDP """ - return (bind(port, _PROTOS[0][0], _PROTOS[0][1]) and - bind(port, _PROTOS[1][0], _PROTOS[1][1])) + return bind(port, *_PROTOS[0]) and bind(port, *_PROTOS[1]) IsPortFree = is_port_free # legacy API. pylint: disable=invalid-name diff --git a/src/portserver.py b/src/portserver.py index 54a480f..fcade6c 100644 --- a/src/portserver.py +++ b/src/portserver.py @@ -38,6 +38,9 @@ import sys log = None # Initialized to a logging.Logger by _configure_logging(). +_PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP), + (socket.SOCK_DGRAM, socket.IPPROTO_UDP)] + def _get_process_command_line(pid): try: @@ -55,23 +58,51 @@ def _get_process_start_time(pid): return 0 -def _port_is_available(port): - """Return False if the given network port is currently in use.""" - for socket_type, proto in ((socket.SOCK_STREAM, socket.IPPROTO_TCP), - (socket.SOCK_DGRAM, 0)): - sock = None +# TODO: Consider importing portpicker.bind() instead of duplicating the code. +def _bind(port, socket_type, socket_proto): + """Try to bind to a socket of the specified type, protocol, and port. + + For the port to be considered available, the kernel must support at least + one of (IPv6, IPv4), and the port must be available on each supported + family. + + Args: + port: The port number to bind to, or 0 to have the OS pick a free port. + socket_type: The type of the socket (ex: socket.SOCK_STREAM). + socket_proto: The protocol of the socket (ex: socket.IPPROTO_TCP). + + Returns: + The port number on success or None on failure. + """ + got_socket = False + for family in (socket.AF_INET6, socket.AF_INET): + try: + sock = socket.socket(family, socket_type, socket_proto) + got_socket = True + except socket.error: + continue try: - sock = socket.socket(socket.AF_INET, socket_type, proto) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(('', port)) if socket_type == socket.SOCK_STREAM: sock.listen(1) + port = sock.getsockname()[1] except socket.error: - return False + return None finally: - if sock: - sock.close() - return True + sock.close() + return port if got_socket else None + + +def _is_port_free(port): + """Check if specified port is free. + + Args: + port: integer, port to check + Returns: + boolean, whether it is free to use for both TCP and UDP + """ + return _bind(port, *_PROTOS[0]) and _bind(port, *_PROTOS[1]) def _should_allocate_port(pid): @@ -149,7 +180,7 @@ class _PortPool(object): check_count += 1 if (candidate.start_time == 0 or candidate.start_time != _get_process_start_time(candidate.pid)): - if _port_is_available(candidate.pid): + if _is_port_free(candidate.pid): candidate.pid = pid candidate.start_time = _get_process_start_time(pid) if not candidate.start_time: diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py index daabb41..9e826a6 100644 --- a/src/tests/portpicker_test.py +++ b/src/tests/portpicker_test.py @@ -16,9 +16,11 @@ # """Unittests for the portpicker module.""" +from __future__ import print_function import os import random import socket +import sys import unittest try: @@ -137,6 +139,52 @@ class PickUnusedPortTest(unittest.TestCase): self.assertTrue(self.IsUnusedTCPPort(port)) self.assertTrue(self.IsUnusedUDPPort(port)) + def testIsPortFree(self): + """This might be flaky unless this test is run with a portserver.""" + # The port should be free initially. + port = portpicker.pick_unused_port() + self.assertTrue(portpicker.is_port_free(port)) + + cases = [ + (socket.AF_INET, socket.SOCK_STREAM, None), + (socket.AF_INET6, socket.SOCK_STREAM, 0), + (socket.AF_INET6, socket.SOCK_STREAM, 1), + (socket.AF_INET, socket.SOCK_DGRAM, None), + (socket.AF_INET6, socket.SOCK_DGRAM, 0), + (socket.AF_INET6, socket.SOCK_DGRAM, 1), + ] + for (sock_family, sock_type, v6only) in cases: + # Occupy the port on a subset of possible protocols. + try: + sock = socket.socket(sock_family, sock_type, 0) + except socket.error: + print('Kernel does not support sock_family=%d' % sock_family, + file=sys.stderr) + # Skip this case, since we cannot occupy a port. + continue + if v6only is not None: + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, + v6only) + except socket.error: + print('Kernel does not support IPV6_V6ONLY=%d' % v6only, + file=sys.stderr) + # Don't care; just proceed with the default. + sock.bind(('', port)) + + # The port should be busy. + self.assertFalse(portpicker.is_port_free(port)) + sock.close() + + # Now it's free again. + self.assertTrue(portpicker.is_port_free(port)) + + def testIsPortFreeException(self): + port = portpicker.pick_unused_port() + with mock.patch.object(socket, 'socket') as mock_sock: + mock_sock.side_effect = socket.error('fake socket error', 0) + self.assertFalse(portpicker.is_port_free(port)) + def testThatLegacyCapWordsAPIsExist(self): """The original APIs were CapWords style, 1.1 added PEP8 names.""" self.assertEqual(portpicker.bind, portpicker.Bind) diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py index 2e49595..f0475c3 100644 --- a/src/tests/portserver_test.py +++ b/src/tests/portserver_test.py @@ -16,6 +16,7 @@ # """Tests for the example portserver.""" +from __future__ import print_function import asyncio import os import socket @@ -43,15 +44,49 @@ class PortserverFunctionsTest(unittest.TestCase): def test_get_process_start_time(self): self.assertGreater(portserver._get_process_start_time(os.getpid()), 0) - def test_port_is_available_true(self): + def test_is_port_free(self): """This might be flaky unless this test is run with a portserver.""" - # Insert Inception "we must go deeper" meme here. - self.assertTrue(portserver._port_is_available(self.port)) - - def test_port_is_available_false(self): + # The port should be free initially. + self.assertTrue(portserver._is_port_free(self.port)) + + cases = [ + (socket.AF_INET, socket.SOCK_STREAM, None), + (socket.AF_INET6, socket.SOCK_STREAM, 0), + (socket.AF_INET6, socket.SOCK_STREAM, 1), + (socket.AF_INET, socket.SOCK_DGRAM, None), + (socket.AF_INET6, socket.SOCK_DGRAM, 0), + (socket.AF_INET6, socket.SOCK_DGRAM, 1), + ] + for (sock_family, sock_type, v6only) in cases: + # Occupy the port on a subset of possible protocols. + try: + sock = socket.socket(sock_family, sock_type, 0) + except socket.error: + print('Kernel does not support sock_family=%d' % sock_family, + file=sys.stderr) + # Skip this case, since we cannot occupy a port. + continue + if v6only is not None: + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, + v6only) + except socket.error: + print('Kernel does not support IPV6_V6ONLY=%d' % v6only, + file=sys.stderr) + # Don't care; just proceed with the default. + sock.bind(('', self.port)) + + # The port should be busy. + self.assertFalse(portserver._is_port_free(self.port)) + sock.close() + + # Now it's free again. + self.assertTrue(portserver._is_port_free(self.port)) + + def test_is_port_free_exception(self): with mock.patch.object(socket, 'socket') as mock_sock: mock_sock.side_effect = socket.error('fake socket error', 0) - self.assertFalse(portserver._port_is_available(self.port)) + self.assertFalse(portserver._is_port_free(self.port)) def test_should_allocate_port(self): self.assertFalse(portserver._should_allocate_port(0)) @@ -140,18 +175,18 @@ class PortPoolTest(unittest.TestCase): self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0) self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536) - @mock.patch.object(portserver, '_port_is_available') - def test_get_port_for_process_ok(self, mock_port_is_available): + @mock.patch.object(portserver, '_is_port_free') + def test_get_port_for_process_ok(self, mock_is_port_free): self.pool.add_port_to_free_pool(self.port) - mock_port_is_available.return_value = True + mock_is_port_free.return_value = True self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid())) self.assertEqual(1, self.pool.ports_checked_for_last_request) - @mock.patch.object(portserver, '_port_is_available') - def test_get_port_for_process_none_left(self, mock_port_is_available): + @mock.patch.object(portserver, '_is_port_free') + def test_get_port_for_process_none_left(self, mock_is_port_free): self.pool.add_port_to_free_pool(self.port) self.pool.add_port_to_free_pool(22) - mock_port_is_available.return_value = False + mock_is_port_free.return_value = False self.assertEqual(2, self.pool.num_ports()) self.assertEqual(0, self.pool.get_port_for_process(os.getpid())) self.assertEqual(2, self.pool.num_ports()) |