diff options
4 files changed, 154 insertions, 112 deletions
diff --git a/src/main/java/com/google/mockwebserver/Dispatcher.java b/src/main/java/com/google/mockwebserver/Dispatcher.java index 0456025..48541a4 100644 --- a/src/main/java/com/google/mockwebserver/Dispatcher.java +++ b/src/main/java/com/google/mockwebserver/Dispatcher.java @@ -26,11 +26,13 @@ public abstract class Dispatcher { public abstract MockResponse dispatch(RecordedRequest request) throws InterruptedException; /** - * Returns the socket policy of the next request. Default implementation - * returns {@link SocketPolicy#KEEP_OPEN}. Mischievous implementations can - * return other values to test HTTP edge cases. + * Returns an early guess of the next response, used for policy on how an + * incoming request should be received. The default implementation returns an + * empty response. Mischievous implementations can return other values to test + * HTTP edge cases, such as unhappy socket policies or throttled request + * bodies. */ - public SocketPolicy peekSocketPolicy() { - return SocketPolicy.KEEP_OPEN; + public MockResponse peek() { + return new MockResponse().setSocketPolicy(SocketPolicy.KEEP_OPEN); } } diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java index 7bca741..665d85a 100644 --- a/src/main/java/com/google/mockwebserver/MockResponse.java +++ b/src/main/java/com/google/mockwebserver/MockResponse.java @@ -16,7 +16,6 @@ package com.google.mockwebserver; -import static com.google.mockwebserver.MockWebServer.ASCII; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -25,6 +24,9 @@ import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.concurrent.TimeUnit; + +import static java.nio.charset.StandardCharsets.US_ASCII; /** * A scripted response to be replayed by the mock web server. @@ -40,9 +42,14 @@ public final class MockResponse implements Cloneable { /** The response body content, or null if {@code body} is set. */ private InputStream bodyStream; - private int bytesPerSecond = Integer.MAX_VALUE; + private int throttleBytesPerPeriod = Integer.MAX_VALUE; + private long throttlePeriod = 1; + private TimeUnit throttleUnit = TimeUnit.SECONDS; + private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN; + private int bodyDelayTimeMs = 0; + /** * Creates a new mock response with an empty body. */ @@ -185,13 +192,13 @@ public final class MockResponse implements Cloneable { int pos = 0; while (pos < body.length) { int chunkSize = Math.min(body.length - pos, maxChunkSize); - bytesOut.write(Integer.toHexString(chunkSize).getBytes(ASCII)); - bytesOut.write("\r\n".getBytes(ASCII)); + bytesOut.write(Integer.toHexString(chunkSize).getBytes(US_ASCII)); + bytesOut.write("\r\n".getBytes(US_ASCII)); bytesOut.write(body, pos, chunkSize); - bytesOut.write("\r\n".getBytes(ASCII)); + bytesOut.write("\r\n".getBytes(US_ASCII)); pos += chunkSize; } - bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf + bytesOut.write("0\r\n\r\n".getBytes(US_ASCII)); // last chunk + empty trailer + crlf this.body = bytesOut.toByteArray(); return this; @@ -221,19 +228,43 @@ public final class MockResponse implements Cloneable { return this; } - public int getBytesPerSecond() { - return bytesPerSecond; + /** + * Throttles the response body writer to sleep for the given period after each + * series of {@code bytesPerPeriod} bytes are written. Use this to simulate + * network behavior. + */ + public MockResponse throttleBody(int bytesPerPeriod, long period, TimeUnit unit) { + this.throttleBytesPerPeriod = bytesPerPeriod; + this.throttlePeriod = period; + this.throttleUnit = unit; + return this; + } + + public int getThrottleBytesPerPeriod() { + return throttleBytesPerPeriod; + } + + public long getThrottlePeriod() { + return throttlePeriod; + } + + public TimeUnit getThrottleUnit() { + return throttleUnit; } /** - * Set simulated network speed, in bytes per second. This applies to the - * response body only; response headers are not throttled. + * Set the delayed time of the response body to {@code delay}. This applies to the + * response body only; response headers are not affected. */ - public MockResponse setBytesPerSecond(int bytesPerSecond) { - this.bytesPerSecond = bytesPerSecond; + public MockResponse setBodyDelayTimeMs(int delay) { + bodyDelayTimeMs = delay; return this; } + public int getBodyDelayTimeMs() { + return bodyDelayTimeMs; + } + @Override public String toString() { return "MockResponse{" + status + "}"; } diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java index afcacc5..13a4597 100644 --- a/src/main/java/com/google/mockwebserver/MockWebServer.java +++ b/src/main/java/com/google/mockwebserver/MockWebServer.java @@ -33,11 +33,14 @@ import java.net.Socket; import java.net.SocketException; import java.net.URL; import java.net.UnknownHostException; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; @@ -58,12 +61,26 @@ import javax.net.ssl.X509TrustManager; * replays them upon request in sequence. */ public final class MockWebServer { + private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { + @Override public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException(); + } + + @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { + throw new AssertionError(); + } - static final String ASCII = "US-ASCII"; + @Override public X509Certificate[] getAcceptedIssuers() { + throw new AssertionError(); + } + }; private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); + private final BlockingQueue<RecordedRequest> requestQueue = new LinkedBlockingQueue<RecordedRequest>(); + /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */ private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>(); private final AtomicInteger requestCount = new AtomicInteger(); @@ -78,7 +95,6 @@ public final class MockWebServer { private int port = -1; private int workerThreads = Integer.MAX_VALUE; - public int getPort() { if (port == -1) { throw new IllegalStateException("Cannot retrieve port before calling play()"); @@ -90,7 +106,7 @@ public final class MockWebServer { try { return InetAddress.getLocalHost().getHostName(); } catch (UnknownHostException e) { - throw new AssertionError(); + throw new AssertionError(e); } } @@ -250,7 +266,7 @@ public final class MockWebServer { } catch (SocketException e) { return; } - final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); + SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); if (socketPolicy == DISCONNECT_AT_START) { dispatchBookkeepingRequest(0, socket); socket.close(); @@ -288,16 +304,20 @@ public final class MockWebServer { if (tunnelProxy) { createTunnel(); } - final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); + SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); if (socketPolicy == FAIL_HANDSHAKE) { dispatchBookkeepingRequest(sequenceNumber, raw); - processHandshakeFailure(raw, sequenceNumber++); + processHandshakeFailure(raw); return; } socket = sslSocketFactory.createSocket( raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); - ((SSLSocket) socket).setUseClientMode(false); + SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.setUseClientMode(false); openClientSockets.put(socket, true); + + sslSocket.startHandshake(); + openClientSockets.remove(raw); } else { socket = raw; @@ -325,13 +345,11 @@ public final class MockWebServer { */ private void createTunnel() throws IOException, InterruptedException { while (true) { - final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); + SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { throw new IllegalStateException("Tunnel without any CONNECT!"); } - if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) { - return; - } + if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return; } } @@ -341,7 +359,7 @@ public final class MockWebServer { */ private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) throws IOException, InterruptedException { - RecordedRequest request = readRequest(socket, in, sequenceNumber); + RecordedRequest request = readRequest(socket, in, out, sequenceNumber); if (request == null) { return false; } @@ -385,21 +403,9 @@ public final class MockWebServer { })); } - private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception { - X509TrustManager untrusted = new X509TrustManager() { - @Override public void checkClientTrusted(X509Certificate[] chain, String authType) - throws CertificateException { - throw new CertificateException(); - } - @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { - throw new AssertionError(); - } - @Override public X509Certificate[] getAcceptedIssuers() { - throw new AssertionError(); - } - }; + private void processHandshakeFailure(Socket raw) throws Exception { SSLContext context = SSLContext.getInstance("TLS"); - context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom()); + context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom()); SSLSocketFactory sslSocketFactory = context.getSocketFactory(); SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); @@ -416,14 +422,11 @@ public final class MockWebServer { RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket); dispatcher.dispatch(request); - requestQueue.add(request); } - /** - * @param sequenceNumber the index of this request on this connection. - */ - private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber) - throws IOException { + /** @param sequenceNumber the index of this request on this connection. */ + private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out, + int sequenceNumber) throws IOException { String request; try { request = readAsciiUntilCrlf(in); @@ -435,27 +438,40 @@ public final class MockWebServer { } List<String> headers = new ArrayList<String>(); - int contentLength = -1; + long contentLength = -1; boolean chunked = false; + boolean expectContinue = false; String header; while ((header = readAsciiUntilCrlf(in)).length() != 0) { headers.add(header); - String lowercaseHeader = header.toLowerCase(); + String lowercaseHeader = header.toLowerCase(Locale.US); if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { - contentLength = Integer.parseInt(header.substring(15).trim()); + contentLength = Long.parseLong(header.substring(15).trim()); } - if (lowercaseHeader.startsWith("transfer-encoding:") && - lowercaseHeader.substring(18).trim().equals("chunked")) { + if (lowercaseHeader.startsWith("transfer-encoding:") + && lowercaseHeader.substring(18).trim().equals("chunked")) { chunked = true; } + if (lowercaseHeader.startsWith("expect:") + && lowercaseHeader.substring(7).trim().equals("100-continue")) { + expectContinue = true; + } + } + + if (expectContinue) { + out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII)); + out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII)); + out.write(("\r\n").getBytes(StandardCharsets.US_ASCII)); + out.flush(); } boolean hasBody = false; TruncatingOutputStream requestBody = new TruncatingOutputStream(); List<Integer> chunkSizes = new ArrayList<Integer>(); + MockResponse throttlePolicy = dispatcher.peek(); if (contentLength != -1) { hasBody = true; - transfer(contentLength, in, requestBody); + throttledTransfer(throttlePolicy, in, requestBody, contentLength); } else if (chunked) { hasBody = true; while (true) { @@ -465,79 +481,75 @@ public final class MockWebServer { break; } chunkSizes.add(chunkSize); - transfer(chunkSize, in, requestBody); + throttledTransfer(throttlePolicy, in, requestBody, chunkSize); readEmptyLine(in); } } - if (request.startsWith("OPTIONS ") || request.startsWith("GET ") - || request.startsWith("HEAD ") || request.startsWith("DELETE ") - || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) { + if (request.startsWith("OPTIONS ") + || request.startsWith("GET ") + || request.startsWith("HEAD ") + || request.startsWith("TRACE ") + || request.startsWith("CONNECT ")) { if (hasBody) { throw new IllegalArgumentException("Request must not have a body: " + request); } - } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) { + } else if (!request.startsWith("POST ") + && !request.startsWith("PUT ") + && !request.startsWith("PATCH ") + && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous. throw new UnsupportedOperationException("Unexpected method: " + request); } - return new RecordedRequest(request, headers, chunkSizes, - requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket); + return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived, + requestBody.toByteArray(), sequenceNumber, socket); } private void writeResponse(OutputStream out, MockResponse response) throws IOException { - out.write((response.getStatus() + "\r\n").getBytes(ASCII)); - for (String header : response.getHeaders()) { - out.write((header + "\r\n").getBytes(ASCII)); + out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII)); + List<String> headers = response.getHeaders(); + for (int i = 0, size = headers.size(); i < size; i++) { + String header = headers.get(i); + out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII)); } - out.write(("\r\n").getBytes(ASCII)); + out.write(("\r\n").getBytes(StandardCharsets.US_ASCII)); out.flush(); - final InputStream in = response.getBodyStream(); - if (in == null) { - return; - } - final int bytesPerSecond = response.getBytesPerSecond(); - - // Stream data in MTU-sized increments - final byte[] buffer = new byte[1452]; - final long delayMs; - if (bytesPerSecond == Integer.MAX_VALUE) { - delayMs = 0; - } else { - delayMs = (1000 * buffer.length) / bytesPerSecond; - } - - int read; - long sinceDelay = 0; - while ((read = in.read(buffer)) != -1) { - out.write(buffer, 0, read); - out.flush(); - - sinceDelay += read; - if (sinceDelay >= buffer.length && delayMs > 0) { - sinceDelay %= buffer.length; - try { - Thread.sleep(delayMs); - } catch (InterruptedException e) { - throw new AssertionError(); - } - } - } + InputStream in = response.getBodyStream(); + if (in == null) return; + throttledTransfer(response, in, out, Long.MAX_VALUE); } /** * Transfer bytes from {@code in} to {@code out} until either {@code length} - * bytes have been transferred or {@code in} is exhausted. + * bytes have been transferred or {@code in} is exhausted. The transfer is + * throttled according to {@code throttlePolicy}. */ - private void transfer(int length, InputStream in, OutputStream out) throws IOException { + private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out, + long limit) throws IOException { byte[] buffer = new byte[1024]; - while (length > 0) { - int count = in.read(buffer, 0, Math.min(buffer.length, length)); - if (count == -1) { - return; + int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod(); + long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod()); + + while (true) { + for (int b = 0; b < bytesPerPeriod; ) { + int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b); + int read = in.read(buffer, 0, toRead); + if (read == -1) return; + + out.write(buffer, 0, read); + out.flush(); + b += read; + limit -= read; + + if (limit == 0) return; + } + + try { + if (delayMs != 0) Thread.sleep(delayMs); + } catch (InterruptedException e) { + throw new AssertionError(); } - out.write(buffer, 0, count); - length -= count; } } diff --git a/src/main/java/com/google/mockwebserver/QueueDispatcher.java b/src/main/java/com/google/mockwebserver/QueueDispatcher.java index bc26694..a95089b 100644 --- a/src/main/java/com/google/mockwebserver/QueueDispatcher.java +++ b/src/main/java/com/google/mockwebserver/QueueDispatcher.java @@ -45,14 +45,11 @@ public class QueueDispatcher extends Dispatcher { return responseQueue.take(); } - @Override public SocketPolicy peekSocketPolicy() { + @Override public MockResponse peek() { MockResponse peek = responseQueue.peek(); - if (peek == null) { - return failFastResponse != null - ? failFastResponse.getSocketPolicy() - : SocketPolicy.KEEP_OPEN; - } - return peek.getSocketPolicy(); + if (peek != null) return peek; + if (failFastResponse != null) return failFastResponse; + return super.peek(); } public void enqueueResponse(MockResponse response) { |