aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory P. Smith <greg@krypto.org>2015-12-02 19:13:23 -0800
committerGregory P. Smith <greg@krypto.org>2015-12-02 19:13:23 -0800
commit90ab29b4a4f835fd0395228a050a4aa4e02f9cdd (patch)
tree2d87f8bdd5a338ce9cff6d8b522d93ee62e57250
parentff089ad94c37bd865769a93cb86889d8d85cda10 (diff)
parent48e564ea0b5ef7aeb52a27fe7a3de83b3d72a5a9 (diff)
downloadportpicker-90ab29b4a4f835fd0395228a050a4aa4e02f9cdd.tar.gz
Merge pull request #1 from pmarks-net/master
Use both IPv4+IPv6 sockets to check whether a port is free.
-rw-r--r--src/portpicker.py34
-rw-r--r--src/portserver.py53
-rw-r--r--src/tests/portpicker_test.py48
-rw-r--r--src/tests/portserver_test.py59
4 files changed, 160 insertions, 34 deletions
diff --git a/src/portpicker.py b/src/portpicker.py
index a77869e..08dde0a 100644
--- a/src/portpicker.py
+++ b/src/portpicker.py
@@ -55,6 +55,10 @@ def bind(port, socket_type, socket_proto):
This is primarily a helper function for PickUnusedPort, used to see
if a particular port number is available.
+ For the port to be considered available, the kernel must support at least
+ one of (IPv6, IPv4), and the port must be available on each supported
+ family.
+
Args:
port: The port number to bind to, or 0 to have the OS pick a free port.
socket_type: The type of the socket (ex: socket.SOCK_STREAM).
@@ -63,15 +67,24 @@ def bind(port, socket_type, socket_proto):
Returns:
The port number on success or None on failure.
"""
- sock = socket.socket(socket.AF_INET, socket_type, socket_proto)
- try:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind(('', port))
- return sock.getsockname()[1]
- except socket.error:
- return None
- finally:
- sock.close()
+ got_socket = False
+ for family in (socket.AF_INET6, socket.AF_INET):
+ try:
+ sock = socket.socket(family, socket_type, socket_proto)
+ got_socket = True
+ except socket.error:
+ continue
+ try:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(('', port))
+ if socket_type == socket.SOCK_STREAM:
+ sock.listen(1)
+ port = sock.getsockname()[1]
+ except socket.error:
+ return None
+ finally:
+ sock.close()
+ return port if got_socket else None
Bind = bind # legacy API. pylint: disable=invalid-name
@@ -84,8 +97,7 @@ def is_port_free(port):
Returns:
boolean, whether it is free to use for both TCP and UDP
"""
- return (bind(port, _PROTOS[0][0], _PROTOS[0][1]) and
- bind(port, _PROTOS[1][0], _PROTOS[1][1]))
+ return bind(port, *_PROTOS[0]) and bind(port, *_PROTOS[1])
IsPortFree = is_port_free # legacy API. pylint: disable=invalid-name
diff --git a/src/portserver.py b/src/portserver.py
index 54a480f..fcade6c 100644
--- a/src/portserver.py
+++ b/src/portserver.py
@@ -38,6 +38,9 @@ import sys
log = None # Initialized to a logging.Logger by _configure_logging().
+_PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP),
+ (socket.SOCK_DGRAM, socket.IPPROTO_UDP)]
+
def _get_process_command_line(pid):
try:
@@ -55,23 +58,51 @@ def _get_process_start_time(pid):
return 0
-def _port_is_available(port):
- """Return False if the given network port is currently in use."""
- for socket_type, proto in ((socket.SOCK_STREAM, socket.IPPROTO_TCP),
- (socket.SOCK_DGRAM, 0)):
- sock = None
+# TODO: Consider importing portpicker.bind() instead of duplicating the code.
+def _bind(port, socket_type, socket_proto):
+ """Try to bind to a socket of the specified type, protocol, and port.
+
+ For the port to be considered available, the kernel must support at least
+ one of (IPv6, IPv4), and the port must be available on each supported
+ family.
+
+ Args:
+ port: The port number to bind to, or 0 to have the OS pick a free port.
+ socket_type: The type of the socket (ex: socket.SOCK_STREAM).
+ socket_proto: The protocol of the socket (ex: socket.IPPROTO_TCP).
+
+ Returns:
+ The port number on success or None on failure.
+ """
+ got_socket = False
+ for family in (socket.AF_INET6, socket.AF_INET):
+ try:
+ sock = socket.socket(family, socket_type, socket_proto)
+ got_socket = True
+ except socket.error:
+ continue
try:
- sock = socket.socket(socket.AF_INET, socket_type, proto)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', port))
if socket_type == socket.SOCK_STREAM:
sock.listen(1)
+ port = sock.getsockname()[1]
except socket.error:
- return False
+ return None
finally:
- if sock:
- sock.close()
- return True
+ sock.close()
+ return port if got_socket else None
+
+
+def _is_port_free(port):
+ """Check if specified port is free.
+
+ Args:
+ port: integer, port to check
+ Returns:
+ boolean, whether it is free to use for both TCP and UDP
+ """
+ return _bind(port, *_PROTOS[0]) and _bind(port, *_PROTOS[1])
def _should_allocate_port(pid):
@@ -149,7 +180,7 @@ class _PortPool(object):
check_count += 1
if (candidate.start_time == 0 or
candidate.start_time != _get_process_start_time(candidate.pid)):
- if _port_is_available(candidate.pid):
+ if _is_port_free(candidate.pid):
candidate.pid = pid
candidate.start_time = _get_process_start_time(pid)
if not candidate.start_time:
diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py
index daabb41..9e826a6 100644
--- a/src/tests/portpicker_test.py
+++ b/src/tests/portpicker_test.py
@@ -16,9 +16,11 @@
#
"""Unittests for the portpicker module."""
+from __future__ import print_function
import os
import random
import socket
+import sys
import unittest
try:
@@ -137,6 +139,52 @@ class PickUnusedPortTest(unittest.TestCase):
self.assertTrue(self.IsUnusedTCPPort(port))
self.assertTrue(self.IsUnusedUDPPort(port))
+ def testIsPortFree(self):
+ """This might be flaky unless this test is run with a portserver."""
+ # The port should be free initially.
+ port = portpicker.pick_unused_port()
+ self.assertTrue(portpicker.is_port_free(port))
+
+ cases = [
+ (socket.AF_INET, socket.SOCK_STREAM, None),
+ (socket.AF_INET6, socket.SOCK_STREAM, 0),
+ (socket.AF_INET6, socket.SOCK_STREAM, 1),
+ (socket.AF_INET, socket.SOCK_DGRAM, None),
+ (socket.AF_INET6, socket.SOCK_DGRAM, 0),
+ (socket.AF_INET6, socket.SOCK_DGRAM, 1),
+ ]
+ for (sock_family, sock_type, v6only) in cases:
+ # Occupy the port on a subset of possible protocols.
+ try:
+ sock = socket.socket(sock_family, sock_type, 0)
+ except socket.error:
+ print('Kernel does not support sock_family=%d' % sock_family,
+ file=sys.stderr)
+ # Skip this case, since we cannot occupy a port.
+ continue
+ if v6only is not None:
+ try:
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
+ v6only)
+ except socket.error:
+ print('Kernel does not support IPV6_V6ONLY=%d' % v6only,
+ file=sys.stderr)
+ # Don't care; just proceed with the default.
+ sock.bind(('', port))
+
+ # The port should be busy.
+ self.assertFalse(portpicker.is_port_free(port))
+ sock.close()
+
+ # Now it's free again.
+ self.assertTrue(portpicker.is_port_free(port))
+
+ def testIsPortFreeException(self):
+ port = portpicker.pick_unused_port()
+ with mock.patch.object(socket, 'socket') as mock_sock:
+ mock_sock.side_effect = socket.error('fake socket error', 0)
+ self.assertFalse(portpicker.is_port_free(port))
+
def testThatLegacyCapWordsAPIsExist(self):
"""The original APIs were CapWords style, 1.1 added PEP8 names."""
self.assertEqual(portpicker.bind, portpicker.Bind)
diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py
index 2e49595..f0475c3 100644
--- a/src/tests/portserver_test.py
+++ b/src/tests/portserver_test.py
@@ -16,6 +16,7 @@
#
"""Tests for the example portserver."""
+from __future__ import print_function
import asyncio
import os
import socket
@@ -43,15 +44,49 @@ class PortserverFunctionsTest(unittest.TestCase):
def test_get_process_start_time(self):
self.assertGreater(portserver._get_process_start_time(os.getpid()), 0)
- def test_port_is_available_true(self):
+ def test_is_port_free(self):
"""This might be flaky unless this test is run with a portserver."""
- # Insert Inception "we must go deeper" meme here.
- self.assertTrue(portserver._port_is_available(self.port))
-
- def test_port_is_available_false(self):
+ # The port should be free initially.
+ self.assertTrue(portserver._is_port_free(self.port))
+
+ cases = [
+ (socket.AF_INET, socket.SOCK_STREAM, None),
+ (socket.AF_INET6, socket.SOCK_STREAM, 0),
+ (socket.AF_INET6, socket.SOCK_STREAM, 1),
+ (socket.AF_INET, socket.SOCK_DGRAM, None),
+ (socket.AF_INET6, socket.SOCK_DGRAM, 0),
+ (socket.AF_INET6, socket.SOCK_DGRAM, 1),
+ ]
+ for (sock_family, sock_type, v6only) in cases:
+ # Occupy the port on a subset of possible protocols.
+ try:
+ sock = socket.socket(sock_family, sock_type, 0)
+ except socket.error:
+ print('Kernel does not support sock_family=%d' % sock_family,
+ file=sys.stderr)
+ # Skip this case, since we cannot occupy a port.
+ continue
+ if v6only is not None:
+ try:
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
+ v6only)
+ except socket.error:
+ print('Kernel does not support IPV6_V6ONLY=%d' % v6only,
+ file=sys.stderr)
+ # Don't care; just proceed with the default.
+ sock.bind(('', self.port))
+
+ # The port should be busy.
+ self.assertFalse(portserver._is_port_free(self.port))
+ sock.close()
+
+ # Now it's free again.
+ self.assertTrue(portserver._is_port_free(self.port))
+
+ def test_is_port_free_exception(self):
with mock.patch.object(socket, 'socket') as mock_sock:
mock_sock.side_effect = socket.error('fake socket error', 0)
- self.assertFalse(portserver._port_is_available(self.port))
+ self.assertFalse(portserver._is_port_free(self.port))
def test_should_allocate_port(self):
self.assertFalse(portserver._should_allocate_port(0))
@@ -140,18 +175,18 @@ class PortPoolTest(unittest.TestCase):
self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0)
self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536)
- @mock.patch.object(portserver, '_port_is_available')
- def test_get_port_for_process_ok(self, mock_port_is_available):
+ @mock.patch.object(portserver, '_is_port_free')
+ def test_get_port_for_process_ok(self, mock_is_port_free):
self.pool.add_port_to_free_pool(self.port)
- mock_port_is_available.return_value = True
+ mock_is_port_free.return_value = True
self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid()))
self.assertEqual(1, self.pool.ports_checked_for_last_request)
- @mock.patch.object(portserver, '_port_is_available')
- def test_get_port_for_process_none_left(self, mock_port_is_available):
+ @mock.patch.object(portserver, '_is_port_free')
+ def test_get_port_for_process_none_left(self, mock_is_port_free):
self.pool.add_port_to_free_pool(self.port)
self.pool.add_port_to_free_pool(22)
- mock_port_is_available.return_value = False
+ mock_is_port_free.return_value = False
self.assertEqual(2, self.pool.num_ports())
self.assertEqual(0, self.pool.get_port_for_process(os.getpid()))
self.assertEqual(2, self.pool.num_ports())