diff options
Diffstat (limited to 'src/tests/portserver_test.py')
-rw-r--r-- | src/tests/portserver_test.py | 370 |
1 files changed, 370 insertions, 0 deletions
diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py new file mode 100644 index 0000000..b7de094 --- /dev/null +++ b/src/tests/portserver_test.py @@ -0,0 +1,370 @@ +#!/usr/bin/python3 +# +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for the example portserver.""" + +import asyncio +import os +import signal +import socket +import subprocess +import sys +import time +import unittest +from unittest import mock +from multiprocessing import Process + +import portpicker + +# On Windows, portserver.py is located in the "Scripts" folder, which isn't +# added to the import path by default +if sys.platform == 'win32': + sys.path.append(os.path.join(os.path.split(sys.executable)[0])) + +import portserver + + +def setUpModule(): + portserver._configure_logging(verbose=True) + +def exit_immediately(): + os._exit(0) + +class PortserverFunctionsTest(unittest.TestCase): + + @classmethod + def setUp(cls): + cls.port = portpicker.PickUnusedPort() + + def test_get_process_command_line(self): + portserver._get_process_command_line(os.getpid()) + + def test_get_process_start_time(self): + self.assertGreater(portserver._get_process_start_time(os.getpid()), 0) + + def test_is_port_free(self): + """This might be flaky unless this test is run with a portserver.""" + # The port should be free initially. + self.assertTrue(portserver._is_port_free(self.port)) + + cases = [ + (socket.AF_INET, socket.SOCK_STREAM, None), + (socket.AF_INET6, socket.SOCK_STREAM, 1), + (socket.AF_INET, socket.SOCK_DGRAM, None), + (socket.AF_INET6, socket.SOCK_DGRAM, 1), + ] + + # Using v6only=0 on Windows doesn't result in collisions + if sys.platform != 'win32': + cases.extend([ + (socket.AF_INET6, socket.SOCK_STREAM, 0), + (socket.AF_INET6, socket.SOCK_DGRAM, 0), + ]) + + for (sock_family, sock_type, v6only) in cases: + # Occupy the port on a subset of possible protocols. + try: + sock = socket.socket(sock_family, sock_type, 0) + except socket.error: + print('Kernel does not support sock_family=%d' % sock_family, + file=sys.stderr) + # Skip this case, since we cannot occupy a port. + continue + + if not hasattr(socket, 'IPPROTO_IPV6'): + v6only = None + + if v6only is not None: + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, + v6only) + except socket.error: + print('Kernel does not support IPV6_V6ONLY=%d' % v6only, + file=sys.stderr) + # Don't care; just proceed with the default. + sock.bind(('', self.port)) + + # The port should be busy. + self.assertFalse(portserver._is_port_free(self.port)) + sock.close() + + # Now it's free again. + self.assertTrue(portserver._is_port_free(self.port)) + + def test_is_port_free_exception(self): + with mock.patch.object(socket, 'socket') as mock_sock: + mock_sock.side_effect = socket.error('fake socket error', 0) + self.assertFalse(portserver._is_port_free(self.port)) + + def test_should_allocate_port(self): + self.assertFalse(portserver._should_allocate_port(0)) + self.assertFalse(portserver._should_allocate_port(1)) + self.assertTrue(portserver._should_allocate_port, os.getpid()) + + p = Process(target=exit_immediately) + p.start() + child_pid = p.pid + p.join() + + # This test assumes that after waitpid returns the kernel has finished + # cleaning the process. We also assume that the kernel will not reuse + # the former child's pid before our next call checks for its existence. + # Likely assumptions, but not guaranteed. + self.assertFalse(portserver._should_allocate_port(child_pid)) + + def test_parse_command_line(self): + with mock.patch.object( + sys, 'argv', ['program_name', '--verbose', + '--portserver_static_pool=1-1,3-8', + '--portserver_unix_socket_address=@hello-test']): + portserver._parse_command_line() + + def test_parse_port_ranges(self): + self.assertFalse(portserver._parse_port_ranges('')) + self.assertCountEqual(portserver._parse_port_ranges('1-1'), {1}) + self.assertCountEqual(portserver._parse_port_ranges('1-1,3-8,375-378'), + {1, 3, 4, 5, 6, 7, 8, 375, 376, 377, 378}) + # Unparsable parts are logged but ignored. + self.assertEqual({1, 2}, + portserver._parse_port_ranges('1-2,not,numbers')) + self.assertEqual(set(), portserver._parse_port_ranges('8080-8081x')) + # Port ranges that go out of bounds are logged but ignored. + self.assertEqual(set(), portserver._parse_port_ranges('0-1138')) + self.assertEqual(set(range(19, 84 + 1)), + portserver._parse_port_ranges('1138-65536,19-84')) + + def test_configure_logging(self): + """Just code coverage really.""" + portserver._configure_logging(False) + portserver._configure_logging(True) + + + _test_socket_addr = f'@TST-{os.getpid()}' + + @mock.patch.object( + sys, 'argv', ['PortserverFunctionsTest.test_main', + f'--portserver_unix_socket_address={_test_socket_addr}'] + ) + @mock.patch.object(portserver, '_parse_port_ranges') + def test_main_no_ports(self, *unused_mocks): + portserver._parse_port_ranges.return_value = set() + with self.assertRaises(SystemExit): + portserver.main() + + @unittest.skipUnless(sys.executable, 'Requires a stand alone interpreter') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'AF_UNIX required') + def test_portserver_binary(self): + """Launch python portserver.py and test it.""" + # Blindly assuming tree layout is src/tests/portserver_test.py + # with src/portserver.py. + portserver_py = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + 'portserver.py') + anon_addr = self._test_socket_addr.replace('@', '\0') + + conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + with self.assertRaises( + ConnectionRefusedError, + msg=f'{self._test_socket_addr} should not listen yet.'): + conn.connect(anon_addr) + conn.close() + + server = subprocess.Popen( + [sys.executable, portserver_py, + f'--portserver_unix_socket_address={self._test_socket_addr}'], + stderr=subprocess.PIPE, + ) + try: + # Wait a few seconds for the server to start listening. + start_time = time.monotonic() + while True: + time.sleep(0.05) + try: + conn.connect(anon_addr) + conn.close() + except ConnectionRefusedError: + delta = time.monotonic() - start_time + if delta < 4: + continue + else: + server.kill() + self.fail('Failed to connect to portserver ' + f'{self._test_socket_addr} within ' + f'{delta} seconds. STDERR:\n' + + server.stderr.read().decode('utf-8')) + else: + break + + ports = set() + port = portpicker.get_port_from_port_server( + portserver_address=self._test_socket_addr) + ports.add(port) + port = portpicker.get_port_from_port_server( + portserver_address=self._test_socket_addr) + ports.add(port) + + with subprocess.Popen('exit 0', shell=True) as quick_process: + quick_process.wait() + # This process doesn't exist so it should be a denied alloc. + # We use the pid from the above quick_process under the assumption + # that most OSes try to avoid rapid pid recycling. + denied_port = portpicker.get_port_from_port_server( + portserver_address=self._test_socket_addr, + pid=quick_process.pid) # A now unused pid. + self.assertIsNone(denied_port) + + self.assertEqual(len(ports), 2, msg=ports) + + # Check statistics from portserver + server.send_signal(signal.SIGUSR1) + # TODO implement an I/O timeout + for line in server.stderr: + if b'denied-allocations ' in line: + denied_allocations = int( + line.split(b'denied-allocations ', 2)[1]) + self.assertEqual(1, denied_allocations, msg=line) + elif b'total-allocations ' in line: + total_allocations = int( + line.split(b'total-allocations ', 2)[1]) + self.assertEqual(2, total_allocations, msg=line) + break + + rejected_port = portpicker.get_port_from_port_server( + portserver_address=self._test_socket_addr, + pid=99999999999999999999999999999999999) # Out of range. + self.assertIsNone(rejected_port) + + # Done. shutdown gracefully. + server.send_signal(signal.SIGINT) + server.communicate(timeout=2) + finally: + server.kill() + server.wait() + + +class PortPoolTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.port = portpicker.PickUnusedPort() + + def setUp(self): + self.pool = portserver._PortPool() + + def test_initialization(self): + self.assertEqual(0, self.pool.num_ports()) + self.pool.add_port_to_free_pool(self.port) + self.assertEqual(1, self.pool.num_ports()) + self.pool.add_port_to_free_pool(1138) + self.assertEqual(2, self.pool.num_ports()) + self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 0) + self.assertRaises(ValueError, self.pool.add_port_to_free_pool, 65536) + + @mock.patch.object(portserver, '_is_port_free') + def test_get_port_for_process_ok(self, mock_is_port_free): + self.pool.add_port_to_free_pool(self.port) + mock_is_port_free.return_value = True + self.assertEqual(self.port, self.pool.get_port_for_process(os.getpid())) + self.assertEqual(1, self.pool.ports_checked_for_last_request) + + @mock.patch.object(portserver, '_is_port_free') + def test_get_port_for_process_none_left(self, mock_is_port_free): + self.pool.add_port_to_free_pool(self.port) + self.pool.add_port_to_free_pool(22) + mock_is_port_free.return_value = False + self.assertEqual(2, self.pool.num_ports()) + self.assertEqual(0, self.pool.get_port_for_process(os.getpid())) + self.assertEqual(2, self.pool.num_ports()) + self.assertEqual(2, self.pool.ports_checked_for_last_request) + + @mock.patch.object(portserver, '_is_port_free') + @mock.patch.object(os, 'getpid') + def test_get_port_for_process_pid_eq_port(self, mock_getpid, mock_is_port_free): + self.pool.add_port_to_free_pool(12345) + self.pool.add_port_to_free_pool(12344) + mock_is_port_free.side_effect = lambda port: port == os.getpid() + mock_getpid.return_value = 12345 + self.assertEqual(2, self.pool.num_ports()) + self.assertEqual(12345, self.pool.get_port_for_process(os.getpid())) + self.assertEqual(2, self.pool.ports_checked_for_last_request) + + @mock.patch.object(portserver, '_is_port_free') + @mock.patch.object(os, 'getpid') + def test_get_port_for_process_pid_ne_port(self, mock_getpid, mock_is_port_free): + self.pool.add_port_to_free_pool(12344) + self.pool.add_port_to_free_pool(12345) + mock_is_port_free.side_effect = lambda port: port != os.getpid() + mock_getpid.return_value = 12345 + self.assertEqual(2, self.pool.num_ports()) + self.assertEqual(12344, self.pool.get_port_for_process(os.getpid())) + self.assertEqual(2, self.pool.ports_checked_for_last_request) + + +@mock.patch.object(portserver, '_get_process_command_line') +@mock.patch.object(portserver, '_should_allocate_port') +@mock.patch.object(portserver._PortPool, 'get_port_for_process') +class PortServerRequestHandlerTest(unittest.TestCase): + def setUp(self): + portserver._configure_logging(verbose=True) + self.rh = portserver._PortServerRequestHandler([23, 42, 54]) + + def test_stats_reporting(self, *unused_mocks): + with mock.patch.object(portserver, 'log') as mock_logger: + self.rh.dump_stats() + mock_logger.info.assert_called_with('total-allocations 0') + + def test_handle_port_request_bad_data(self, *unused_mocks): + self._test_bad_data_from_client(b'') + self._test_bad_data_from_client(b'\n') + self._test_bad_data_from_client(b'99Z\n') + self._test_bad_data_from_client(b'99 8\n') + self.assertEqual([], portserver._get_process_command_line.mock_calls) + + def _test_bad_data_from_client(self, data): + mock_writer = mock.Mock(asyncio.StreamWriter) + self.rh._handle_port_request(data, mock_writer) + self.assertFalse(portserver._should_allocate_port.mock_calls) + + def test_handle_port_request_denied_allocation(self, *unused_mocks): + portserver._should_allocate_port.return_value = False + self.assertEqual(0, self.rh._denied_allocations) + mock_writer = mock.Mock(asyncio.StreamWriter) + self.rh._handle_port_request(b'5\n', mock_writer) + self.assertEqual(1, self.rh._denied_allocations) + + def test_handle_port_request_bad_port_returned(self, *unused_mocks): + portserver._should_allocate_port.return_value = True + self.rh._port_pool.get_port_for_process.return_value = 0 + mock_writer = mock.Mock(asyncio.StreamWriter) + self.rh._handle_port_request(b'6\n', mock_writer) + self.rh._port_pool.get_port_for_process.assert_called_once_with(6) + self.assertEqual(1, self.rh._denied_allocations) + + def test_handle_port_request_success(self, *unused_mocks): + portserver._should_allocate_port.return_value = True + self.rh._port_pool.get_port_for_process.return_value = 999 + mock_writer = mock.Mock(asyncio.StreamWriter) + self.assertEqual(0, self.rh._total_allocations) + self.rh._handle_port_request(b'8', mock_writer) + portserver._should_allocate_port.assert_called_once_with(8) + self.rh._port_pool.get_port_for_process.assert_called_once_with(8) + self.assertEqual(1, self.rh._total_allocations) + self.assertEqual(0, self.rh._denied_allocations) + mock_writer.write.assert_called_once_with(b'999\n') + + +if __name__ == '__main__': + unittest.main() |