1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27 package org.ximtec.igesture.algorithm.rubine;
28
29 import java.util.HashMap;
30 import java.util.Map;
31 import java.util.concurrent.CountDownLatch;
32 import java.util.concurrent.Executor;
33 import java.util.concurrent.Executors;
34 import java.util.logging.Level;
35 import java.util.logging.Logger;
36
37 import org.apache.commons.math.linear.InvalidMatrixException;
38 import org.apache.commons.math.linear.RealMatrix;
39 import org.apache.commons.math.linear.RealMatrixImpl;
40 import org.sigtec.ink.Note;
41 import org.sigtec.util.Constant;
42 import org.ximtec.igesture.algorithm.AlgorithmException;
43 import org.ximtec.igesture.algorithm.SampleBasedAlgorithm;
44 import org.ximtec.igesture.algorithm.AlgorithmException.ExceptionType;
45 import org.ximtec.igesture.algorithm.feature.FeatureException;
46 import org.ximtec.igesture.configuration.Configuration;
47 import org.ximtec.igesture.core.Gesture;
48 import org.ximtec.igesture.core.GestureClass;
49 import org.ximtec.igesture.core.GestureSample;
50 import org.ximtec.igesture.core.GestureSet;
51 import org.ximtec.igesture.core.Result;
52 import org.ximtec.igesture.core.ResultSet;
53 import org.ximtec.igesture.util.DoubleVector;
54 import org.ximtec.igesture.util.GestureTool;
55
56
57
58
59
60
61
62
63
64 public class RubineAlgorithm extends SampleBasedAlgorithm {
65
66 private static final Logger LOGGER = Logger.getLogger(RubineAlgorithm.class
67 .getName());
68
69 private static final String RESULT = "Result: ";
70
71 private static final String NO_RESULT = "No result";
72
73 private static final String DISTANCE = "Distance: ";
74
75 private static final String PROBABILITY = "Probability: ";
76
77 private static final String COMPUTATION_FAILED = "Computation failed because of NaN fields in the matrix";
78
79
80
81
82 private RealMatrix matrix;
83
84
85
86
87 private RealMatrix inverse;
88
89 private RubineConfiguration rubineConfig;
90
91 private Map<GestureClass, GestureClassHelper> helpers;
92
93
94
95
96 GestureSet gestureSet;
97
98 private Executor threadPool;
99
100
101
102
103
104 public RubineAlgorithm() {
105 super();
106 LOGGER.setLevel(Level.SEVERE);
107 helpers = new HashMap<GestureClass, GestureClassHelper>();
108 threadPool = Executors.newFixedThreadPool(3);
109 }
110
111
112 public ResultSet recognise(Gesture< ? > gesture) throws AlgorithmException {
113 ResultSet resultSet = new ResultSet();
114 Note note = null;
115
116 if (gesture instanceof GestureSample) {
117 note = ((GestureSample)gesture).getGesture();
118 }
119
120 if (isApplicable(note)) {
121
122 try {
123 GestureSampleHelper helper = new GestureSampleHelper(note,
124 rubineConfig);
125
126 resultSet = classify(helper.getFeatureVector());
127 }
128 catch (FeatureException exception) {
129 throw new AlgorithmException(ExceptionType.Recognition);
130 }
131
132 resultSet.setGesture(gesture);
133
134 if (resultSet.getResult() != null) {
135 LOGGER.info(RESULT
136 + resultSet.getResult().getGestureClass().getName());
137 }
138 else {
139 LOGGER.info(NO_RESULT);
140 }
141 }
142
143 return resultSet;
144 }
145
146
147 public void init(Configuration config) throws AlgorithmException {
148 this.rubineConfig = new RubineConfiguration(config);
149 preprocess(GestureTool.combine(config.getGestureSets()));
150 }
151
152
153 private void preprocess(GestureSet gestureSet) throws AlgorithmException {
154 this.gestureSet = gestureSet;
155 CountDownLatch latch = new CountDownLatch(gestureSet.size());
156
157 for (GestureClass gestureClass : gestureSet.getGestureClasses()) {
158 GestureClassHelper helper = new GestureClassHelper(gestureClass,
159 rubineConfig, latch);
160 helpers.put(gestureClass, helper);
161 threadPool.execute(helper);
162 }
163
164 try {
165 latch.await();
166
167 this.matrix = getCovarianceMatrix();
168 inverse = matrix.inverse();
169 }
170 catch (InterruptedException e) {
171 throw new AlgorithmException(
172 AlgorithmException.ExceptionType.Initialisation, e);
173 }
174 catch (InvalidMatrixException e) {
175 e.printStackTrace();
176 throw new AlgorithmException(
177 AlgorithmException.ExceptionType.Initialisation, e);
178 }
179
180 computeWeights();
181 }
182
183
184
185
186
187
188
189 private RealMatrix getCovarianceMatrix() {
190 int dim = rubineConfig.getNumberOfFeatures();
191 double[][] commonCovMatrix = new double[dim][dim];
192
193 for (int i = 0; i < dim; i++) {
194
195 for (int j = 0; j < dim; j++) {
196 double dividend = 0;
197
198 for (GestureClass gestureClass : gestureSet.getGestureClasses()) {
199 RealMatrix covarianceMatrix = helpers.get(gestureClass)
200 .getCovarianceMatrix();
201
202 dividend += covarianceMatrix.getEntry(i, j)
203 / (helpers.get(gestureClass).getNumberOfSamples() - 1);
204 }
205
206 double divisor = -gestureSet.size();
207
208 for (GestureClass gestureClass : gestureSet.getGestureClasses()) {
209 divisor += helpers.get(gestureClass).getNumberOfSamples();
210 }
211
212 commonCovMatrix[i][j] = dividend / divisor;
213 }
214
215 }
216
217 return new RealMatrixImpl(commonCovMatrix);
218 }
219
220
221
222
223
224
225 private void computeWeights() throws AlgorithmException {
226 for (GestureClassHelper helper : helpers.values()) {
227 helper.computeWeights(inverse);
228 }
229
230 }
231
232
233
234
235
236
237 private ResultSet classify(DoubleVector inputFeatureVector) {
238 double max = -Double.MAX_VALUE;
239 GestureClass classifiedGesture = null;
240 HashMap<GestureClass, Double> classifiers = new HashMap<GestureClass, Double>();
241
242 for (GestureClassHelper helper : helpers.values()) {
243
244 DoubleVector weightVector = helper.getWeights();
245 double v = helper.getInitialWeight();
246
247 for (int i = 0; i < inputFeatureVector.size(); i++) {
248 v += weightVector.get(i) * inputFeatureVector.get(i);
249 }
250
251 LOGGER.info(helper.getGestureClass().getName() + Constant.COLON_BLANK
252 + v);
253 classifiers.put(helper.getGestureClass(), v);
254
255 if (v > max) {
256 max = v;
257 classifiedGesture = helper.getGestureClass();
258 }
259
260 if (Double.isNaN(v)) {
261 LOGGER.warning(COMPUTATION_FAILED);
262 return new ResultSet();
263 }
264
265 }
266
267
268 double divisor = 0;
269
270 for (GestureClass gestureClass : classifiers.keySet()) {
271 divisor += Math.exp(classifiers.get(gestureClass)
272 - classifiers.get(classifiedGesture));
273 }
274
275 double probability = 1.0 / divisor;
276
277
278 double distance = getMahalanobisDistance(classifiedGesture,
279 inputFeatureVector);
280
281 ResultSet resultSet = new ResultSet();
282
283 if (probability >= rubineConfig.getProbability()
284 && distance <= rubineConfig.getMahalanobisDistance()) {
285 resultSet.addResult(new Result(classifiedGesture, 1));
286 }
287
288 LOGGER.info(DISTANCE + distance);
289 LOGGER.info(PROBABILITY + probability);
290 return resultSet;
291 }
292
293
294
295
296
297
298
299
300
301
302 private double getMahalanobisDistance(GestureClass gestureClass,
303 DoubleVector inputVector) {
304 double result = 0;
305 DoubleVector meanVector = helpers.get(gestureClass).getMeanFeatureVector();
306
307 for (int j = 0; j < rubineConfig.getNumberOfFeatures(); j++) {
308
309 for (int k = 0; k < rubineConfig.getNumberOfFeatures(); k++) {
310 result += inverse.getEntry(j, k)
311 * (inputVector.get(j) - meanVector.get(j))
312 * (inputVector.get(k) - meanVector.get(k));
313 }
314
315 }
316
317 return result;
318 }
319
320
321 public RubineConfiguration.Config[] getConfigParameters() {
322 return RubineConfiguration.Config.values();
323 }
324
325
326
327
328
329
330
331 private boolean isApplicable(Note gesture) {
332 return gesture != null
333 && gesture.getPoints().size() >= rubineConfig
334 .getMinimalNumberOfPoints();
335 }
336
337
338 @Override
339 public String getDefaultParameterValue(String parameterName) {
340 return RubineConfiguration.getDefaultConfiguration().get(parameterName);
341 }
342
343
344 @Override
345 public int getType() {
346 return org.ximtec.igesture.util.Constant.TYPE_2D;
347 }
348
349 }