diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java')
-rw-r--r-- | src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java | 283 |
1 files changed, 283 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java new file mode 100644 index 0000000..991a9f8 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.distribution; + +import org.apache.commons.math3.exception.MathArithmeticException; +import org.apache.commons.math3.exception.NotANumberException; +import org.apache.commons.math3.exception.NotFiniteNumberException; +import org.apache.commons.math3.exception.NotPositiveException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well19937c; +import org.apache.commons.math3.util.MathArrays; +import org.apache.commons.math3.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A generic implementation of a <a + * href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> + * discrete probability distribution (Wikipedia)</a> over a finite sample space, based on an + * enumerated list of <value, probability> pairs. Input probabilities must all be + * non-negative, but zero values are allowed and their sum does not have to equal one. Constructors + * will normalize input probabilities to make them sum to one. + * + * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and + * it can contain null values. The pmf created by the constructor will combine probabilities of + * equal values and will treat null values as equal. For example, if the list of pairs <"dog", + * 0.2>, <null, 0.1>, <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided + * to the constructor, the resulting pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to + * null. + * + * @param <T> type of the elements in the sample space. + * @since 3.2 + */ +public class EnumeratedDistribution<T> implements Serializable { + + /** Serializable UID. */ + private static final long serialVersionUID = 20123308L; + + /** RNG instance used to generate samples from the distribution. */ + protected final RandomGenerator random; + + /** List of random variable values. */ + private final List<T> singletons; + + /** + * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, + * probability[i] is the probability that a random variable following this distribution takes + * the value singletons[i]. + */ + private final double[] probabilities; + + /** Cumulative probabilities, cached to speed up sampling. */ + private final double[] cumulativeProbabilities; + + /** + * Create an enumerated distribution using the given probability mass function enumeration. + * + * <p><b>Note:</b> this constructor will implicitly create an instance of {@link Well19937c} as + * random generator to be used for sampling only (see {@link #sample()} and {@link + * #sample(int)}). In case no sampling is needed for the created distribution, it is advised to + * pass {@code null} as random generator via the appropriate constructors to avoid the + * additional initialisation overhead. + * + * @param pmf probability mass function enumerated as a list of <T, probability> pairs. + * @throws NotPositiveException if any of the probabilities are negative. + * @throws NotFiniteNumberException if any of the probabilities are infinite. + * @throws NotANumberException if any of the probabilities are NaN. + * @throws MathArithmeticException all of the probabilities are 0. + */ + public EnumeratedDistribution(final List<Pair<T, Double>> pmf) + throws NotPositiveException, + MathArithmeticException, + NotFiniteNumberException, + NotANumberException { + this(new Well19937c(), pmf); + } + + /** + * Create an enumerated distribution using the given random number generator and probability + * mass function enumeration. + * + * @param rng random number generator. + * @param pmf probability mass function enumerated as a list of <T, probability> pairs. + * @throws NotPositiveException if any of the probabilities are negative. + * @throws NotFiniteNumberException if any of the probabilities are infinite. + * @throws NotANumberException if any of the probabilities are NaN. + * @throws MathArithmeticException all of the probabilities are 0. + */ + public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf) + throws NotPositiveException, + MathArithmeticException, + NotFiniteNumberException, + NotANumberException { + random = rng; + + singletons = new ArrayList<T>(pmf.size()); + final double[] probs = new double[pmf.size()]; + + for (int i = 0; i < pmf.size(); i++) { + final Pair<T, Double> sample = pmf.get(i); + singletons.add(sample.getKey()); + final double p = sample.getValue(); + if (p < 0) { + throw new NotPositiveException(sample.getValue()); + } + if (Double.isInfinite(p)) { + throw new NotFiniteNumberException(p); + } + if (Double.isNaN(p)) { + throw new NotANumberException(); + } + probs[i] = p; + } + + probabilities = MathArrays.normalizeArray(probs, 1.0); + + cumulativeProbabilities = new double[probabilities.length]; + double sum = 0; + for (int i = 0; i < probabilities.length; i++) { + sum += probabilities[i]; + cumulativeProbabilities[i] = sum; + } + } + + /** + * Reseed the random generator used to generate samples. + * + * @param seed the new seed + */ + public void reseedRandomGenerator(long seed) { + random.setSeed(seed); + } + + /** + * For a random variable {@code X} whose values are distributed according to this distribution, + * this method returns {@code P(X = x)}. In other words, this method represents the probability + * mass function (PMF) for the distribution. + * + * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)}, or both are null, + * then {@code probability(x1) = probability(x2)}. + * + * @param x the point at which the PMF is evaluated + * @return the value of the probability mass function at {@code x} + */ + double probability(final T x) { + double probability = 0; + + for (int i = 0; i < probabilities.length; i++) { + if ((x == null && singletons.get(i) == null) + || (x != null && x.equals(singletons.get(i)))) { + probability += probabilities[i]; + } + } + + return probability; + } + + /** + * Return the probability mass function as a list of <value, probability> pairs. + * + * <p>Note that if duplicate and / or null values were provided to the constructor when creating + * this EnumeratedDistribution, the returned list will contain these values. If duplicates + * values exist, what is returned will not represent a pmf (i.e., it is up to the caller to + * consolidate duplicate mass points). + * + * @return the probability mass function. + */ + public List<Pair<T, Double>> getPmf() { + final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length); + + for (int i = 0; i < probabilities.length; i++) { + samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i])); + } + + return samples; + } + + /** + * Generate a random value sampled from this distribution. + * + * @return a random value. + */ + public T sample() { + final double randomValue = random.nextDouble(); + + int index = Arrays.binarySearch(cumulativeProbabilities, randomValue); + if (index < 0) { + index = -index - 1; + } + + if (index >= 0 + && index < probabilities.length + && randomValue < cumulativeProbabilities[index]) { + return singletons.get(index); + } + + /* This should never happen, but it ensures we will return a correct + * object in case there is some floating point inequality problem + * wrt the cumulative probabilities. */ + return singletons.get(singletons.size() - 1); + } + + /** + * Generate a random sample from the distribution. + * + * @param sampleSize the number of random values to generate. + * @return an array representing the random sample. + * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. + */ + public Object[] sample(int sampleSize) throws NotStrictlyPositiveException { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); + } + + final Object[] out = new Object[sampleSize]; + + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + + return out; + } + + /** + * Generate a random sample from the distribution. + * + * <p>If the requested samples fit in the specified array, it is returned therein. Otherwise, a + * new array is allocated with the runtime type of the specified array and the size of this + * collection. + * + * @param sampleSize the number of random values to generate. + * @param array the array to populate. + * @return an array representing the random sample. + * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. + * @throws NullArgumentException if {@code array} is null + */ + public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); + } + + if (array == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + + T[] out; + if (array.length < sampleSize) { + @SuppressWarnings("unchecked") // safe as both are of type T + final T[] unchecked = + (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize); + out = unchecked; + } else { + out = array; + } + + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + + return out; + } +} |