Normand Briere
2018-05-22 42107f9a01652cb2f47228d20c1148a2a22f6a63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*
 * This software is a cooperative product of The MathWorks and the National
 * Institute of Standards and Technology (NIST) which has been released to the
 * public domain. Neither The MathWorks nor NIST assumes any responsibility
 * whatsoever for its use by other parties, and makes no guarantees, expressed
 * or implied, about its quality, reliability, or any other characteristic.
 */
 
/*
 * LinearRegression.java
 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
 *
 */
 
package //weka.core.
        matrix;
 
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
 
/**
 * Class for performing (ridged) linear regression.
 *
 * @author Fracpete (fracpete at waikato dot ac dot nz)
 * @version $Revision: 1.4 $
 */
 
public class LinearRegression
  implements RevisionHandler {
 
  /** the coefficients */
  protected double[] m_Coefficients = null;
 
  /**
   * Performs a (ridged) linear regression.
   *
   * @param a the matrix to perform the regression on
   * @param y the dependent variable vector
   * @param ridge the ridge parameter
   * @throws IllegalArgumentException if not successful
   */
  public LinearRegression(Matrix a, Matrix y, double ridge) {
    calculate(a, y, ridge);
  }
 
  /**
   * Performs a weighted (ridged) linear regression. 
   *
   * @param a the matrix to perform the regression on
   * @param y the dependent variable vector
   * @param w the array of data point weights
   * @param ridge the ridge parameter
   * @throws IllegalArgumentException if the wrong number of weights were
   * provided.
   */
  public LinearRegression(Matrix a, Matrix y, double[] w, double ridge) {
 
    if (w.length != a.getRowDimension())
      throw new IllegalArgumentException("Incorrect number of weights provided");
    Matrix weightedThis = new Matrix(
                              a.getRowDimension(), a.getColumnDimension());
    Matrix weightedDep = new Matrix(a.getRowDimension(), 1);
    for (int i = 0; i < w.length; i++) {
      double sqrt_weight = Math.sqrt(w[i]);
      for (int j = 0; j < a.getColumnDimension(); j++)
        weightedThis.set(i, j, a.get(i, j) * sqrt_weight);
      weightedDep.set(i, 0, y.get(i, 0) * sqrt_weight);
    }
 
    calculate(weightedThis, weightedDep, ridge);
  }
 
  /**
   * performs the actual regression.
   *
   * @param a the matrix to perform the regression on
   * @param y the dependent variable vector
   * @param ridge the ridge parameter
   * @throws IllegalArgumentException if not successful
   */
  protected void calculate(Matrix a, Matrix y, double ridge) {
 
    if (y.getColumnDimension() > 1)
      throw new IllegalArgumentException("Only one dependent variable allowed");
 
    int nc = a.getColumnDimension();
    m_Coefficients = new double[nc];
    Matrix xt = a.transpose();
    Matrix solution;
 
    boolean success = true;
 
    do {
      Matrix ss = xt.times(a);
 
      // Set ridge regression adjustment
      for (int i = 0; i < nc; i++)
        ss.set(i, i, ss.get(i, i) + ridge);
 
      // Carry out the regression
      Matrix bb = xt.times(y);
      for(int i = 0; i < nc; i++)
        m_Coefficients[i] = bb.get(i, 0);
 
      try {
        solution = ss.solve(new Matrix(m_Coefficients, m_Coefficients.length));
        for (int i = 0; i < nc; i++)
          m_Coefficients[i] = solution.get(i, 0);
        success = true;
      } 
      catch (Exception ex) {
        ridge *= 10;
        success = false;
      }
    } while (!success);
  }
 
  /**
   * returns the calculated coefficients
   *
   * @return the coefficients
   */
  public final double[] getCoefficients() {
    return m_Coefficients;
  }
 
  /**
   * returns the coefficients in a string representation
   */
  public String toString() {
    return Utils.arrayToString(getCoefficients());
  }
  
  /**
   * Returns the revision string.
   * 
   * @return        the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.4 $");
  }
}