001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.analysis.interpolation;
018
019 import java.io.Serializable;
020 import java.util.Arrays;
021
022 import org.apache.commons.math.MathException;
023 import org.apache.commons.math.analysis.polynomials.PolynomialSplineFunction;
024 import org.apache.commons.math.exception.util.Localizable;
025 import org.apache.commons.math.exception.util.LocalizedFormats;
026 import org.apache.commons.math.util.FastMath;
027
028 /**
029 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
030 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
031 * real univariate functions.
032 * <p/>
033 * For reference, see
034 * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf">
035 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
036 * Scatterplots</a>
037 * <p/>
038 * This class implements both the loess method and serves as an interpolation
039 * adapter to it, allowing to build a spline on the obtained loess fit.
040 *
041 * @version $Revision: 990655 $ $Date: 2010-08-29 23:49:40 +0200 (dim. 29 ao??t 2010) $
042 * @since 2.0
043 */
044 public class LoessInterpolator
045 implements UnivariateRealInterpolator, Serializable {
046
047 /** Default value of the bandwidth parameter. */
048 public static final double DEFAULT_BANDWIDTH = 0.3;
049
050 /** Default value of the number of robustness iterations. */
051 public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
052
053 /**
054 * Default value for accuracy.
055 * @since 2.1
056 */
057 public static final double DEFAULT_ACCURACY = 1e-12;
058
059 /** serializable version identifier. */
060 private static final long serialVersionUID = 5204927143605193821L;
061
062 /**
063 * The bandwidth parameter: when computing the loess fit at
064 * a particular point, this fraction of source points closest
065 * to the current point is taken into account for computing
066 * a least-squares regression.
067 * <p/>
068 * A sensible value is usually 0.25 to 0.5.
069 */
070 private final double bandwidth;
071
072 /**
073 * The number of robustness iterations parameter: this many
074 * robustness iterations are done.
075 * <p/>
076 * A sensible value is usually 0 (just the initial fit without any
077 * robustness iterations) to 4.
078 */
079 private final int robustnessIters;
080
081 /**
082 * If the median residual at a certain robustness iteration
083 * is less than this amount, no more iterations are done.
084 */
085 private final double accuracy;
086
087 /**
088 * Constructs a new {@link LoessInterpolator}
089 * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
090 * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
091 * and an accuracy of {#link #DEFAULT_ACCURACY}.
092 * See {@link #LoessInterpolator(double, int, double)} for an explanation of
093 * the parameters.
094 */
095 public LoessInterpolator() {
096 this.bandwidth = DEFAULT_BANDWIDTH;
097 this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
098 this.accuracy = DEFAULT_ACCURACY;
099 }
100
101 /**
102 * Constructs a new {@link LoessInterpolator}
103 * with given bandwidth and number of robustness iterations.
104 * <p>
105 * Calling this constructor is equivalent to calling {link {@link
106 * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
107 * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
108 * </p>
109 *
110 * @param bandwidth when computing the loess fit at
111 * a particular point, this fraction of source points closest
112 * to the current point is taken into account for computing
113 * a least-squares regression.</br>
114 * A sensible value is usually 0.25 to 0.5, the default value is
115 * {@link #DEFAULT_BANDWIDTH}.
116 * @param robustnessIters This many robustness iterations are done.</br>
117 * A sensible value is usually 0 (just the initial fit without any
118 * robustness iterations) to 4, the default value is
119 * {@link #DEFAULT_ROBUSTNESS_ITERS}.
120 * @throws MathException if bandwidth does not lie in the interval [0,1]
121 * or if robustnessIters is negative.
122 * @see #LoessInterpolator(double, int, double)
123 */
124 public LoessInterpolator(double bandwidth, int robustnessIters) throws MathException {
125 this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
126 }
127
128 /**
129 * Constructs a new {@link LoessInterpolator}
130 * with given bandwidth, number of robustness iterations and accuracy.
131 *
132 * @param bandwidth when computing the loess fit at
133 * a particular point, this fraction of source points closest
134 * to the current point is taken into account for computing
135 * a least-squares regression.</br>
136 * A sensible value is usually 0.25 to 0.5, the default value is
137 * {@link #DEFAULT_BANDWIDTH}.
138 * @param robustnessIters This many robustness iterations are done.</br>
139 * A sensible value is usually 0 (just the initial fit without any
140 * robustness iterations) to 4, the default value is
141 * {@link #DEFAULT_ROBUSTNESS_ITERS}.
142 * @param accuracy If the median residual at a certain robustness iteration
143 * is less than this amount, no more iterations are done.
144 * @throws MathException if bandwidth does not lie in the interval [0,1]
145 * or if robustnessIters is negative.
146 * @see #LoessInterpolator(double, int)
147 * @since 2.1
148 */
149 public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) throws MathException {
150 if (bandwidth < 0 || bandwidth > 1) {
151 throw new MathException(LocalizedFormats.BANDWIDTH_OUT_OF_INTERVAL,
152 bandwidth);
153 }
154 this.bandwidth = bandwidth;
155 if (robustnessIters < 0) {
156 throw new MathException(LocalizedFormats.NEGATIVE_ROBUSTNESS_ITERATIONS, robustnessIters);
157 }
158 this.robustnessIters = robustnessIters;
159 this.accuracy = accuracy;
160 }
161
162 /**
163 * Compute an interpolating function by performing a loess fit
164 * on the data at the original abscissae and then building a cubic spline
165 * with a
166 * {@link org.apache.commons.math.analysis.interpolation.SplineInterpolator}
167 * on the resulting fit.
168 *
169 * @param xval the arguments for the interpolation points
170 * @param yval the values for the interpolation points
171 * @return A cubic spline built upon a loess fit to the data at the original abscissae
172 * @throws MathException if some of the following conditions are false:
173 * <ul>
174 * <li> Arguments and values are of the same size that is greater than zero</li>
175 * <li> The arguments are in a strictly increasing order</li>
176 * <li> All arguments and values are finite real numbers</li>
177 * </ul>
178 */
179 public final PolynomialSplineFunction interpolate(
180 final double[] xval, final double[] yval) throws MathException {
181 return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
182 }
183
184 /**
185 * Compute a weighted loess fit on the data at the original abscissae.
186 *
187 * @param xval the arguments for the interpolation points
188 * @param yval the values for the interpolation points
189 * @param weights point weights: coefficients by which the robustness weight of a point is multiplied
190 * @return values of the loess fit at corresponding original abscissae
191 * @throws MathException if some of the following conditions are false:
192 * <ul>
193 * <li> Arguments and values are of the same size that is greater than zero</li>
194 * <li> The arguments are in a strictly increasing order</li>
195 * <li> All arguments and values are finite real numbers</li>
196 * </ul>
197 * @since 2.1
198 */
199 public final double[] smooth(final double[] xval, final double[] yval, final double[] weights)
200 throws MathException {
201 if (xval.length != yval.length) {
202 throw new MathException(LocalizedFormats.MISMATCHED_LOESS_ABSCISSA_ORDINATE_ARRAYS,
203 xval.length, yval.length);
204 }
205
206 final int n = xval.length;
207
208 if (n == 0) {
209 throw new MathException(LocalizedFormats.LOESS_EXPECTS_AT_LEAST_ONE_POINT);
210 }
211
212 checkAllFiniteReal(xval, LocalizedFormats.NON_REAL_FINITE_ABSCISSA);
213 checkAllFiniteReal(yval, LocalizedFormats.NON_REAL_FINITE_ORDINATE);
214 checkAllFiniteReal(weights, LocalizedFormats.NON_REAL_FINITE_WEIGHT);
215
216 checkStrictlyIncreasing(xval);
217
218 if (n == 1) {
219 return new double[]{yval[0]};
220 }
221
222 if (n == 2) {
223 return new double[]{yval[0], yval[1]};
224 }
225
226 int bandwidthInPoints = (int) (bandwidth * n);
227
228 if (bandwidthInPoints < 2) {
229 throw new MathException(LocalizedFormats.TOO_SMALL_BANDWIDTH,
230 n, 2.0 / n, bandwidth);
231 }
232
233 final double[] res = new double[n];
234
235 final double[] residuals = new double[n];
236 final double[] sortedResiduals = new double[n];
237
238 final double[] robustnessWeights = new double[n];
239
240 // Do an initial fit and 'robustnessIters' robustness iterations.
241 // This is equivalent to doing 'robustnessIters+1' robustness iterations
242 // starting with all robustness weights set to 1.
243 Arrays.fill(robustnessWeights, 1);
244
245 for (int iter = 0; iter <= robustnessIters; ++iter) {
246 final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
247 // At each x, compute a local weighted linear regression
248 for (int i = 0; i < n; ++i) {
249 final double x = xval[i];
250
251 // Find out the interval of source points on which
252 // a regression is to be made.
253 if (i > 0) {
254 updateBandwidthInterval(xval, weights, i, bandwidthInterval);
255 }
256
257 final int ileft = bandwidthInterval[0];
258 final int iright = bandwidthInterval[1];
259
260 // Compute the point of the bandwidth interval that is
261 // farthest from x
262 final int edge;
263 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
264 edge = ileft;
265 } else {
266 edge = iright;
267 }
268
269 // Compute a least-squares linear fit weighted by
270 // the product of robustness weights and the tricube
271 // weight function.
272 // See http://en.wikipedia.org/wiki/Linear_regression
273 // (section "Univariate linear case")
274 // and http://en.wikipedia.org/wiki/Weighted_least_squares
275 // (section "Weighted least squares")
276 double sumWeights = 0;
277 double sumX = 0;
278 double sumXSquared = 0;
279 double sumY = 0;
280 double sumXY = 0;
281 double denom = FastMath.abs(1.0 / (xval[edge] - x));
282 for (int k = ileft; k <= iright; ++k) {
283 final double xk = xval[k];
284 final double yk = yval[k];
285 final double dist = (k < i) ? x - xk : xk - x;
286 final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k];
287 final double xkw = xk * w;
288 sumWeights += w;
289 sumX += xkw;
290 sumXSquared += xk * xkw;
291 sumY += yk * w;
292 sumXY += yk * xkw;
293 }
294
295 final double meanX = sumX / sumWeights;
296 final double meanY = sumY / sumWeights;
297 final double meanXY = sumXY / sumWeights;
298 final double meanXSquared = sumXSquared / sumWeights;
299
300 final double beta;
301 if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
302 beta = 0;
303 } else {
304 beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
305 }
306
307 final double alpha = meanY - beta * meanX;
308
309 res[i] = beta * x + alpha;
310 residuals[i] = FastMath.abs(yval[i] - res[i]);
311 }
312
313 // No need to recompute the robustness weights at the last
314 // iteration, they won't be needed anymore
315 if (iter == robustnessIters) {
316 break;
317 }
318
319 // Recompute the robustness weights.
320
321 // Find the median residual.
322 // An arraycopy and a sort are completely tractable here,
323 // because the preceding loop is a lot more expensive
324 System.arraycopy(residuals, 0, sortedResiduals, 0, n);
325 Arrays.sort(sortedResiduals);
326 final double medianResidual = sortedResiduals[n / 2];
327
328 if (FastMath.abs(medianResidual) < accuracy) {
329 break;
330 }
331
332 for (int i = 0; i < n; ++i) {
333 final double arg = residuals[i] / (6 * medianResidual);
334 if (arg >= 1) {
335 robustnessWeights[i] = 0;
336 } else {
337 final double w = 1 - arg * arg;
338 robustnessWeights[i] = w * w;
339 }
340 }
341 }
342
343 return res;
344 }
345
346 /**
347 * Compute a loess fit on the data at the original abscissae.
348 *
349 * @param xval the arguments for the interpolation points
350 * @param yval the values for the interpolation points
351 * @return values of the loess fit at corresponding original abscissae
352 * @throws MathException if some of the following conditions are false:
353 * <ul>
354 * <li> Arguments and values are of the same size that is greater than zero</li>
355 * <li> The arguments are in a strictly increasing order</li>
356 * <li> All arguments and values are finite real numbers</li>
357 * </ul>
358 */
359 public final double[] smooth(final double[] xval, final double[] yval)
360 throws MathException {
361 if (xval.length != yval.length) {
362 throw new MathException(LocalizedFormats.MISMATCHED_LOESS_ABSCISSA_ORDINATE_ARRAYS,
363 xval.length, yval.length);
364 }
365
366 final double[] unitWeights = new double[xval.length];
367 Arrays.fill(unitWeights, 1.0);
368
369 return smooth(xval, yval, unitWeights);
370 }
371
372 /**
373 * Given an index interval into xval that embraces a certain number of
374 * points closest to xval[i-1], update the interval so that it embraces
375 * the same number of points closest to xval[i], ignoring zero weights.
376 *
377 * @param xval arguments array
378 * @param weights weights array
379 * @param i the index around which the new interval should be computed
380 * @param bandwidthInterval a two-element array {left, right} such that: <p/>
381 * <tt>(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])</tt>
382 * <p/> and also <p/>
383 * <tt>(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])</tt>.
384 * The array will be updated.
385 */
386 private static void updateBandwidthInterval(final double[] xval, final double[] weights,
387 final int i,
388 final int[] bandwidthInterval) {
389 final int left = bandwidthInterval[0];
390 final int right = bandwidthInterval[1];
391
392 // The right edge should be adjusted if the next point to the right
393 // is closer to xval[i] than the leftmost point of the current interval
394 int nextRight = nextNonzero(weights, right);
395 if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
396 int nextLeft = nextNonzero(weights, bandwidthInterval[0]);
397 bandwidthInterval[0] = nextLeft;
398 bandwidthInterval[1] = nextRight;
399 }
400 }
401
402 /**
403 * Returns the smallest index j such that j > i && (j==weights.length || weights[j] != 0)
404 * @param weights weights array
405 * @param i the index from which to start search; must be < weights.length
406 * @return the smallest index j such that j > i && (j==weights.length || weights[j] != 0)
407 */
408 private static int nextNonzero(final double[] weights, final int i) {
409 int j = i + 1;
410 while(j < weights.length && weights[j] == 0) {
411 j++;
412 }
413 return j;
414 }
415
416 /**
417 * Compute the
418 * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
419 * weight function
420 *
421 * @param x the argument
422 * @return (1-|x|^3)^3
423 */
424 private static double tricube(final double x) {
425 final double tmp = 1 - x * x * x;
426 return tmp * tmp * tmp;
427 }
428
429 /**
430 * Check that all elements of an array are finite real numbers.
431 *
432 * @param values the values array
433 * @param pattern pattern of the error message
434 * @throws MathException if one of the values is not a finite real number
435 */
436 private static void checkAllFiniteReal(final double[] values, final Localizable pattern)
437 throws MathException {
438 for (int i = 0; i < values.length; i++) {
439 final double x = values[i];
440 if (Double.isInfinite(x) || Double.isNaN(x)) {
441 throw new MathException(pattern, i, x);
442 }
443 }
444 }
445
446 /**
447 * Check that elements of the abscissae array are in a strictly
448 * increasing order.
449 *
450 * @param xval the abscissae array
451 * @throws MathException if the abscissae array
452 * is not in a strictly increasing order
453 */
454 private static void checkStrictlyIncreasing(final double[] xval)
455 throws MathException {
456 for (int i = 0; i < xval.length; ++i) {
457 if (i >= 1 && xval[i - 1] >= xval[i]) {
458 throw new MathException(LocalizedFormats.OUT_OF_ORDER_ABSCISSA_ARRAY,
459 i - 1, xval[i - 1], i, xval[i]);
460 }
461 }
462 }
463 }