1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27 package org.ximtec.igesture.algorithm.rubinebd;
28
29 import java.math.BigDecimal;
30 import java.util.HashMap;
31 import java.util.Map;
32 import java.util.concurrent.CountDownLatch;
33 import java.util.concurrent.Executor;
34 import java.util.concurrent.Executors;
35 import java.util.logging.Level;
36 import java.util.logging.Logger;
37
38 import org.apache.commons.math.linear.BigMatrix;
39 import org.apache.commons.math.linear.BigMatrixImpl;
40 import org.apache.commons.math.linear.InvalidMatrixException;
41 import org.sigtec.ink.Note;
42 import org.sigtec.util.Constant;
43 import org.ximtec.igesture.algorithm.AlgorithmException;
44 import org.ximtec.igesture.algorithm.SampleBasedAlgorithm;
45 import org.ximtec.igesture.algorithm.AlgorithmException.ExceptionType;
46 import org.ximtec.igesture.algorithm.feature.FeatureException;
47 import org.ximtec.igesture.configuration.Configuration;
48 import org.ximtec.igesture.core.Gesture;
49 import org.ximtec.igesture.core.GestureClass;
50 import org.ximtec.igesture.core.GestureSample;
51 import org.ximtec.igesture.core.GestureSet;
52 import org.ximtec.igesture.core.Result;
53 import org.ximtec.igesture.core.ResultSet;
54 import org.ximtec.igesture.util.BigDecimalVector;
55 import org.ximtec.igesture.util.DoubleVector;
56 import org.ximtec.igesture.util.GestureTool;
57
58
59
60
61
62
63
64
65
66 public class RubineAlgorithmBigDecimal extends SampleBasedAlgorithm {
67
68 private static final Logger LOGGER = Logger.getLogger(RubineAlgorithmBigDecimal.class
69 .getName());
70
71 private static final String RESULT = "Result: ";
72
73 private static final String NO_RESULT = "No result";
74
75 private static final String DISTANCE = "Distance: ";
76
77 private static final String PROBABILITY = "Probability: ";
78
79 private static final String COMPUTATION_FAILED = "Computation failed because of NaN fields in the matrix";
80
81
82
83
84 private BigMatrix matrix;
85
86
87
88
89 private BigMatrix inverse;
90
91 private RubineConfiguration rubineConfig;
92
93 private Map<GestureClass, GestureClassHelper> helpers;
94
95
96
97
98 GestureSet gestureSet;
99
100 private Executor threadPool;
101
102
103
104
105
106 public RubineAlgorithmBigDecimal() {
107 super();
108 LOGGER.setLevel(Level.SEVERE);
109 helpers = new HashMap<GestureClass, GestureClassHelper>();
110 threadPool = Executors.newFixedThreadPool(3);
111 }
112
113
114 public ResultSet recognise(Gesture< ? > gesture) throws AlgorithmException {
115 ResultSet resultSet = new ResultSet();
116 Note note = null;
117
118 if (gesture instanceof GestureSample) {
119 note = ((GestureSample)gesture).getGesture();
120 }
121
122 if (isApplicable(note)) {
123
124 try {
125 GestureSampleHelper helper = new GestureSampleHelper(note,
126 rubineConfig);
127
128 resultSet = classify(helper.getFeatureVector());
129 }
130 catch (FeatureException exception) {
131 throw new AlgorithmException(ExceptionType.Recognition);
132 }
133
134 resultSet.setGesture(gesture);
135
136 if (resultSet.getResult() != null) {
137 LOGGER.info(RESULT
138 + resultSet.getResult().getGestureClass().getName());
139 }
140 else {
141 LOGGER.info(NO_RESULT);
142 }
143 }
144
145 return resultSet;
146 }
147
148
149 public void init(Configuration config) throws AlgorithmException {
150 this.rubineConfig = new RubineConfiguration(config);
151 preprocess(GestureTool.combine(config.getGestureSets()));
152 }
153
154
155 private void preprocess(GestureSet gestureSet) throws AlgorithmException {
156 this.gestureSet = gestureSet;
157 if(gestureSet.size() == 0){
158 return;
159 }
160 CountDownLatch latch = new CountDownLatch(gestureSet.size());
161
162 for (GestureClass gestureClass : gestureSet.getGestureClasses()) {
163 GestureClassHelper helper = new GestureClassHelper(gestureClass,
164 rubineConfig, latch);
165 helpers.put(gestureClass, helper);
166 threadPool.execute(helper);
167 }
168
169 try {
170 latch.await();
171
172 this.matrix = getCovarianceMatrix();
173 inverse = matrix.inverse();
174 }
175 catch (InterruptedException e) {
176 throw new AlgorithmException(
177 AlgorithmException.ExceptionType.Initialisation, e);
178 }
179 catch (InvalidMatrixException e) {
180 e.printStackTrace();
181 throw new AlgorithmException(
182 AlgorithmException.ExceptionType.Initialisation, e);
183 }
184
185 computeWeights();
186 }
187
188
189
190
191
192
193
194 private BigMatrix getCovarianceMatrix() {
195 int dim = rubineConfig.getNumberOfFeatures();
196 BigDecimal[][] commonCovMatrix = new BigDecimal[dim][dim];
197
198 for (int i = 0; i < dim; i++) {
199
200 for (int j = 0; j < dim; j++) {
201 BigDecimal dividend = new BigDecimal(0);
202
203 for (GestureClass gestureClass : gestureSet.getGestureClasses()) {
204 BigMatrix covarianceMatrix = helpers.get(gestureClass)
205 .getCovarianceMatrix();
206
207 dividend = dividend.add(covarianceMatrix.getEntry(i, j).divide(
208 (new BigDecimal(helpers.get(gestureClass).getNumberOfSamples() - 1)), BigDecimal.ROUND_HALF_DOWN));
209 }
210
211 int divisor = -gestureSet.size();
212
213 for (GestureClass gestureClass : gestureSet.getGestureClasses()) {
214 divisor += helpers.get(gestureClass).getNumberOfSamples();
215 }
216
217
218 if(divisor != 0){
219 commonCovMatrix[i][j] = dividend.divide(new BigDecimal(divisor));
220 }else{
221 commonCovMatrix[i][j] = new BigDecimal(0);
222 }
223 }
224
225 }
226
227 return new BigMatrixImpl(commonCovMatrix);
228 }
229
230
231
232
233
234
235 private void computeWeights() throws AlgorithmException {
236 for (GestureClassHelper helper : helpers.values()) {
237 helper.computeWeights(inverse);
238 }
239
240 }
241
242
243
244
245
246
247 private ResultSet classify(BigDecimalVector inputFeatureVector) {
248 double max = -Double.MAX_VALUE;
249 GestureClass classifiedGesture = null;
250 HashMap<GestureClass, BigDecimal> classifiers = new HashMap<GestureClass, BigDecimal>();
251
252 for (GestureClassHelper helper : helpers.values()) {
253
254 BigDecimalVector weightVector = helper.getWeights();
255 BigDecimal v = helper.getInitialWeight();
256
257 for (int i = 0; i < inputFeatureVector.size(); i++) {
258 v = v.add(weightVector.get(i).multiply(inputFeatureVector.get(i)));
259 }
260
261 LOGGER.info(helper.getGestureClass().getName() + Constant.COLON_BLANK + v);
262 classifiers.put(helper.getGestureClass(), v);
263
264 if (v.doubleValue() > max) {
265 max = v.doubleValue();
266 classifiedGesture = helper.getGestureClass();
267 }
268
269
270
271
272
273
274
275 }
276
277
278 double divisor = 0;
279
280 for (GestureClass gestureClass : classifiers.keySet()) {
281
282
283 divisor += Math.exp(classifiers.get(gestureClass).doubleValue()
284 - classifiers.get(classifiedGesture).doubleValue());
285 }
286
287 double probability = 1.0 / divisor;
288
289
290 double distance = getMahalanobisDistance(classifiedGesture, inputFeatureVector);
291
292 ResultSet resultSet = new ResultSet();
293
294 if (probability >= rubineConfig.getProbability()
295 && distance <= rubineConfig.getMahalanobisDistance()) {
296 resultSet.addResult(new Result(classifiedGesture, 1));
297 }
298
299 LOGGER.info(DISTANCE + distance);
300 LOGGER.info(PROBABILITY + probability);
301 return resultSet;
302 }
303
304
305
306
307
308
309
310
311
312
313 private double getMahalanobisDistance(GestureClass gestureClass,
314 BigDecimalVector inputVector) {
315 BigDecimal result = new BigDecimal(0);
316 BigDecimalVector meanVector = helpers.get(gestureClass).getMeanFeatureVector();
317
318 for (int j = 0; j < rubineConfig.getNumberOfFeatures(); j++) {
319
320 for (int k = 0; k < rubineConfig.getNumberOfFeatures(); k++) {
321 result = result.add(
322 inverse.getEntry(j, k).multiply(
323 (inputVector.get(j).subtract(meanVector.get(j)))).multiply(
324 (inputVector.get(k).subtract(meanVector.get(k)))));
325 }
326
327 }
328
329 return result.doubleValue();
330 }
331
332
333 public RubineConfiguration.Config[] getConfigParameters() {
334 return RubineConfiguration.Config.values();
335 }
336
337
338
339
340
341
342
343 private boolean isApplicable(Note gesture) {
344 return gesture != null
345 && gesture.getPoints().size() >= rubineConfig
346 .getMinimalNumberOfPoints();
347 }
348
349
350 @Override
351 public String getDefaultParameterValue(String parameterName) {
352 return RubineConfiguration.getDefaultConfiguration().get(parameterName);
353 }
354
355
356 @Override
357 public int getType() {
358 return org.ximtec.igesture.util.Constant.TYPE_2D;
359 }
360
361 }