aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTony An <40644135+tonyjongyoonan@users.noreply.github.com>2023-07-06 10:03:08 -0700
committerGitHub <noreply@github.com>2023-07-06 10:03:08 -0700
commit0b53dd7304365e20e4523ab0cc3cd113cfae8133 (patch)
tree40324f49219bf9c1f6287708e431acbaf34ae705
parent361616ae7ce081fd22a2f2e4d4a2ab34cc4e954a (diff)
downloadgrpc-grpc-java-0b53dd7304365e20e4523ab0cc3cd113cfae8133.tar.gz
implemented and tested static stride scheduler for weighted round robin load balancing policy (#10272)
-rw-r--r--xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java216
-rw-r--r--xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java301
2 files changed, 378 insertions, 139 deletions
diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
index 48442a84b..d5d8c4d9e 100644
--- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
@@ -44,10 +44,10 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -120,7 +120,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
public void run() {
if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
- ((WeightedRoundRobinPicker)currentPicker).updateWeight();
+ ((WeightedRoundRobinPicker) currentPicker).updateWeight();
}
weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
TimeUnit.NANOSECONDS, timeService);
@@ -258,7 +258,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
new HashMap<>();
private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty;
- private volatile EdfScheduler scheduler;
+ private volatile StaticStrideScheduler scheduler;
WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
float errorUtilizationPenalty) {
@@ -279,7 +279,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
Subchannel subchannel = list.get(scheduler.pick());
if (!enableOobLoadReport) {
return PickResult.withSubchannel(subchannel,
- OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
+ OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannelToReportListenerMap.getOrDefault(subchannel,
((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty))));
} else {
@@ -288,26 +288,14 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
private void updateWeight() {
- int weightedChannelCount = 0;
- double avgWeight = 0;
- for (Subchannel value : list) {
- double newWeight = ((WrrSubchannel) value).getWeight();
- if (newWeight > 0) {
- avgWeight += newWeight;
- weightedChannelCount++;
- }
- }
- EdfScheduler scheduler = new EdfScheduler(list.size(), random);
- if (weightedChannelCount >= 1) {
- avgWeight /= 1.0 * weightedChannelCount;
- } else {
- avgWeight = 1;
- }
+ float[] newWeights = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
double newWeight = subchannel.getWeight();
- scheduler.add(i, newWeight > 0 ? newWeight : avgWeight);
+ newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
+
+ StaticStrideScheduler scheduler = new StaticStrideScheduler(newWeights, random);
this.scheduler = scheduler;
}
@@ -340,111 +328,125 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
}
- /**
- * The earliest deadline first implementation in which each object is
- * chosen deterministically and periodically with frequency proportional to its weight.
- *
- * <p>Specifically, each object added to chooser is given a deadline equal to the multiplicative
- * inverse of its weight. The place of each object in its deadline is tracked, and each call to
- * choose returns the object with the least remaining time in its deadline.
- * (Ties are broken by the order in which the children were added to the chooser.) The deadline
- * advances by the multiplicative inverse of the object's weight.
- * For example, if items A and B are added with weights 0.5 and 0.2, successive chooses return:
- *
- * <ul>
- * <li>In the first call, the deadlines are A=2 (1/0.5) and B=5 (1/0.2), so A is returned.
- * The deadline of A is updated to 4.
- * <li>Next, the remaining deadlines are A=4 and B=5, so A is returned. The deadline of A (2) is
- * updated to A=6.
- * <li>Remaining deadlines are A=6 and B=5, so B is returned. The deadline of B is updated with
- * with B=10.
- * <li>Remaining deadlines are A=6 and B=10, so A is returned. The deadline of A is updated with
- * A=8.
- * <li>Remaining deadlines are A=8 and B=10, so A is returned. The deadline of A is updated with
- * A=10.
- * <li>Remaining deadlines are A=10 and B=10, so A is returned. The deadline of A is updated
- * with A=12.
- * <li>Remaining deadlines are A=12 and B=10, so B is returned. The deadline of B is updated
- * with B=15.
- * <li>etc.
- * </ul>
- *
- * <p>In short: the entry with the highest weight is preferred.
+ /*
+ * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
+ * in which each object's deadline is the multiplicative inverse of the object's weight.
+ * <p>
+ * The way in which this is implemented is through a static stride scheduler.
+ * The Static Stride Scheduler works by iterating through the list of subchannel weights
+ * and using modular arithmetic to proportionally distribute picks, favoring entries
+ * with higher weights. It is based on the observation that the intended sequence generated
+ * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic.
+ * The Static Stride Scheduler is more performant than other implementations of the EDF
+ * Scheduler, as it removes the need for a priority queue (and thus mutex locks).
+ * <p>
+ * go/static-stride-scheduler
+ * <p>
*
* <ul>
- * <li>add() - O(lg n)
- * <li>pick() - O(lg n)
- * </ul>
- *
+ * <li>nextSequence() - O(1)
+ * <li>pick() - O(n)
*/
@VisibleForTesting
- static final class EdfScheduler {
- private final PriorityQueue<ObjectState> prioQueue;
-
- /**
- * Weights below this value will be upped to this minimum weight.
- */
- private static final double MINIMUM_WEIGHT = 0.0001;
-
- private final Object lock = new Object();
+ static final class StaticStrideScheduler {
+ private final short[] scaledWeights;
+ private final int sizeDivisor;
+ private final AtomicInteger sequence;
+ private static final int K_MAX_WEIGHT = 0xFFFF;
+
+ StaticStrideScheduler(float[] weights, Random random) {
+ checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
+ int numChannels = weights.length;
+ int numWeightedChannels = 0;
+ double sumWeight = 0;
+ float maxWeight = 0;
+ short meanWeight = 0;
+ for (float weight : weights) {
+ if (weight > 0) {
+ sumWeight += weight;
+ maxWeight = Math.max(weight, maxWeight);
+ numWeightedChannels++;
+ }
+ }
- private final Random random;
+ double scalingFactor = K_MAX_WEIGHT / maxWeight;
+ if (numWeightedChannels > 0) {
+ meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
+ } else {
+ meanWeight = 1;
+ }
- /**
- * Use the item's deadline as the order in the priority queue. If the deadlines are the same,
- * use the index. Index should be unique.
- */
- EdfScheduler(int initialCapacity, Random random) {
- this.prioQueue = new PriorityQueue<ObjectState>(initialCapacity, (o1, o2) -> {
- if (o1.deadline == o2.deadline) {
- return Integer.compare(o1.index, o2.index);
+ // scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
+ short[] scaledWeights = new short[numChannels];
+ for (int i = 0; i < numChannels; i++) {
+ if (weights[i] <= 0) {
+ scaledWeights[i] = meanWeight;
} else {
- return Double.compare(o1.deadline, o2.deadline);
+ scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor);
}
- });
- this.random = random;
+ }
+
+ this.scaledWeights = scaledWeights;
+ this.sizeDivisor = numChannels;
+ this.sequence = new AtomicInteger(random.nextInt());
+
}
- /**
- * Adds the item in the scheduler. This is not thread safe.
- *
- * @param index The field {@link ObjectState#index} to be added
- * @param weight positive weight for the added object
- */
- void add(int index, double weight) {
- checkArgument(weight > 0.0, "Weights need to be positive.");
- ObjectState state = new ObjectState(Math.max(weight, MINIMUM_WEIGHT), index);
- // Randomize the initial deadline.
- state.deadline = random.nextDouble() * (1 / state.weight);
- prioQueue.add(state);
+ /** Returns the next sequence number and atomically increases sequence with wraparound. */
+ private long nextSequence() {
+ return Integer.toUnsignedLong(sequence.getAndIncrement());
}
- /**
- * Picks the next WRR object.
+ @VisibleForTesting
+ long getSequence() {
+ return Integer.toUnsignedLong(sequence.get());
+ }
+
+ /*
+ * Selects index of next backend server.
+ * <p>
+ * A 2D array is compactly represented as a function of W(backend), where the row
+ * represents the generation and the column represents the backend index:
+ * X(backend,generation) | generation ∈ [0,kMaxWeight).
+ * Each element in the conceptual array is a boolean indicating whether the backend at
+ * this index should be picked now. If false, the counter is incremented again,
+ * and the new element is checked. An atomically incremented counter keeps track of our
+ * backend and generation through modular arithmetic within the pick() method.
+ * <p>
+ * Modular arithmetic allows us to evenly distribute picks and skips between
+ * generations based on W(backend).
+ * X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
+ * If we have the same three backends with weights:
+ * W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
+ * <p>
+ * B0 B1 B2
+ * T T T
+ * F F T
+ * F T T
+ * T F T
+ * F T T
+ * F F T
+ * The sequence of picked backend indices is given by
+ * walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
+ * <p>
+ * To reduce the variance and spread the wasted work among different picks,
+ * an offset that varies per backend index is also included to the calculation.
*/
int pick() {
- synchronized (lock) {
- ObjectState minObject = prioQueue.remove();
- minObject.deadline += 1.0 / minObject.weight;
- prioQueue.add(minObject);
- return minObject.index;
+ while (true) {
+ long sequence = this.nextSequence();
+ int backendIndex = (int) (sequence % this.sizeDivisor);
+ long generation = sequence / this.sizeDivisor;
+ int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]);
+ long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
+ if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
+ continue;
+ }
+ return backendIndex;
}
}
}
- /** Holds the state of the object. */
- @VisibleForTesting
- static class ObjectState {
- private final double weight;
- private final int index;
- private volatile double deadline;
-
- ObjectState(double weight, int index) {
- this.weight = weight;
- this.index = index;
- }
- }
-
static final class WeightedRoundRobinLoadBalancerConfig {
final long blackoutPeriodNanos;
final long weightExpirationPeriodNanos;
diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
index daf58a174..58a19af96 100644
--- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
@@ -52,7 +52,7 @@ import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.services.InternalCallMetricRecorder;
import io.grpc.services.MetricReport;
-import io.grpc.xds.WeightedRoundRobinLoadBalancer.EdfScheduler;
+import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel;
@@ -175,7 +175,7 @@ public class WeightedRoundRobinLoadBalancerTest {
}
});
wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(),
- new FakeRandom());
+ new FakeRandom(0));
}
@Test
@@ -220,7 +220,7 @@ public class WeightedRoundRobinLoadBalancerTest {
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
- .getSubchannel()).isEqualTo(weightedSubchannel1);
+ .getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
@@ -338,7 +338,7 @@ public class WeightedRoundRobinLoadBalancerTest {
}
@Test
- public void pickByWeight_LargeWeight() {
+ public void pickByWeight_largeWeight() {
MetricReport report1 = InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 999, 0, new HashMap<>(), new HashMap<>());
MetricReport report2 = InternalCallMetricRecorder.createMetricReport(
@@ -593,6 +593,7 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1);
assertThat(weightedPicker.pickSubchannel(mockArgs)
.getSubchannel()).isEqualTo(weightedSubchannel2);
+
}
@Test
@@ -750,12 +751,12 @@ public class WeightedRoundRobinLoadBalancerTest {
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
- .isAtMost(0.001);
+ .isAtMost(0.002);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
- .isAtMost(0.001);
+ .isAtMost(0.002);
// subchannel3's weight is average of subchannel1 and subchannel2
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
- .isAtMost(0.001);
+ .isAtMost(0.002);
}
@Test
@@ -821,45 +822,275 @@ public class WeightedRoundRobinLoadBalancerTest {
.isAtMost(0.001);
}
+ @Test(expected = NullPointerException.class)
+ public void wrrConfig_TimeValueNonNull() {
+ WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null);
+ }
+
+ @Test(expected = NullPointerException.class)
+ public void wrrConfig_BooleanValueNonNull() {
+ WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void emptyWeights() {
+ float[] weights = {};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ sss.pick();
+ }
+
+ @Test
+ public void testPicksEqualsWeights() {
+ float[] weights = {1.0f, 2.0f, 3.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int[] expectedPicks = new int[] {1, 2, 3};
+ int[] picks = new int[3];
+ for (int i = 0; i < 6; i++) {
+ picks[sss.pick()] += 1;
+ }
+ assertThat(picks).isEqualTo(expectedPicks);
+ }
+
@Test
- public void edfScheduler() {
+ public void testContainsZeroWeightUseMean() {
+ float[] weights = {3.0f, 0.0f, 1.0f};
Random random = new Random();
- double totalWeight = 0;
- int capacity = random.nextInt(10) + 1;
- double[] weights = new double[capacity];
- EdfScheduler scheduler = new EdfScheduler(capacity, random);
- for (int i = 0; i < capacity; i++) {
- weights[i] = random.nextDouble();
- scheduler.add(i, weights[i]);
- totalWeight += weights[i];
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int[] expectedPicks = new int[] {3, 2, 1};
+ int[] picks = new int[3];
+ for (int i = 0; i < 6; i++) {
+ picks[sss.pick()] += 1;
}
+ assertThat(picks).isEqualTo(expectedPicks);
+ }
+
+ @Test
+ public void testContainsNegativeWeightUseMean() {
+ float[] weights = {3.0f, -1.0f, 1.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int[] expectedPicks = new int[] {3, 2, 1};
+ int[] picks = new int[3];
+ for (int i = 0; i < 6; i++) {
+ picks[sss.pick()] += 1;
+ }
+ assertThat(picks).isEqualTo(expectedPicks);
+ }
+
+ @Test
+ public void testAllSameWeights() {
+ float[] weights = {1.0f, 1.0f, 1.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int[] expectedPicks = new int[] {2, 2, 2};
+ int[] picks = new int[3];
+ for (int i = 0; i < 6; i++) {
+ picks[sss.pick()] += 1;
+ }
+ assertThat(picks).isEqualTo(expectedPicks);
+ }
+
+ @Test
+ public void testAllZeroWeightsUseOne() {
+ float[] weights = {0.0f, 0.0f, 0.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int[] expectedPicks = new int[] {2, 2, 2};
+ int[] picks = new int[3];
+ for (int i = 0; i < 6; i++) {
+ picks[sss.pick()] += 1;
+ }
+ assertThat(picks).isEqualTo(expectedPicks);
+ }
+
+ @Test
+ public void testAllInvalidWeightsUseOne() {
+ float[] weights = {-3.1f, -0.0f, 0.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int[] expectedPicks = new int[] {2, 2, 2};
+ int[] picks = new int[3];
+ for (int i = 0; i < 6; i++) {
+ picks[sss.pick()] += 1;
+ }
+ assertThat(picks).isEqualTo(expectedPicks);
+ }
+
+ @Test
+ public void testLargestWeightIndexPickedEveryGeneration() {
+ float[] weights = {1.0f, 2.0f, 3.0f};
+ int largestWeightIndex = 2;
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ int largestWeightPickCount = 0;
+ int kMaxWeight = 65535;
+ for (int i = 0; i < largestWeightIndex * kMaxWeight; i++) {
+ if (sss.pick() == largestWeightIndex) {
+ largestWeightPickCount += 1;
+ }
+ }
+ assertThat(largestWeightPickCount).isEqualTo(kMaxWeight);
+ }
+
+ @Test
+ public void testStaticStrideSchedulerNonIntegers1() {
+ float[] weights = {2.0f, (float) (10.0 / 3.0), 1.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 2 + 10.0 / 3.0 + 1.0;
Map<Integer, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
- int result = scheduler.pick();
+ int result = sss.pick();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
- for (int i = 0; i < capacity; i++) {
- assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight) )
+ for (int i = 0; i < 3; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
}
}
@Test
- public void edsScheduler_sameWeight() {
- EdfScheduler scheduler = new EdfScheduler(2, new FakeRandom());
- scheduler.add(0, 0.5);
- scheduler.add(1, 0.5);
- assertThat(scheduler.pick()).isEqualTo(0);
+ public void testStaticStrideSchedulerNonIntegers2() {
+ float[] weights = {0.5f, 0.3f, 1.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 1.8;
+ Map<Integer, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ int result = sss.pick();
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
+ }
+ for (int i = 0; i < 3; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
+ .isAtMost(0.01);
+ }
}
- @Test(expected = NullPointerException.class)
- public void wrrConfig_TimeValueNonNull() {
- WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos((Long) null);
+ @Test
+ public void testTwoWeights() {
+ float[] weights = {1.0f, 2.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 3;
+ Map<Integer, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ int result = sss.pick();
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
+ }
+ for (int i = 0; i < 2; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
+ .isAtMost(0.01);
+ }
}
- @Test(expected = NullPointerException.class)
- public void wrrConfig_BooleanValueNonNull() {
- WeightedRoundRobinLoadBalancerConfig.newBuilder().setEnableOobLoadReport((Boolean) null);
+ @Test
+ public void testManyWeights() {
+ float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 15;
+ Map<Integer, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ int result = sss.pick();
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
+ }
+ for (int i = 0; i < 5; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
+ .isAtMost(0.0011);
+ }
+ }
+
+ @Test
+ public void testManyComplexWeights() {
+ float[] weights = {1.2f, 2.4f, 222.56f, 1.1f, 15.0f, 226342.0f, 5123.0f, 532.2f};
+ Random random = new Random();
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 1.2 + 2.4 + 222.56 + 15.0 + 226342.0 + 5123.0 + 0.0001;
+ Map<Integer, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ int result = sss.pick();
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
+ }
+ for (int i = 0; i < 8; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
+ .isAtMost(0.01);
+ }
+ }
+
+ @Test
+ public void testDeterministicPicks() {
+ float[] weights = {2.0f, 3.0f, 6.0f};
+ Random random = new FakeRandom(0);
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ assertThat(sss.getSequence()).isEqualTo(0);
+ assertThat(sss.pick()).isEqualTo(1);
+ assertThat(sss.getSequence()).isEqualTo(2);
+ assertThat(sss.pick()).isEqualTo(2);
+ assertThat(sss.getSequence()).isEqualTo(3);
+ assertThat(sss.pick()).isEqualTo(2);
+ assertThat(sss.getSequence()).isEqualTo(6);
+ assertThat(sss.pick()).isEqualTo(0);
+ assertThat(sss.getSequence()).isEqualTo(7);
+ assertThat(sss.pick()).isEqualTo(1);
+ assertThat(sss.getSequence()).isEqualTo(8);
+ assertThat(sss.pick()).isEqualTo(2);
+ assertThat(sss.getSequence()).isEqualTo(9);
+ }
+
+ @Test
+ public void testImmediateWraparound() {
+ float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ Random random = new FakeRandom(-1);
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 15;
+ Map<Integer, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ int result = sss.pick();
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
+ }
+ for (int i = 0; i < 5; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
+ .isAtMost(0.001);
+ }
+ }
+
+ @Test
+ public void testWraparound() {
+ float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ Random random = new FakeRandom(-500);
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ double totalWeight = 15;
+ Map<Integer, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ int result = sss.pick();
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
+ }
+ for (int i = 0; i < 5; i++) {
+ assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
+ .isAtMost(0.0011);
+ }
+ }
+
+ @Test
+ public void testDeterministicWraparound() {
+ float[] weights = {2.0f, 3.0f, 6.0f};
+ Random random = new FakeRandom(-1);
+ StaticStrideScheduler sss = new StaticStrideScheduler(weights, random);
+ assertThat(sss.getSequence()).isEqualTo(0xFFFF_FFFFL);
+ assertThat(sss.pick()).isEqualTo(1);
+ assertThat(sss.getSequence()).isEqualTo(2);
+ assertThat(sss.pick()).isEqualTo(2);
+ assertThat(sss.getSequence()).isEqualTo(3);
+ assertThat(sss.pick()).isEqualTo(2);
+ assertThat(sss.getSequence()).isEqualTo(6);
+ assertThat(sss.pick()).isEqualTo(0);
+ assertThat(sss.getSequence()).isEqualTo(7);
+ assertThat(sss.pick()).isEqualTo(1);
+ assertThat(sss.getSequence()).isEqualTo(8);
+ assertThat(sss.pick()).isEqualTo(2);
+ assertThat(sss.getSequence()).isEqualTo(9);
}
private static class FakeSocketAddress extends SocketAddress {
@@ -875,10 +1106,16 @@ public class WeightedRoundRobinLoadBalancerTest {
}
private static class FakeRandom extends Random {
+ private int nextInt;
+
+ public FakeRandom(int nextInt) {
+ this.nextInt = nextInt;
+ }
+
@Override
- public double nextDouble() {
+ public int nextInt() {
// return constant value to disable init deadline randomization in the scheduler
- return 0.322023;
+ return nextInt;
}
}
}