GestureDatabase.java

00001 package edu.stanford.hci.r3.pen.gesture;
00002 
00003 import java.awt.BorderLayout;
00004 import java.awt.Component;
00005 import java.awt.Container;
00006 import java.awt.event.ActionEvent;
00007 import java.awt.event.ActionListener;
00008 import java.awt.event.FocusEvent;
00009 import java.awt.event.FocusListener;
00010 import java.io.BufferedReader;
00011 import java.io.File;
00012 import java.io.FileWriter;
00013 import java.io.IOException;
00014 import java.io.InputStreamReader;
00015 import java.io.Writer;
00016 import java.util.ArrayList;
00017 import java.util.Date;
00018 import java.util.HashMap;
00019 import java.util.Random;
00020 
00021 import javax.swing.JComponent;
00022 import javax.swing.JFrame;
00023 import javax.swing.JLabel;
00024 import javax.swing.JPanel;
00025 import javax.swing.JTextField;
00026 import javax.swing.Spring;
00027 import javax.swing.SpringLayout;
00028 import javax.swing.WindowConstants;
00029 
00030 import com.thoughtworks.xstream.XStream;
00031 
00032 import edu.stanford.hci.r3.components.InkPanel;
00033 import edu.stanford.hci.r3.pen.PenSample;
00034 import edu.stanford.hci.r3.pen.ink.Ink;
00035 import edu.stanford.hci.r3.pen.ink.InkStroke;
00036 import edu.stanford.hci.r3.render.ink.InkRenderer;
00037 import edu.stanford.hci.r3.units.Pixels;
00038 import edu.stanford.hci.r3.units.Points;
00039 import edu.stanford.hci.r3.util.WindowUtils;
00040 
00041 public class GestureDatabase implements ActionListener, FocusListener {
00042         private transient static JTextField entryField;
00043 
00044         private transient static JFrame inkDisplay;
00045 
00046         private transient static InkPanel inkPanel;
00047 
00048         transient static GestureDatabase instance;
00049 
00050         private transient static JLabel labelField;
00051 
00052         private transient static JPanel mainPanel;
00053 
00054         private transient static JPanel statusPanel;
00055 
00056         private transient static BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in));
00057 
00058         /* Used by makeCompactGrid. */
00059         private static SpringLayout.Constraints getConstraintsForCell(int row, int col, Container parent, int cols) {
00060                 SpringLayout layout = (SpringLayout) parent.getLayout();
00061                 Component c = parent.getComponent(row * cols + col);
00062                 return layout.getConstraints(c);
00063         }
00064 
00065         static JFrame getInkDisplay() {
00066                 if (inkDisplay == null) {
00067                         JFrame.setDefaultLookAndFeelDecorated(true);
00068                         inkDisplay = new JFrame("Sketch! Display");
00069                         inkDisplay.setContentPane(getMainPanel());
00070                         inkDisplay.setSize(690, 740);
00071                         inkDisplay.setLocation(WindowUtils.getWindowOrigin(inkDisplay, WindowUtils.DESKTOP_CENTER));
00072                         inkDisplay.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
00073                         // inkDisplay.pack();
00074                         inkDisplay.setVisible(true);
00075                 }
00076                 return inkDisplay;
00077         }
00078 
00082         private static Container getInkPanel() {
00083                 if (inkPanel == null) {
00084                         inkPanel = new InkPanel();
00085                 }
00086                 return inkPanel;
00087         }
00088 
00092         private static Container getMainPanel() {
00093                 if (mainPanel == null) {
00094                         mainPanel = new JPanel();
00095                         mainPanel.setLayout(new BorderLayout());
00096                         mainPanel.add(getInkPanel(), BorderLayout.CENTER);
00097                         mainPanel.add(getStatusPanel(), BorderLayout.SOUTH);
00098                 }
00099                 return mainPanel;
00100         }
00101 
00105         private static Component getStatusPanel() {
00106                 if (statusPanel == null) {
00107                         statusPanel = new JPanel(new SpringLayout());
00108                         String[] labelStrings = { "0: Add a gesture", "1: Add more of an existing gesture", "2: Save",
00109                                         "3: List", "4: Test (interactive)", "5: Add a test gesture (for existing)",
00110                                         "6: Run autotest", "7: Determine class parameters",
00111                                         "8: Assign class labels to unlabeled gestures", "9: Add unlabeled gestures",
00112                                         "10: Add unlabeled test gestures", "11: Optimize cost weighting",
00113                                         "12: Run leave-10%-out test", "13: Move gestures out of test",
00114                                         "14: Run leave-one-user-out test", "15: Compute best examples per class",
00115                                         "16: Run leave-10%-out test using only best examples", "17: Create best image",
00116                                         "18: Generate gesture", "-1: Exit", "Option: " };
00117                         JLabel[] labels = new JLabel[labelStrings.length];
00118                         JComponent[] fields = new JComponent[labelStrings.length];
00119                         int fieldNum = 0;
00120                         for (fieldNum = 0; fieldNum < labels.length - 1; fieldNum++) {
00121                                 fields[fieldNum] = new JLabel();
00122                         }
00123                         entryField = new JTextField();
00124                         entryField.setColumns(20);
00125                         fields[fieldNum++] = entryField;
00126                         for (int i = 0; i < labelStrings.length; i++) {
00127                                 labels[i] = new JLabel(labelStrings[i], JLabel.TRAILING);
00128                                 labels[i].setLabelFor(fields[i]);
00129                                 statusPanel.add(labels[i]);
00130                                 statusPanel.add(fields[i]);
00131                                 JTextField tf = null;
00132                                 if (fields[i] instanceof JTextField) {
00133                                         labelField = labels[i];
00134                                         tf = (JTextField) fields[i];
00135                                         tf.addActionListener(instance);
00136                                         tf.addFocusListener(instance);
00137                                 }
00138                         }
00139                         int GAP = 10;
00140                         makeCompactGrid(statusPanel, labelStrings.length, 2, GAP, GAP, GAP, GAP / 2);
00141                 }
00142                 return statusPanel;
00143         }
00144 
00145         public static void makeCompactGrid(Container parent, int rows, int cols, int initialX, int initialY,
00146                         int xPad, int yPad) {
00147                 SpringLayout layout;
00148                 try {
00149                         layout = (SpringLayout) parent.getLayout();
00150                 } catch (ClassCastException exc) {
00151                         System.err.println("The first argument to makeCompactGrid must use SpringLayout.");
00152                         return;
00153                 }
00154 
00155                 // Align all cells in each column and make them the same width.
00156                 Spring x = Spring.constant(initialX);
00157                 for (int c = 0; c < cols; c++) {
00158                         Spring width = Spring.constant(0);
00159                         for (int r = 0; r < rows; r++) {
00160                                 width = Spring.max(width, getConstraintsForCell(r, c, parent, cols).getWidth());
00161                         }
00162                         for (int r = 0; r < rows; r++) {
00163                                 SpringLayout.Constraints constraints = getConstraintsForCell(r, c, parent, cols);
00164                                 constraints.setX(x);
00165                                 constraints.setWidth(width);
00166                         }
00167                         x = Spring.sum(x, Spring.sum(width, Spring.constant(xPad)));
00168                 }
00169 
00170                 // Align all cells in each row and make them the same height.
00171                 Spring y = Spring.constant(initialY);
00172                 for (int r = 0; r < rows; r++) {
00173                         Spring height = Spring.constant(0);
00174                         for (int c = 0; c < cols; c++) {
00175                                 height = Spring.max(height, getConstraintsForCell(r, c, parent, cols).getHeight());
00176                         }
00177                         for (int c = 0; c < cols; c++) {
00178                                 SpringLayout.Constraints constraints = getConstraintsForCell(r, c, parent, cols);
00179                                 constraints.setY(y);
00180                                 constraints.setHeight(height);
00181                         }
00182                         y = Spring.sum(y, Spring.sum(height, Spring.constant(yPad)));
00183                 }
00184 
00185                 // Set the parent's size.
00186                 SpringLayout.Constraints pCons = layout.getConstraints(parent);
00187                 pCons.setConstraint(SpringLayout.SOUTH, y);
00188                 pCons.setConstraint(SpringLayout.EAST, x);
00189         }
00190 
00191         private transient ArrayList<ShapeContext> bestExamples = null;
00192 
00193         private transient ArrayList<Gesture> bestGestures = null;
00194 
00195         private transient boolean commandMode;
00196 
00197         private transient Thread commandThread;
00198 
00199         String databaseName;
00200 
00201         ArrayList<Gesture> gestures = new ArrayList<Gesture>();
00202 
00203         private transient boolean inputAvailable;
00204 
00205         private transient String inputString;
00206 
00207         PenGestureListener listener;
00208 
00209         ArrayList<Gesture> testGestures = new ArrayList<Gesture>();
00210 
00211         ArrayList<ShapeContext> unlabeledContexts = new ArrayList<ShapeContext>();
00212 
00213         private ArrayList<ShapeContext> unlabeledTestContexts = new ArrayList<ShapeContext>();
00214 
00215         public GestureDatabase(String databaseName) {
00216                 this.databaseName = databaseName;
00217                 listener = new PenGestureListener();
00218         }
00219 
00220         public void actionPerformed(ActionEvent e) {
00221                 if (commandMode) {
00222                         inputString = e.getActionCommand();
00223                         inputAvailable = true;
00224                 } else {
00225                         try {
00226                                 final int option = Integer.parseInt(e.getActionCommand());
00227                                 commandMode = true;
00228                                 commandThread = new Thread(new Runnable() {
00229                                         public void run() {
00230                                                 try {
00231                                                         GestureDatabase.instance.command(option);
00232                                                 } catch (NumberFormatException e) {
00233                                                         // TODO Auto-generated catch block
00234                                                         e.printStackTrace();
00235                                                 } catch (IOException e) {
00236                                                         // TODO Auto-generated catch block
00237                                                         e.printStackTrace();
00238                                                 }
00239                                         }
00240                                 });
00241                                 commandThread.start();
00242                         } catch (NumberFormatException error) {
00243                         }
00244                 }
00245                 entryField.setText("");
00246         }
00247 
00248         public void autotest() {
00249                 autotest(gestures);
00250         }
00251 
00252         public void autotest(ArrayList<Gesture> gestures) {
00253                 // use knn, most occurences or best average weight
00254                 int tested = 0;
00255                 int correct = 0;
00256                 for (int i = 0; i < gestures.size(); i++) {
00257                         Gesture gesture = gestures.get(i);
00258                         Gesture testGesture = testGestures.get(i);
00259                         for (ShapeContext context : testGesture.contexts) {
00260                                 tested++;
00261                                 String assignment = test(context, false, gestures);
00262                                 if (assignment.compareTo(gesture.name) == 0)
00263                                         correct++;
00264                                 else {
00265                                         test(context, true, gestures);
00266                                         System.out.println(gesture.name + " misclassified as " + assignment);
00267                                 }
00268                         }
00269                 }
00270                 System.out.println("Tested " + tested + " contexts, " + correct + " correct.");
00271         }
00272 
00273         private void bestExamples() {
00274                 double totalN = 0;
00275                 int count = 0;
00276                 for (Gesture gesture : gestures) {
00277                         for (ShapeContext context : gesture.contexts) {
00278                                 count++;
00279                                 totalN += context.controlPoints.size();
00280                         }
00281                 }
00282                 System.out.println("Average points per context: " + totalN / count);
00283                 bestExamples = new ArrayList<ShapeContext>();
00284                 bestGestures = new ArrayList<Gesture>();
00285                 for (int i = 0; i < gestures.size(); i++) {
00286                         Gesture gesture = gestures.get(i);
00287                         ShapeContext best = null;
00288                         double distance = Double.MAX_VALUE;
00289                         System.out.println("Computing best example for " + gesture.name);
00290                         if (gesture.contexts.size() == 1)
00291                                 best = gesture.contexts.get(0);
00292                         else {
00293                                 for (int j = gesture.contexts.size() - 1; j >= 0; j--) {
00294                                         ShapeContext context = gesture.contexts.get(j);
00295                                         gesture.contexts.remove(context);
00296                                         // double d = gesture.bestMatch(context); // dumb; this just tells me who's
00297                                         // closest to their nearest neighbor. We want centroid.
00298                                         double d = gesture.averageMatch(context);
00299                                         if (d < distance) {
00300                                                 distance = d;
00301                                                 best = context;
00302                                         }
00303                                         gesture.contexts.add(context);
00304                                         // AVINOTE: remove this
00305                                         break;
00306                                 }
00307                         }
00308                         bestExamples.add(best);
00309                         Gesture bestGesture = new Gesture(gesture.name);
00310                         bestGesture.addGesture(best);
00311                         bestGestures.add(bestGesture);
00312                         print("Best example for " + gesture.name);
00313                         display(best, mainPanel.getWidth() / 2, mainPanel.getHeight() / 2);
00314                 }
00315         }
00316 
00317         private void chunk(double threshold) {
00318                 Random rand = new Random();
00319                 moveAllToTrain();
00320                 for (int i = 0; i < gestures.size(); i++) {
00321                         ArrayList<ShapeContext> contexts = gestures.get(i).contexts;
00322                         ArrayList<ShapeContext> testContexts = testGestures.get(i).contexts;
00323                         for (int j = contexts.size() - 1; j >= 0; j--) {
00324                                 if (rand.nextDouble() < threshold) {
00325                                         ShapeContext context = contexts.get(j);
00326                                         contexts.remove(context);
00327                                         testContexts.add(context);
00328                                 }
00329                         }
00330                 }
00331         }
00332 
00333         public void command(int option) throws IOException {
00334                 // build a set of gestures
00335                 // options: add a gesture (N times)
00336                 // add more of a gesture (N)
00337                 // remove some of a gesture (? with images?)
00338                 // int option;
00339                 // do{
00340                 /*
00341                  * System.out.println("Options:\n" + "0: Add a gesture\n" + "1: Add more of an existing gesture\n" +
00342                  * "2: Save\n" + "3: List\n" + "4: Test (interactive)\n" + "5: Add a test gesture (for existing)\n" +
00343                  * "6: Run autotest\n" + "7: Determine class parameters\n" + "8: Assign class labels to unlabeled
00344                  * gestures\n" + "9: Add unlabeled gestures\n" + "-1: Exit");
00345                  */
00346                 // option = Integer.parseInt(stdin.readLine());
00347                 int count, index;
00348                 String author;
00349                 Gesture gesture;
00350                 switch (option) {
00351                 case 0: // add a gesture
00352                         print("Name: ");
00353                         String name = getString();
00354                         print("Examples: ");
00355                         count = Integer.parseInt(getString());
00356                         gesture = new Gesture(name);
00357                         gestures.add(gesture);
00358                         testGestures.add(new Gesture(name + "TEST"));
00359                         listener.setContexts(gesture.contexts, count);
00360                         break;
00361                 case 1: // add examples to a gesture
00362                         print("Index: ");
00363                         index = Integer.parseInt(getString());
00364                         gesture = gestures.get(index);
00365                         System.out.println("Name: " + gesture.name);
00366                         print("Examples: ");
00367                         count = Integer.parseInt(getString());
00368                         listener.setContexts(gesture.contexts, count);
00369                         break;
00370                 case 3:
00371                         System.out.println("Gestures:");
00372                         for (int i = 0; i < gestures.size(); i++)
00373                                 System.out.println("Gesture " + gestures.get(i).name + ": " + gestures.get(i).size());
00374                         System.out.println("Tests:");
00375                         for (int i = 0; i < testGestures.size(); i++)
00376                                 System.out.println("Gesture " + testGestures.get(i).name + ": " + testGestures.get(i).size());
00377                         System.out.println("Queued gestures: " + unlabeledContexts.size());
00378                         System.out.println("Queued test gestures: " + unlabeledTestContexts.size());
00379                         break;
00380                 case 4:
00381                         listener.setContexts(null, 0);
00382                         listener.setDatabase(this);
00383                         print("Hit enter to end test.");
00384                         getString();
00385                         listener.setDatabase(null);
00386                 case 5: // add test examples to a gesture
00387                         print("Name: ");
00388                         name = getString();
00389                         gesture = null;
00390                         for (Gesture g : testGestures)
00391                                 if (g.name.compareTo(name + "TEST") == 0) {
00392                                         gesture = g;
00393                                         break;
00394                                 }
00395                         print("Examples: ");
00396                         count = Integer.parseInt(getString());
00397                         listener.setContexts(gesture.contexts, count);
00398                         break;
00399                 case 6:
00400                         // System.out.println("Test without time: ");
00401                         // ShapeContext.bands = 2;
00402                         // autotest();
00403                         System.out.println("Test using time: ");
00404                         ShapeContext.bands = 3;
00405                         autotest();
00406                         break;
00407                 case 7:
00408                         determineClassParameters();
00409                         break;
00410                 case 8:
00411                         labelGestures();
00412                         break;
00413                 case 9:
00414                         print("Author: ");
00415                         author = getString();
00416                         System.out.println(author);
00417                         listener.setAuthor(author);
00418                         listener.setContexts(unlabeledContexts, 1000);
00419                         print("Hit enter to end test.");
00420                         getString();
00421                         break;
00422                 case 10:
00423                         print("Author: ");
00424                         author = getString();
00425                         System.out.println(author);
00426                         listener.setAuthor(author);
00427                         listener.setContexts(unlabeledTestContexts, 1000);
00428                         print("Hit enter to end test.");
00429                         getString();
00430                         break;
00431                 case 11:
00432                         leaveOneOutOptimizeCostWeighting();
00433                         break;
00434                 case 12:
00435                         leaveChunkOut(10);
00436                         break;
00437                 case 13:
00438                         moveAllToTrain();
00439                         break;
00440                 case 14:
00441                         leaveOneUserOut();
00442                         break;
00443                 case 15:
00444                         bestExamples();
00445                         break;
00446                 case 16:
00447                         testChunkOutOnBest(20);
00448                         break;
00449                 case 17:
00450                         createBestImage();
00451                         break;
00452                 case 18:
00453                         generateGesture();
00454                         break;
00455                 case -1:
00456                         System.out.println("Exiting.");
00457                         System.exit(0); // no automatic save on exit
00458                 case 2:
00459                         Writer writer = new FileWriter(new File("C:\\dev\\quill\\data\\" + databaseName + ".gp"));
00460                         quillWrite(writer);
00461                         Save(new FileWriter(new File(databaseName + ".xml")));
00462                         break;
00463                 }
00464                 // Thread.yield();
00465                 // } while (option > -1);
00466                 commandMode = false;
00467                 labelField.setText("Option:");
00468         }
00469 
00470         public void createBestImage() {
00471                 if (bestExamples == null)
00472                         bestExamples();
00473                 // write 'em out, staggered.
00474                 Ink ink = new Ink();
00475                 inkPanel.clear();
00476                 double max_range_x = 0;
00477                 double max_range_y = 0;
00478                 for (int i = 0; i < bestExamples.size(); i++) {
00479                         double min_x = Double.MAX_VALUE;
00480                         double min_y = Double.MAX_VALUE;
00481                         double max_x = Double.MIN_VALUE;
00482                         double max_y = Double.MIN_VALUE;
00483                         for (PenSample sample : bestExamples.get(i).controlPoints) {
00484                                 min_x = Math.min(sample.x, min_x);
00485                                 min_y = Math.min(sample.y, min_y);
00486                                 max_x = Math.max(sample.x, max_x);
00487                                 max_y = Math.max(sample.y, max_y);
00488 
00489                         }
00490                         max_range_x = Math.max(max_range_x, max_x - min_x);
00491                         max_range_y = Math.max(max_range_y, max_y - min_y);
00492                 }
00493                 for (int i = 0; i < bestExamples.size(); i++) {
00494                         ShapeContext context = bestExamples.get(i);
00495                         display(context, max_range_x * (2 * (i % 3) + 1), max_range_y * (2 * (i / 3) + 1));
00496                         ink.addStroke(new InkStroke(context.controlPoints, new Points()));
00497                 }
00498                 InkRenderer renderer = new InkRenderer(ink);
00499                 renderer.renderToJPEG(new File("best.jpg"), new Pixels(300), new Points(max_range_x * 7), new Points(
00500                                 max_range_y * 7));
00501         }
00502 
00503         public void determineClassParameters() {
00504                 for (Gesture gesture : gestures)
00505                         gesture.determineClassParameters();
00506         }
00507 
00508         public void display(ShapeContext context, double w, double h) {
00509                 double min_x = Double.MAX_VALUE;
00510                 double min_y = Double.MAX_VALUE;
00511                 for (PenSample sample : context.controlPoints) {
00512                         min_x = Math.min(sample.x, min_x);
00513                         min_y = Math.min(sample.y, min_y);
00514                 }
00515                 for (PenSample sample : context.controlPoints) {
00516                         sample.x += w - min_x;
00517                         sample.y += h - min_y;
00518                 }
00519                 InkStroke stroke = new InkStroke(context.controlPoints, new Points());
00520                 Ink ink = new Ink();
00521                 ink.addStroke(stroke);
00522                 inkPanel.clear();
00523                 inkPanel.addInk(ink);
00524         }
00525 
00526         public void focusGained(FocusEvent arg0) {
00527                 // TODO Auto-generated method stub
00528 
00529         }
00530 
00531         public void focusLost(FocusEvent arg0) {
00532                 // TODO Auto-generated method stub
00533 
00534         }
00535 
00536         public void generateGesture() {
00537                 if (bestExamples == null)
00538                         bestExamples();
00539                 ShapeContext.bands = 3;
00540                 int max_points = 0;
00541                 for (ShapeContext context : bestExamples) {
00542                         max_points = Math.max(max_points, context.controlPoints.size());
00543                 }
00544                 Random random = new Random();
00545                 ArrayList<PenSample> newSamples = new ArrayList<PenSample>();
00546                 for (int i = 0; i < max_points; i++) {
00547                         newSamples.add(new PenSample(random.nextDouble(), random.nextDouble(), 0, i));
00548                 }
00549                 ShapeContext noiseExample = new ShapeContext(newSamples, "noise");
00550                 max_points = 10;
00551                 ArrayList<PenSample> testSamples = new ArrayList<PenSample>();
00552                 for (int i = 0; i < max_points; i++) {
00553                         testSamples.add(new PenSample(random.nextDouble() * max_points * max_points, random.nextDouble()
00554                                         * max_points * max_points, 0, i));
00555                 }
00556                 ShapeContext testExample = new ShapeContext(testSamples, "test");
00557                 double distanceMetric = -Double.MAX_VALUE;
00558                 int maxIteration = 1001;
00559                 int acceptedWeak = 0;
00560                 int scale = ((max_points * max_points) / 8) * 2 + 1;
00561                 for (int i = 0; i < maxIteration; i++) {
00562                         // pick a random point in testExample
00563                         int point = random.nextInt(max_points);
00564                         PenSample sample = testExample.controlPoints.get(point);
00565                         // scale is max_points
00566                         int dim = random.nextInt(2);
00567                         int mod = random.nextInt(scale) - (scale / 2);
00568                         if (dim == 0) {
00569                                 mod = Math.max((int) -sample.x, mod);
00570                                 mod = Math.min((int) (max_points * max_points - sample.x), mod);
00571                                 sample.x += mod;
00572                         } else {
00573                                 mod = Math.max((int) -sample.y, mod);
00574                                 mod = Math.min((int) (max_points * max_points - sample.y), mod);
00575                                 sample.y += mod;
00576                         }
00577                         double[] distances = new double[bestExamples.size()];
00578                         double average = 0;
00579                         for (int c = 0; c < bestExamples.size(); c++) {
00580                                 ShapeContext context = bestExamples.get(c);
00581                                 distances[c] = ShapeHistogram.shapeContextMetric(context, testExample, false, false, false);
00582                                 average += distances[c];
00583                         }
00584                         average /= bestExamples.size();
00585                         // compute stddev distance
00586                         double stddev = 0;
00587                         for (int c = 0; c < bestExamples.size(); c++)
00588                                 stddev += Math.pow(distances[c] - average, 2);
00589                         stddev /= bestExamples.size() - 1;
00590                         stddev = Math.sqrt(stddev);
00591                         double angleCost = 0;
00592                         for (int c = 1; c < max_points - 1; c++) {
00593                                 // compute angles
00594                                 PenSample first = testExample.controlPoints.get(c - 1);
00595                                 PenSample second = testExample.controlPoints.get(c);
00596                                 PenSample third = testExample.controlPoints.get(c + 1);
00597                                 double x1 = second.x - first.x;
00598                                 double y1 = second.y - first.y;
00599                                 double x2 = third.x - second.x;
00600                                 double y2 = third.y - second.y;
00601                                 double dot = (x1 * x2 + y1 * y2) / Math.sqrt(x1 * x1 + y1 * y1)
00602                                                 / Math.sqrt(x2 * x2 + y2 * y2);
00603                                 // from -1 to 1
00604                                 angleCost += dot;
00605                         }
00606                         angleCost /= (max_points - 2);
00607                         double springCost = 0;
00608                         double averageLength = 0;
00609                         double[] length = new double[max_points - 1];
00610                         for (int c = 1; c < max_points; c++) {
00611                                 PenSample first = testExample.controlPoints.get(c - 1);
00612                                 PenSample second = testExample.controlPoints.get(c);
00613                                 length[c - 1] = Math.sqrt(Math.pow(first.x - second.x, 2) + Math.pow(first.y - second.y, 2));
00614                                 averageLength += length[c - 1];
00615                         }
00616                         averageLength /= max_points - 1;
00617                         for (int c = 0; c < max_points - 1; c++) {
00618                                 springCost += Math.pow(length[c] - averageLength, 2);
00619                         }
00620                         springCost /= max_points - 2;
00621                         springCost = Math.sqrt(springCost);
00622                         double noiseDistance = ShapeHistogram.shapeContextMetric(noiseExample, testExample, false, false,
00623                                         false);
00624                         double metric = -10 * stddev + noiseDistance + average / 3 + angleCost * 10000 - springCost
00625                                         * 5000 + averageLength * 5000;
00626                         System.out.println("iteration " + i + " distance " + metric + " stddev " + stddev + " noise "
00627                                         + noiseDistance + " average " + average + " angle " + angleCost + " sc " + springCost
00628                                         + " al " + averageLength);
00629                         if (!Double.isNaN(metric) && (metric > distanceMetric || random.nextDouble() < /*
00630                                                                                                                                                                                          * Math.exp((-Math.abs(mod) -
00631                                                                                                                                                                                          * scale)*(2*i/(double)maxIteration)/5.)))
00632                                                                                                                                                                                          */Math.exp((metric - distanceMetric) / (3000 - i)))) {
00633                                 if (metric < distanceMetric)
00634                                         acceptedWeak++;
00635                                 distanceMetric = metric;
00636                         } else {
00637                                 if (dim == 0)
00638                                         sample.x -= mod;
00639                                 else
00640                                         sample.y -= mod;
00641                         }
00642                         if (i % 100 == 0) {
00643                                 System.out.println("Accepted " + acceptedWeak + " weak samples.");
00644                                 display(testExample, max_points * max_points, max_points * max_points);
00645                                 Ink ink = new Ink();
00646                                 ink.addStroke(new InkStroke(testExample.controlPoints, new Points()));
00647                                 InkRenderer renderer = new InkRenderer(ink);
00648                                 renderer.renderToJPEG(new File(databaseName + "_generated_" + (i / 100) + ".jpg"),
00649                                                 new Pixels(300), new Points(max_points * max_points * 3), new Points(max_points
00650                                                                 * max_points * 3));
00651                         }
00652                 }
00653         }
00654 
00655         public PenGestureListener getListener() {
00656                 return listener;
00657         }
00658 
00659         private String getString() {
00660                 while (!inputAvailable)
00661                         Thread.yield();
00662                 inputAvailable = false;
00663                 return inputString;
00664         }
00665 
00666         public void labelGestures() {
00667                 labelField.setText("Class label:");
00668                 while (unlabeledContexts.size() > 0) {
00669                         ShapeContext context = unlabeledContexts.get(unlabeledContexts.size() - 1);
00670                         if (context.authorName.contains("jerry"))
00671                                 for (PenSample sample : context.controlPoints) { // that clown
00672                                         sample.x *= -1;
00673                                         sample.y *= -1;
00674                                 }
00675                         display(context, mainPanel.getWidth() / 2, mainPanel.getHeight() / 2);
00676                         String label = getString();
00677                         if (label.compareTo("-1") == 0) { // discard this gesture and continue
00678                                 unlabeledContexts.remove(context);
00679                                 continue;
00680                         } else if (label.compareTo("") == 0) {
00681                                 // stop for the moment
00682                                 break;
00683                         }
00684                         for (Gesture gesture : gestures) {
00685                                 if (gesture.name.compareTo(label) == 0) {
00686                                         gesture.addGesture(context);
00687                                         unlabeledContexts.remove(context);
00688                                         break;
00689                                 }
00690                         }
00691                 }
00692                 labelField.setText("Class label (test):");
00693                 while (unlabeledTestContexts.size() > 0) {
00694                         ShapeContext context = unlabeledTestContexts.get(unlabeledTestContexts.size() - 1);
00695                         display(context, mainPanel.getWidth() / 2, mainPanel.getHeight() / 2);
00696                         String label = getString();
00697                         if (label.compareTo("-1") == 0) { // discard this gesture and continue
00698                                 unlabeledTestContexts.remove(context);
00699                                 continue;
00700                         } else if (label.compareTo("") == 0) {
00701                                 // stop for the moment
00702                                 break;
00703                         }
00704                         for (Gesture gesture : testGestures) {
00705                                 if (gesture.name.compareTo(label + "TEST") == 0) {
00706                                         gesture.addGesture(context);
00707                                         unlabeledTestContexts.remove(context);
00708                                         break;
00709                                 }
00710                         }
00711                 }
00712         }
00713 
00714         private void leaveChunkOut(int trials) throws IOException {
00715                 for (int trial = 0; trial < trials; trial++) {
00716                         chunk(.1);
00717                         Writer writer = new FileWriter(new File("C:\\dev\\quill\\data\\" + databaseName + "_" + trial
00718                                         + ".gp"));
00719                         quillWrite(writer);
00720                         Save(new FileWriter(new File(databaseName + "_" + trial + ".xml")));
00721                         Date before = new Date();
00722                         System.out.println("Beginning no-time run leaving out " + trial);
00723                         ShapeContext.bands = 2;
00724                         autotest();
00725                         Date after = new Date();
00726                         double secondsElapsed = (after.getTime() - before.getTime()) / 1000.;
00727                         System.out.println("Elapsed time was: " + secondsElapsed);
00728 
00729                         before = new Date();
00730                         System.out.println("Beginning time run leaving out " + trial);
00731                         ShapeContext.bands = 3;
00732                         autotest();
00733                         after = new Date();
00734                         secondsElapsed = (after.getTime() - before.getTime()) / 1000.;
00735                         System.out.println("Elapsed time was: " + secondsElapsed);
00736 
00737                 }
00738         }
00739 
00740         public void leaveOneOutOptimizeCostWeighting() {
00741                 // try a range from .1 to 2 for kicks
00742                 int[] errors = new int[11];
00743                 for (int i = 1; i <= 10; i++) {
00744                         ShapeHistogram.costWeighting = i * .1;
00745                         for (Gesture gesture : gestures) {
00746                                 for (int j = gesture.contexts.size() - 1; j >= 0; j--) {
00747                                         ShapeContext context = gesture.contexts.get(j);
00748                                         gesture.contexts.remove(context);
00749                                         String assignment = test(context, false);
00750                                         if (assignment.compareTo(gesture.name) != 0) {
00751                                                 errors[i]++;
00752                                                 test(context, true);
00753                                                 System.out.println(gesture.name + " misclassified as " + assignment);
00754                                         }
00755                                         gesture.addGesture(context);
00756                                 }
00757                         }
00758                         System.out
00759                                         .println("Error with cost weighting " + ShapeHistogram.costWeighting + ": " + errors[i]);
00760                 }
00761                 ShapeHistogram.costWeighting = .3;
00762         }
00763 
00764         private void leaveOneUserOut() throws IOException {
00765                 // go through the users, pull one out into the test set at a time
00766                 moveAllToTrain();
00767                 Gesture gesture = gestures.get(0);
00768                 HashMap<String, Boolean> names = new HashMap<String, Boolean>();
00769                 for (int i = 0; i < gesture.contexts.size(); i++) {
00770                         names.put(gesture.contexts.get(i).authorName, true);
00771                 }
00772                 for (String name : names.keySet()) {
00773                         for (int i = 0; i < gestures.size(); i++) {
00774                                 ArrayList<ShapeContext> contexts = gestures.get(i).contexts;
00775                                 ArrayList<ShapeContext> testContexts = testGestures.get(i).contexts;
00776                                 for (int j = contexts.size() - 1; j >= 0; j--) {
00777                                         ShapeContext context = contexts.get(j);
00778                                         if (context.authorName.compareTo(name) == 0) {
00779                                                 contexts.remove(context);
00780                                                 testContexts.add(context);
00781                                         }
00782                                 }
00783                         }
00784                         Writer writer = new FileWriter(new File("C:\\dev\\quill\\data\\" + databaseName + "_" + name
00785                                         + ".gp"));
00786                         quillWrite(writer);
00787                         Save(new FileWriter(new File(databaseName + "_" + name + ".xml")));
00788                         Date before = new Date();
00789                         System.out.println("Beginning no-time run leaving out " + name);
00790                         ShapeContext.bands = 2;
00791                         autotest();
00792                         Date after = new Date();
00793                         double secondsElapsed = (after.getTime() - before.getTime()) / 1000.;
00794                         System.out.println("Elapsed time was: " + secondsElapsed);
00795 
00796                         before = new Date();
00797                         System.out.println("Beginning time run leaving out " + name);
00798                         ShapeContext.bands = 3;
00799                         autotest();
00800                         after = new Date();
00801                         secondsElapsed = (after.getTime() - before.getTime()) / 1000.;
00802                         System.out.println("Elapsed time was: " + secondsElapsed);
00803                 }
00804         }
00805 
00806         private void moveAllToTrain() {
00807                 for (int i = 0; i < gestures.size(); i++) {
00808                         ArrayList<ShapeContext> contexts = gestures.get(i).contexts;
00809                         ArrayList<ShapeContext> testContexts = testGestures.get(i).contexts;
00810                         for (int j = testContexts.size() - 1; j >= 0; j--) {
00811                                 ShapeContext context = testContexts.get(j);
00812                                 testContexts.remove(context);
00813                                 contexts.add(context);
00814                         }
00815                 }
00816         }
00817 
00818         private void print(String string) {
00819                 labelField.setText(string);
00820         }
00821 
00822         public void quillWrite(Writer writer) throws IOException {
00823                 final String VERSION = "gdt 2.0";
00824 
00825                 writer.write(VERSION + "\n");
00826                 // gesture package
00827                 String name = "bob";
00828                 writer.write("name\t" + name + "\n");
00829                 writer.write("training\n");
00830                 // gesture set
00831                 writer.write("name\t" + name + "\n");
00832                 for (Gesture gesture : (bestGestures == null ? gestures : bestGestures)) {
00833                         writer.write("category\n");
00834                         gesture.quillWrite(writer);
00835                 }
00836                 writer.write("endset\n");
00837                 // end set
00838                 writer.write("test\n");
00839                 // gesture set
00840                 writer.write("name\t" + name + "\n");
00841                 writer.write("set\n");
00842                 writer.write("name\t testset1\n");
00843                 for (Gesture gesture : testGestures) {
00844                         writer.write("category\n");
00845                         gesture.quillWrite(writer);
00846                 }
00847                 writer.write("endset\n");
00848                 writer.write("endmetaset\n");
00849                 writer.write("endpackage\n");
00850                 writer.close();
00851         }
00852 
00853         public void Save(Writer writer) throws IOException {
00854                 XStream xstream = new XStream();
00855                 xstream.toXML(this, writer);
00856         }
00857 
00858         public String test(ShapeContext context, boolean verbose) {
00859                 return test(context, verbose, gestures);
00860         }
00861 
00862         public String test(ShapeContext context, boolean verbose, ArrayList<Gesture> gestures) {
00863                 // do KNN
00864                 int k = 3;
00865                 double[] distance = new double[k];
00866                 int[] index = new int[k];
00867                 String[] clazz = new String[k];
00868                 for (int i = 0; i < k; i++) {
00869                         distance[i] = Double.MAX_VALUE;
00870                         index[i] = -1;
00871                 }
00872                 for (int i = 0; i < gestures.size(); i++) {
00873                         gestures.get(i).knnMatch(context, k, distance, clazz, verbose);
00874                         // double d = gestures.get(i).bestMatch(context);
00875                         // System.out.println("Category " + i + "(" + gestures.get(i).size() + " items): " + d);
00876                 }
00877                 if (verbose) {
00878                         System.out.println("Best matches are:");
00879                         for (int i = 0; i < k; i++)
00880                                 System.out.println(i + ": category " + clazz[i] + " with distance " + distance[i] + " using "
00881                                                 + context.size() + " points.");
00882                 }
00883                 HashMap<String, Integer> counts = new HashMap<String, Integer>();
00884                 HashMap<String, Double> costs = new HashMap<String, Double>();
00885                 int max_count = 1;
00886                 boolean unique = true;
00887                 for (int i = 0; i < k; i++) {
00888                         Integer total = counts.get(clazz[i]);
00889                         int count;
00890                         double d;
00891                         if (total == null) {
00892                                 count = 1;
00893                                 d = distance[i];
00894                         } else {
00895                                 count = total + 1;
00896                                 d = costs.get(clazz[i]) + distance[i];
00897                         }
00898                         if (count == max_count)
00899                                 unique = false;
00900                         else if (count > max_count) {
00901                                 unique = true;
00902                                 max_count = count;
00903                         }
00904                         counts.put(clazz[i], count);
00905                         costs.put(clazz[i], d);
00906                 }
00907                 if (unique)
00908                         for (String c : counts.keySet()) {
00909                                 if (counts.get(c) == max_count) {
00910                                         return c;
00911                                 }
00912                         }
00913                 double min_cost = Double.MAX_VALUE;
00914                 String min_clazz = null;
00915                 for (String c : counts.keySet()) {
00916                         double cost = costs.get(c) / counts.get(c);
00917                         if (cost < min_cost) {
00918                                 min_cost = cost;
00919                                 min_clazz = c;
00920                         }
00921                 }
00922                 return min_clazz;
00923         }
00924 
00925         private void testChunkOutOnBest(int trials) throws IOException {
00926                 for (int trial = 0; trial < trials; trial++) {
00927                         chunk(.1);
00928                         bestExamples();
00929                         Writer writer = new FileWriter(new File("C:\\dev\\quill\\data\\" + databaseName + "_best_"
00930                                         + trial + ".gp"));
00931                         quillWrite(writer);
00932                         Save(new FileWriter(new File(databaseName + "_best_" + trial + ".xml")));
00933                         Date before = new Date();
00934                         System.out.println("Beginning no-time run on best leaving out " + trial);
00935                         ShapeContext.bands = 2;
00936                         autotest(bestGestures);
00937                         Date after = new Date();
00938                         double secondsElapsed = (after.getTime() - before.getTime()) / 1000.;
00939                         System.out.println("Elapsed time was: " + secondsElapsed);
00940 
00941                         before = new Date();
00942                         System.out.println("Beginning time run on best leaving out " + trial);
00943                         ShapeContext.bands = 3;
00944                         autotest(bestGestures);
00945                         after = new Date();
00946                         secondsElapsed = (after.getTime() - before.getTime()) / 1000.;
00947                         System.out.println("Elapsed time was: " + secondsElapsed);
00948                 }
00949         }
00950 }

Generated on Sat Apr 14 18:21:35 2007 for R3 Paper Toolkit by  doxygen 1.4.7