Changeset View
Changeset View
Standalone View
Standalone View
src/backend/generalTest/CorrelationCoefficient.cpp
- This file was added.
1 | /*************************************************************************** | ||||
---|---|---|---|---|---|
2 | File : CorrelationCoefficient.cpp | ||||
3 | Project : LabPlot | ||||
4 | Description : Finding Correlation Coefficient on data provided | ||||
5 | -------------------------------------------------------------------- | ||||
6 | Copyright : (C) 2019 Devanshu Agarwal(agarwaldevanshu8@gmail.com) | ||||
7 | | ||||
8 | ***************************************************************************/ | ||||
9 | | ||||
10 | /*************************************************************************** | ||||
11 | * * | ||||
12 | * This program is free software; you can redistribute it and/or modify * | ||||
13 | * it under the terms of the GNU General Public License as published by * | ||||
14 | * the Free Software Foundation; either version 2 of the License, or * | ||||
15 | * (at your option) any later version. * | ||||
16 | * * | ||||
17 | * This program is distributed in the hope that it will be useful, * | ||||
18 | * but WITHOUT ANY WARRANTY; without even the implied warranty of * | ||||
19 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * | ||||
20 | * GNU General Public License for more details. * | ||||
21 | * * | ||||
22 | * You should have received a copy of the GNU General Public License * | ||||
23 | * along with this program; if not, write to the Free Software * | ||||
24 | * Foundation, Inc., 51 Franklin Street, Fifth Floor, * | ||||
25 | * Boston, MA 02110-1301 USA * | ||||
26 | * * | ||||
27 | ***************************************************************************/ | ||||
28 | | ||||
29 | #include "CorrelationCoefficient.h" | ||||
30 | #include "GeneralTest.h" | ||||
31 | #include "kdefrontend/generalTest/CorrelationCoefficientView.h" | ||||
32 | #include "backend/spreadsheet/Spreadsheet.h" | ||||
33 | #include "backend/core/column/Column.h" | ||||
34 | #include "backend/lib/macros.h" | ||||
35 | | ||||
36 | #include <QVector> | ||||
37 | #include <QStandardItemModel> | ||||
38 | #include <QLocale> | ||||
39 | #include <QLabel> | ||||
40 | #include <QVBoxLayout> | ||||
41 | #include <QWidget> | ||||
42 | #include <QtMath> | ||||
43 | #include <QQueue> | ||||
44 | | ||||
45 | #include <KLocalizedString> | ||||
46 | | ||||
47 | #include <gsl/gsl_math.h> | ||||
48 | #include <gsl/gsl_statistics.h> | ||||
49 | #include <algorithm> | ||||
50 | | ||||
51 | extern "C" { | ||||
52 | #include "backend/nsl/nsl_stats.h" | ||||
53 | } | ||||
54 | | ||||
55 | CorrelationCoefficient::CorrelationCoefficient(const QString &name) : GeneralTest (name) { | ||||
56 | } | ||||
57 | | ||||
58 | CorrelationCoefficient::~CorrelationCoefficient() { | ||||
59 | } | ||||
60 | | ||||
61 | void CorrelationCoefficient::performTest(Test test, bool categoricalVariable) { | ||||
62 | m_statsTable = ""; | ||||
63 | m_tooltips.clear(); | ||||
64 | for (int i = 0; i < RESULTLINESCOUNT; i++) | ||||
65 | m_resultLine[i]->clear(); | ||||
66 | | ||||
67 | switch (test) { | ||||
68 | case CorrelationCoefficient::Test::Pearson: { | ||||
69 | m_currTestName = "<h2>" + i18n("Pearson's r Correlation Test") + "</h2>"; | ||||
70 | performPearson(categoricalVariable); | ||||
71 | break; | ||||
72 | } | ||||
73 | case CorrelationCoefficient::Test::Kendall: | ||||
74 | m_currTestName = "<h2>" + i18n("Kendall's Rank Correlation Test") + "</h2>"; | ||||
75 | performKendall(); | ||||
76 | break; | ||||
77 | case CorrelationCoefficient::Test::Spearman: { | ||||
78 | m_currTestName = "<h2>" + i18n("Spearman Correlation Coefficient Test") + "</h2>"; | ||||
79 | performSpearman(); | ||||
80 | break; | ||||
81 | } | ||||
82 | } | ||||
83 | | ||||
84 | emit changed(); | ||||
85 | } | ||||
86 | | ||||
87 | | ||||
88 | double CorrelationCoefficient::correlationValue() { | ||||
89 | return m_correlationValue; | ||||
90 | } | ||||
91 | | ||||
92 | | ||||
93 | /*************************************************************************************************************************** | ||||
94 | * Private Implementations | ||||
95 | * ************************************************************************************************************************/ | ||||
96 | | ||||
97 | /*********************************************Pearson r ******************************************************************/ | ||||
98 | //Formulaes are taken from https://www.statisticssolutions.com/correlation-pearson-kendall-spearman/ | ||||
99 | | ||||
100 | // variables: | ||||
101 | // N = total number of observations | ||||
102 | // sumColx = sum of values in colx | ||||
103 | // sumSqColx = sum of square of values in colx | ||||
104 | // sumColxColy = sum of product of values in colx and coly | ||||
105 | | ||||
106 | //TODO: support for col1 is categorical. | ||||
107 | //TODO: add symbols in stats table header. | ||||
108 | //TODO: add automatic test | ||||
109 | //TODO: add tooltip for correlation value result | ||||
110 | //TODO: find p value | ||||
111 | void CorrelationCoefficient::performPearson(bool categoricalVariable) { | ||||
112 | if (m_columns.count() != 2) { | ||||
113 | printError("Select only 2 columns "); | ||||
114 | return; | ||||
115 | } | ||||
116 | | ||||
117 | if (categoricalVariable) { | ||||
118 | printLine(1, "currently categorical variable not supported", "blue"); | ||||
119 | return; | ||||
120 | } | ||||
121 | | ||||
122 | QString col1Name = m_columns[0]->name(); | ||||
123 | QString col2Name = m_columns[1]->name(); | ||||
124 | | ||||
125 | | ||||
126 | if (!isNumericOrInteger(m_columns[1])) { | ||||
127 | printError("Column " + col2Name + " should contain only numeric or interger values"); | ||||
128 | } | ||||
129 | | ||||
130 | | ||||
131 | int N = findCount(m_columns[0]); | ||||
132 | if (N != findCount(m_columns[1])) { | ||||
133 | printError("Number of data values in Column: " + col1Name + "and Column: " + col2Name + "are not equal"); | ||||
134 | return; | ||||
135 | } | ||||
136 | | ||||
137 | double sumCol1 = findSum(m_columns[0], N); | ||||
138 | double sumCol2 = findSum(m_columns[1], N); | ||||
139 | double sumSqCol1 = findSumSq(m_columns[0], N); | ||||
140 | double sumSqCol2 = findSumSq(m_columns[1], N); | ||||
141 | | ||||
142 | double sumCol12 = 0; | ||||
143 | | ||||
144 | for (int i = 0; i < N; i++) | ||||
145 | sumCol12 += m_columns[0]->valueAt(i) * | ||||
146 | m_columns[1]->valueAt(i); | ||||
147 | | ||||
148 | // printing table; | ||||
149 | // cell constructor structure; data, level, rowSpanCount, m_columnspanCount, isHeader; | ||||
150 | QList<Cell*> rowMajor; | ||||
151 | int level = 0; | ||||
152 | | ||||
153 | // horizontal header | ||||
154 | rowMajor.append(new Cell("", level, true)); | ||||
155 | rowMajor.append(new Cell("N", level, true, "Total Number of Observations")); | ||||
156 | rowMajor.append(new Cell("Sigma", level, true, "Sum of Scores in each column")); | ||||
157 | rowMajor.append(new Cell("Sigma x2", level, true, "Sum of Squares of scores in each column")); | ||||
158 | rowMajor.append(new Cell("Sigma xy", level, true, "Sum of Squares of scores in each column")); | ||||
159 | | ||||
160 | //data with vertical header. | ||||
161 | level++; | ||||
162 | rowMajor.append(new Cell(col1Name, level, true)); | ||||
163 | rowMajor.append(new Cell(N, level)); | ||||
164 | rowMajor.append(new Cell(sumCol1, level)); | ||||
165 | rowMajor.append(new Cell(sumSqCol1, level)); | ||||
166 | | ||||
167 | rowMajor.append(new Cell(sumCol12, level, false, "", 2, 1)); | ||||
168 | | ||||
169 | level++; | ||||
170 | rowMajor.append(new Cell(col2Name, level, true)); | ||||
171 | rowMajor.append(new Cell(N, level)); | ||||
172 | rowMajor.append(new Cell(sumCol2, level)); | ||||
173 | rowMajor.append(new Cell(sumSqCol2, level)); | ||||
174 | | ||||
175 | m_statsTable += getHtmlTable3(rowMajor); | ||||
176 | | ||||
177 | | ||||
178 | m_correlationValue = (N * sumCol12 - sumCol1*sumCol2) / | ||||
179 | sqrt((N * sumSqCol1 - gsl_pow_2(sumCol1)) * | ||||
180 | (N * sumSqCol2 - gsl_pow_2(sumCol2))); | ||||
181 | | ||||
182 | printLine(0, QString("Correlation Value is %1").arg(round(m_correlationValue)), "green"); | ||||
183 | | ||||
184 | } | ||||
185 | | ||||
186 | /***********************************************Kendall ******************************************************************/ | ||||
187 | // used knight algorithm for fast performance O(nlogn) rather than O(n^2) | ||||
188 | // http://adereth.github.io/blog/2013/10/30/efficiently-computing-kendalls-tau/ | ||||
189 | | ||||
190 | // TODO: Change date format type to original for numeric type; | ||||
191 | // TODO: add tooltips. | ||||
192 | // TODO: Compute tauB for ties. | ||||
193 | // TODO: find P Value from Z Value | ||||
194 | void CorrelationCoefficient::performKendall() { | ||||
195 | if (m_columns.count() != 2) { | ||||
196 | printError("Select only 2 columns "); | ||||
197 | return; | ||||
198 | } | ||||
199 | | ||||
200 | QString col1Name = m_columns[0]->name(); | ||||
201 | QString col2Name = m_columns[1]->name(); | ||||
202 | | ||||
203 | int N = findCount(m_columns[0]); | ||||
204 | if (N != findCount(m_columns[1])) { | ||||
205 | printError("Number of data values in Column: " + col1Name + "and Column: " + col2Name + "are not equal"); | ||||
206 | return; | ||||
207 | } | ||||
208 | | ||||
209 | int col2Ranks[N]; | ||||
210 | if (isNumericOrInteger(m_columns[0]) || isNumericOrInteger(m_columns[1])) { | ||||
211 | if (isNumericOrInteger(m_columns[0]) && isNumericOrInteger(m_columns[1])) { | ||||
212 | for (int i = 0; i < N; i++) | ||||
213 | col2Ranks[int(m_columns[0]->valueAt(i)) - 1] = int(m_columns[1]->valueAt(i)); | ||||
214 | } else { | ||||
215 | printError(QString("Ranking System should be same for both Column: %1 and Column: %2 <br/>" | ||||
216 | "Hint: Check for data types of columns").arg(col1Name).arg(col2Name)); | ||||
217 | return; | ||||
218 | } | ||||
219 | } else { | ||||
220 | AbstractColumn::ColumnMode origCol1Mode = m_columns[0]->columnMode(); | ||||
221 | AbstractColumn::ColumnMode origCol2Mode = m_columns[1]->columnMode(); | ||||
222 | | ||||
223 | m_columns[0]->setColumnMode(AbstractColumn::Text); | ||||
224 | m_columns[1]->setColumnMode(AbstractColumn::Text); | ||||
225 | | ||||
226 | QMap<QString, int> ValueToRank; | ||||
227 | | ||||
228 | for (int i = 0; i < N; i++) { | ||||
229 | if (ValueToRank[m_columns[0]->textAt(i)] != 0) { | ||||
230 | printError("Currently ties are not supported"); | ||||
231 | m_columns[0]->setColumnMode(origCol1Mode); | ||||
232 | m_columns[1]->setColumnMode(origCol2Mode); | ||||
233 | return; | ||||
234 | } | ||||
235 | ValueToRank[m_columns[0]->textAt(i)] = i + 1; | ||||
236 | } | ||||
237 | | ||||
238 | for (int i = 0; i < N; i++) | ||||
239 | col2Ranks[i] = ValueToRank[m_columns[1]->textAt(i)]; | ||||
240 | | ||||
241 | m_columns[0]->setColumnMode(origCol1Mode); | ||||
242 | m_columns[1]->setColumnMode(origCol2Mode); | ||||
243 | } | ||||
244 | | ||||
245 | int nPossiblePairs = (N * (N - 1)) / 2; | ||||
246 | | ||||
247 | int nDiscordant = findDiscordants(col2Ranks, 0, N - 1); | ||||
248 | int nCorcordant = nPossiblePairs - nDiscordant; | ||||
249 | | ||||
250 | double tauA = double(nCorcordant - nDiscordant) / nPossiblePairs; | ||||
251 | | ||||
252 | double zA = (3 * (nCorcordant - nDiscordant)) / | ||||
253 | sqrt(N * (N- 1) * (2 * N + 5) / 2); | ||||
254 | | ||||
255 | printLine(0 , QString("Number of Discordants are %1").arg(nDiscordant), "green"); | ||||
256 | printLine(1 , QString("Number of Concordant are %1").arg(nCorcordant), "green"); | ||||
257 | | ||||
258 | printLine(2 , QString("Tau a is %1").arg(round(tauA)), "green"); | ||||
259 | printLine(3 , QString("Z Value is %1").arg(round(zA)), "green"); | ||||
260 | | ||||
261 | m_correlationValue = tauA; | ||||
262 | return; | ||||
263 | | ||||
264 | } | ||||
265 | | ||||
266 | /***********************************************Spearman ******************************************************************/ | ||||
267 | // All formulaes and symbols are taken from : https://www.statisticshowto.datasciencecentral.com/spearman-rank-correlation-definition-calculate/ | ||||
268 | | ||||
269 | void CorrelationCoefficient::performSpearman() { | ||||
270 | if (m_columns.count() != 2) { | ||||
271 | printError("Select only 2 columns "); | ||||
272 | return; | ||||
273 | } | ||||
274 | | ||||
275 | QString col1Name = m_columns[0]->name(); | ||||
276 | QString col2Name = m_columns[1]->name(); | ||||
277 | | ||||
278 | int N = findCount(m_columns[0]); | ||||
279 | if (N != findCount(m_columns[1])) { | ||||
280 | printError("Number of data values in Column: " + col1Name + "and Column: " + col2Name + "are not equal"); | ||||
281 | return; | ||||
282 | } | ||||
283 | | ||||
284 | QMap<double, int> col1Ranks; | ||||
285 | convertToRanks(m_columns[0], N, col1Ranks); | ||||
286 | | ||||
287 | QMap<double, int> col2Ranks; | ||||
288 | convertToRanks(m_columns[1], N, col2Ranks); | ||||
289 | | ||||
290 | double ranksCol1Mean = 0; | ||||
291 | double ranksCol2Mean = 0; | ||||
292 | | ||||
293 | // QString ranks1 = ""; | ||||
294 | // QString ranks2 = ""; | ||||
295 | for (int i = 0; i < N; i++) { | ||||
296 | ranksCol1Mean += col1Ranks[int(m_columns[0]->valueAt(i))]; | ||||
297 | ranksCol2Mean += col2Ranks[int(m_columns[1]->valueAt(i))]; | ||||
298 | | ||||
299 | // ranks1 += ", " + QString::number(col1Ranks[m_columns[0]->valueAt(i)]); | ||||
300 | // ranks2 += ", " + QString::number(col2Ranks[m_columns[1]->valueAt(i)]); | ||||
301 | } | ||||
302 | | ||||
303 | ranksCol1Mean = ranksCol1Mean / N; | ||||
304 | ranksCol2Mean = ranksCol2Mean / N; | ||||
305 | | ||||
306 | //QDEBUG("ranks 1 and ranks2 are " ); | ||||
307 | //QDEBUG(ranks1); | ||||
308 | //QDEBUG(ranks2); | ||||
309 | | ||||
310 | //QDEBUG("Mean ranks are " << ranksCol1Mean << ranksCol2Mean); | ||||
311 | | ||||
312 | double s12 = 0; | ||||
313 | double s1 = 0; | ||||
314 | double s2 = 0; | ||||
315 | | ||||
316 | for (int i = 0; i < N; i++) { | ||||
317 | double centeredRank_1 = col1Ranks[int(m_columns[0]->valueAt(i))] - ranksCol1Mean; | ||||
318 | double centeredRank_2 = col2Ranks[int(m_columns[1]->valueAt(i))] - ranksCol2Mean; | ||||
319 | | ||||
320 | s12 += centeredRank_1 * centeredRank_2; | ||||
321 | | ||||
322 | s1 += gsl_pow_2(centeredRank_1); | ||||
323 | s2 += gsl_pow_2(centeredRank_2); | ||||
324 | } | ||||
325 | | ||||
326 | s12 = s12 / N; | ||||
327 | s1 = s1 / N; | ||||
328 | s2 = s2 / N; | ||||
329 | | ||||
330 | //QDEBUG("s12, s1, s2 are " << s12 << " " << s1 << " " << s2); | ||||
331 | | ||||
332 | m_correlationValue = s12 / std::sqrt(s1 * s2); | ||||
333 | | ||||
334 | printLine(0, QString("Spearman Rank Correlation value is %1").arg(m_correlationValue), "green"); | ||||
335 | } | ||||
336 | | ||||
337 | /***********************************************Helper Functions******************************************************************/ | ||||
338 | | ||||
339 | int CorrelationCoefficient::findDiscordants(int *ranks, int start, int end) { | ||||
340 | if (start >= end) | ||||
341 | return 0; | ||||
342 | | ||||
343 | int mid = (start + end) / 2; | ||||
344 | | ||||
345 | int leftDiscordants = findDiscordants(ranks, start, mid); | ||||
346 | int rightDiscordants = findDiscordants(ranks, mid + 1, end); | ||||
347 | | ||||
348 | int len = end - start + 1; | ||||
349 | int leftLen = mid - start + 1; | ||||
350 | int rightLen = end - mid; | ||||
351 | int leftLenRemain = leftLen; | ||||
352 | | ||||
353 | int leftRanks[leftLen]; | ||||
354 | int rightRanks[rightLen]; | ||||
355 | | ||||
356 | for (int i = 0; i < leftLen; i++) | ||||
357 | leftRanks[i] = ranks[start + i]; | ||||
358 | | ||||
359 | for (int i = leftLen; i < leftLen + rightLen; i++) | ||||
360 | rightRanks[i - leftLen] = ranks[start + i]; | ||||
361 | | ||||
362 | int mergeDiscordants = 0; | ||||
363 | int i = 0, j = 0, k =0; | ||||
364 | while (i < len) { | ||||
365 | if (j >= leftLen) { | ||||
366 | ranks[start + i] = rightRanks[k]; | ||||
367 | k++; | ||||
368 | } else if (k >= rightLen) { | ||||
369 | ranks[start + i] = leftRanks[j]; | ||||
370 | j++; | ||||
371 | } else if (leftRanks[j] < rightRanks[k]) { | ||||
372 | ranks[start + i] = leftRanks[j]; | ||||
373 | j++; | ||||
374 | leftLenRemain--; | ||||
375 | } else if (leftRanks[j] > rightRanks[k]) { | ||||
376 | ranks[start + i] = rightRanks[k]; | ||||
377 | mergeDiscordants += leftLenRemain; | ||||
378 | k++; | ||||
379 | } | ||||
380 | i++; | ||||
381 | } | ||||
382 | return leftDiscordants + rightDiscordants + mergeDiscordants; | ||||
383 | } | ||||
384 | | ||||
385 | void CorrelationCoefficient::convertToRanks(const Column* col, int N, QMap<double, int> &ranks) { | ||||
386 | if (!isNumericOrInteger(col)) | ||||
387 | return; | ||||
388 | | ||||
389 | //QDEBUG("in convert to ranks"); | ||||
390 | double* sortedList = new double[N]; | ||||
391 | for (int i = 0; i < N; i++) | ||||
392 | sortedList[i] = col->valueAt(i); | ||||
393 | | ||||
394 | std::sort(sortedList, sortedList + N, std::greater<double>()); | ||||
395 | | ||||
396 | // QString debug_sortedList = ""; | ||||
397 | ranks.clear(); | ||||
398 | for (int i = 0; i < N; i++) { | ||||
399 | ranks[sortedList[i]] = i + 1; | ||||
400 | // debug_sortedList += ", " + QString::number(sortedList[i]); | ||||
401 | } | ||||
402 | | ||||
403 | //QDEBUG("sorted list is " << debug_sortedList); | ||||
404 | delete[] sortedList; | ||||
405 | } | ||||
406 | | ||||
407 | void CorrelationCoefficient::convertToRanks(const Column* col, QMap<double, int> &ranks) { | ||||
408 | convertToRanks(col, findCount(col), ranks); | ||||
409 | } | ||||
410 | | ||||
411 | /***********************************************Virtual Functions******************************************************************/ | ||||
412 | | ||||
413 | QWidget* CorrelationCoefficient::view() const { | ||||
414 | if (!m_partView) { | ||||
415 | m_view = new CorrelationCoefficientView(const_cast<CorrelationCoefficient*>(this)); | ||||
416 | m_partView = m_view; | ||||
417 | } | ||||
418 | return m_partView; | ||||
419 | } |