diff options
Diffstat (limited to 'src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java')
-rw-r--r-- | src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java b/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java new file mode 100644 index 0000000..d38cf71 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Pair; + +import java.util.Arrays; +import java.util.Comparator; + +/** + * Implementation of Kendall's Tau-b rank correlation</a>. + * <p> + * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and + * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if + * x<sub>1</sub> < x<sub>2</sub> and y<sub>1</sub> < y<sub>2</sub> + * or x<sub>2</sub> < x<sub>1</sub> and y<sub>2</sub> < y<sub>1</sub>. + * The pair is <i>discordant</i> if x<sub>1</sub> < x<sub>2</sub> and + * y<sub>2</sub> < y<sub>1</sub> or x<sub>2</sub> < x<sub>1</sub> and + * y<sub>1</sub> < y<sub>2</sub>. If either x<sub>1</sub> = x<sub>2</sub> + * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor + * discordant. + * <p> + * Kendall's Tau-b is defined as: + * <pre> + * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>)) + * </pre> + * <p> + * where: + * <ul> + * <li>n<sub>0</sub> = n * (n - 1) / 2</li> + * <li>n<sub>c</sub> = Number of concordant pairs</li> + * <li>n<sub>d</sub> = Number of discordant pairs</li> + * <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li> + * <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li> + * <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li> + * <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li> + * </ul> + * <p> + * This implementation uses the O(n log n) algorithm described in + * William R. Knight's 1966 paper "A Computer Method for Calculating + * Kendall's Tau with Ungrouped Data" in the Journal of the American + * Statistical Association. + * + * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient"> + * Kendall tau rank correlation coefficient (Wikipedia)</a> + * @see <a href="http://www.jstor.org/stable/2282833">A Computer + * Method for Calculating Kendall's Tau with Ungrouped Data</a> + * + * @since 3.3 + */ +public class KendallsCorrelation { + + /** correlation matrix */ + private final RealMatrix correlationMatrix; + + /** + * Create a KendallsCorrelation instance without data. + */ + public KendallsCorrelation() { + correlationMatrix = null; + } + + /** + * Create a KendallsCorrelation from a rectangular array + * whose columns represent values of variables to be correlated. + * + * @param data rectangular array with columns representing variables + * @throws IllegalArgumentException if the input data array is not + * rectangular with at least two rows and two columns. + */ + public KendallsCorrelation(double[][] data) { + this(MatrixUtils.createRealMatrix(data)); + } + + /** + * Create a KendallsCorrelation from a RealMatrix whose columns + * represent variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + */ + public KendallsCorrelation(RealMatrix matrix) { + correlationMatrix = computeCorrelationMatrix(matrix); + } + + /** + * Returns the correlation matrix. + * + * @return correlation matrix + */ + public RealMatrix getCorrelationMatrix() { + return correlationMatrix; + } + + /** + * Computes the Kendall's Tau rank correlation matrix for the columns of + * the input matrix. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { + int nVars = matrix.getColumnDimension(); + RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); + for (int i = 0; i < nVars; i++) { + for (int j = 0; j < i; j++) { + double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); + outMatrix.setEntry(i, j, corr); + outMatrix.setEntry(j, i, corr); + } + outMatrix.setEntry(i, i, 1d); + } + return outMatrix; + } + + /** + * Computes the Kendall's Tau rank correlation matrix for the columns of + * the input rectangular array. The columns of the array represent values + * of variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final double[][] matrix) { + return computeCorrelationMatrix(new BlockRealMatrix(matrix)); + } + + /** + * Computes the Kendall's Tau rank correlation coefficient between the two arrays. + * + * @param xArray first data array + * @param yArray second data array + * @return Returns Kendall's Tau rank correlation coefficient for the two arrays + * @throws DimensionMismatchException if the arrays lengths do not match + */ + public double correlation(final double[] xArray, final double[] yArray) + throws DimensionMismatchException { + + if (xArray.length != yArray.length) { + throw new DimensionMismatchException(xArray.length, yArray.length); + } + + final int n = xArray.length; + final long numPairs = sum(n - 1); + + @SuppressWarnings("unchecked") + Pair<Double, Double>[] pairs = new Pair[n]; + for (int i = 0; i < n; i++) { + pairs[i] = new Pair<Double, Double>(xArray[i], yArray[i]); + } + + Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() { + /** {@inheritDoc} */ + public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) { + int compareFirst = pair1.getFirst().compareTo(pair2.getFirst()); + return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond()); + } + }); + + long tiedXPairs = 0; + long tiedXYPairs = 0; + long consecutiveXTies = 1; + long consecutiveXYTies = 1; + Pair<Double, Double> prev = pairs[0]; + for (int i = 1; i < n; i++) { + final Pair<Double, Double> curr = pairs[i]; + if (curr.getFirst().equals(prev.getFirst())) { + consecutiveXTies++; + if (curr.getSecond().equals(prev.getSecond())) { + consecutiveXYTies++; + } else { + tiedXYPairs += sum(consecutiveXYTies - 1); + consecutiveXYTies = 1; + } + } else { + tiedXPairs += sum(consecutiveXTies - 1); + consecutiveXTies = 1; + tiedXYPairs += sum(consecutiveXYTies - 1); + consecutiveXYTies = 1; + } + prev = curr; + } + tiedXPairs += sum(consecutiveXTies - 1); + tiedXYPairs += sum(consecutiveXYTies - 1); + + long swaps = 0; + @SuppressWarnings("unchecked") + Pair<Double, Double>[] pairsDestination = new Pair[n]; + for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) { + for (int offset = 0; offset < n; offset += 2 * segmentSize) { + int i = offset; + final int iEnd = FastMath.min(i + segmentSize, n); + int j = iEnd; + final int jEnd = FastMath.min(j + segmentSize, n); + + int copyLocation = offset; + while (i < iEnd || j < jEnd) { + if (i < iEnd) { + if (j < jEnd) { + if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) { + pairsDestination[copyLocation] = pairs[i]; + i++; + } else { + pairsDestination[copyLocation] = pairs[j]; + j++; + swaps += iEnd - i; + } + } else { + pairsDestination[copyLocation] = pairs[i]; + i++; + } + } else { + pairsDestination[copyLocation] = pairs[j]; + j++; + } + copyLocation++; + } + } + final Pair<Double, Double>[] pairsTemp = pairs; + pairs = pairsDestination; + pairsDestination = pairsTemp; + } + + long tiedYPairs = 0; + long consecutiveYTies = 1; + prev = pairs[0]; + for (int i = 1; i < n; i++) { + final Pair<Double, Double> curr = pairs[i]; + if (curr.getSecond().equals(prev.getSecond())) { + consecutiveYTies++; + } else { + tiedYPairs += sum(consecutiveYTies - 1); + consecutiveYTies = 1; + } + prev = curr; + } + tiedYPairs += sum(consecutiveYTies - 1); + + final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps; + final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs); + return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied); + } + + /** + * Returns the sum of the number from 1 .. n according to Gauss' summation formula: + * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \] + * + * @param n the summation end + * @return the sum of the number from 1 to n + */ + private static long sum(long n) { + return n * (n + 1) / 2l; + } +} |