Skip to content
  1. Examples

AI classification new

This example is showing how to use a pre-trained Graph Neural Network model to predict node classification. It uses the "Cora" dataset, which is a citation network dataset. The model is trained on the a small subset of 500 nodes of the dataset, and uses the feature vectors extracted from the content of the papers associated with nodes and the cross-citations. This allows it to predict the category of the nodes that the model was not aware of.

You can try and re-train the model on the different parts of the dataset to see the effect it has on the accuracy.

ts
import Ogma, { Node, Overlay, RawGraph, Point } from '@linkurious/ogma';
import { failIcon, checkIcon } from './icons';
import * as tf from '@tensorflow/tfjs';

interface NodeData {
  features: number[];
  value: number;
  id: number;
  category: number;
  predictedCategory?: number;
  processed?: boolean;
  trained?: boolean;
  showBadPrediction?: boolean;
}

const CATEGORIES = [
  'Case Based',
  'Genetic Algorithms',
  'Neural Networks',
  'Probabilistic Methods',
  'Reinforcement Learning',
  'Rule Learning',
  'Theory'
];

const pulse_options = {
  number: 5,
  duration: 1700,
  interval: 600,
  startColor: 'inherit',
  endColor: 'red',
  width: 5,
  startRatio: 1,
  endRatio: 5
};

const COLORS = [
  '#FA5319',
  '#2846FA',
  '#FA9419',
  '#19B1FA',
  '#A57B47',
  '#7A5649',
  '#C844D1'
];

const lassoOptions = {
  strokeWidth: 1,
  bothExtremities: false,
  callback: (evt: any) => {
    evt.nodes.setSelected(true);
    evt.edges.setSelected(false);
  }
};

const ogma = new Ogma<NodeData>({
  container: 'graph-container',
  options: {
    backgroundColor: '#F5F6F6',
    detect: { edges: false }
  }
});

function createLegend() {
  const legendList = document.getElementById('legend-list');
  if (legendList) {
    CATEGORIES.forEach((category, index) => {
      const listItem = document.createElement('li');
      const colorBox = document.createElement('span');
      colorBox.className = 'legend-color-box';
      colorBox.style.backgroundColor = COLORS[index];

      const text = document.createElement('span');
      text.textContent = category;

      listItem.appendChild(colorBox);
      listItem.appendChild(text);
      legendList.appendChild(listItem);
    });
  }
}

createLegend();

let popup: Overlay | null = null;
function openPopup(content: string, position: Point) {
  if (popup) popup.destroy();
  popup = ogma.layers.addOverlay({
    element: `
<div class="ogma-popup ogma-popup--top">
  <div class="ogma-popup--body">
    <div class="ogma-popup--close">&times;</div>
    <div class="ogma-popup--content">${content}</div>
  </div>
</div>
    `,
    size: { width: 'auto', height: 'auto' },
    scaled: false,
    position
  });
  popup.element.addEventListener('click', evt => evt.stopPropagation());
  popup.element
    ?.querySelector('.ogma-popup--close')!
    .addEventListener('click', evt => {
      evt.stopPropagation();
      closePopup();
    });

  ogma.getSelectedNodes().setSelected(false);
  ogma.setOptions({
    interactions: {
      zoom: { enabled: false },
      pan: { enabled: false },
      selection: { enabled: false }
    },
    detect: { nodes: false }
  });
}

function closePopup() {
  if (popup) {
    popup.destroy();
    popup = null;
    ogma.setOptions({
      interactions: {
        zoom: { enabled: true },
        pan: { enabled: true },
        selection: { enabled: true }
      },
      detect: { nodes: true }
    });
  }
}

ogma.styles.addNodeRule({
  color: node => {
    if (node.getData('trained')) return '#6e0606';
    else if (node.getData('processed')) return COLORS[node.getData('category')];
  },
  radius: node => {
    if (node.getData('processed') || node.getData('trained')) return 20;
    return 10;
  },
  opacity: node => {
    if (!node.getData('processed') || !highlightCorrect.checked)
      return undefined;
    if (node.getData('category') === node.getData('predictedCategory'))
      return 1;
    return 0.25;
  }
});

function getAdjacencyMatrix({ nodes, edges }: RawGraph) {
  const nodesLength = nodes.length;
  const adjacencyMatrix: number[][] = Array.from({ length: nodesLength }, () =>
    Array(nodesLength).fill(0)
  );
  edges.forEach(edge => {
    adjacencyMatrix[+edge.source][+edge.target] = 1;
    adjacencyMatrix[+edge.target][+edge.source] = 1;
  });
  return adjacencyMatrix;
}

let correctPredictions = 0;
let totalPredictedNodes = 0;
let currentAccuracy = 0;

function updateAccuracy(correct: boolean) {
  if (correct) correctPredictions++;
  totalPredictedNodes++;
  currentAccuracy = (correctPredictions / totalPredictedNodes) * 100;
  showMessage(`Accuracy: <strong>${currentAccuracy.toFixed(2)}%</strong>`);
}

function initializeProgressBar(totalNodes: number, name: string) {
  const progressBarId = `progress-bar-${name}`;
  const existingProgressBar = document.getElementById(progressBarId);
  if (existingProgressBar) return;

  progressBarContainer.innerHTML += `
  <label for="progress-bar-${name}"> ${name} Progress:</label>
        <progress id="progress-bar-${name}" value="0" max="100"></progress>`;

  const progressBar = document.getElementById(
    `progress-bar-${name}`
  )! as HTMLProgressElement;
  progressBar.value = 0;
  progressBar.max = totalNodes;
}

function updateProgressBar(currentProgress: number, name: string) {
  const progressBar = document.getElementById(
    `progress-bar-${name}`
  )! as HTMLProgressElement;

  progressBar.value = currentProgress;
}

// @ts-expect-error
function predict(adj: tf.tensor, model: tf.LayersModel, node: Node<NodeData>) {
  if (node.getData('trained')) return;
  const fts = tf.tensor([node.getData('features')]);
  const adjTensor = tf.tensor([adj]);
  const prediction = model.predict([fts, adj]);
  const predictedCategory = prediction.argMax(1).dataSync()[0];
  const correct = predictedCategory === node.getData('category');
  node.setData('predictedCategory', predictedCategory);
  node.setData('processed', true);
  updateAccuracy(correct);
  return { correct, predictedCategory };
}

const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));

let hasPredictedAllNodes = false;

// @ts-expect-error
async function predictAllNodes(adj: tf.Tensor, model: tf.LayersModel) {
  if (hasPredictedAllNodes) return;
  hasPredictedAllNodes = true;
  const allNodes = ogma.getNodes();

  initializeProgressBar(allNodes.size, 'Predict');

  for (let i = 0; i < 2708; i++) {
    const node = allNodes.get(i);
    const nodeId = node.getId() as number;
    const adjTensor = tf.tensor([adj[nodeId]]);
    predict(adjTensor, model, node);
    updateProgressBar(i + 1, 'Predict');
    if (i % 50 === 0) await delay(1);
  }
  setTimeout(() => {
    progressBarContainer.innerHTML = '';
  }, 5000);
}

function showWarningPopup() {
  const warningPopup = document.getElementById('warning-popup');
  if (warningPopup) {
    warningPopup.style.display = 'block';
    setTimeout(hideWarningPopup, 3000);
  }
}

function hideWarningPopup() {
  const warningPopup = document.getElementById('warning-popup');
  if (warningPopup) warningPopup.style.display = 'none';
}

// @ts-expect-error
async function trainModel(adj: tf.tensor, model: tf.LayersModel) {
  const selectedNodes = ogma.getSelectedNodes();
  const epoch = 30;

  const trainButton = document.getElementById('train') as HTMLButtonElement;
  const buttonText = trainButton.textContent?.toString()!;

  if (selectedNodes.size < 150) {
    showWarningPopup();
    trainButton.disabled = true;
    trainButton.textContent = 'Selecting nodes for training';
    ogma.tools.lasso.enable(lassoOptions);
    ogma.events.once('nodesSelected', () => {
      trainButton.removeAttribute('disabled');
      trainButton.textContent = `Train on ${ogma.getSelectedNodes().size} nodes`;
    });
    return;
  }
  initializeProgressBar(epoch, 'Training');
  hideWarningPopup();
  trainButton.disabled = true;
  const sampleSize = selectedNodes.size / ogma.getNodes().size;
  trainButton.textContent = `Training on ${selectedNodes.size} nodes (${Math.floor(sampleSize * 100)}%)...`;

  await delay(200);

  const featuresArray = selectedNodes.getData('features');
  const labelsArray = selectedNodes.getData('category');
  const labels1 = tf.tensor1d(labelsArray, 'int32');
  const labelsTensor = tf.oneHot(labels1, 7);

  const adjArray = selectedNodes.map(node => {
    const nodeId = node.getId() as number;
    return adj[nodeId];
  });

  selectedNodes.fillData('trained', true).setSelected(false);

  const featuresTensor = tf.tensor(featuresArray);
  const adjTensor = tf.tensor2d(adjArray);

  model.compile({
    optimizer: tf.train.adam(0.01),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });

  const start = Date.now();
  await model.fit([featuresTensor, adjTensor], labelsTensor, {
    epochs: epoch,
    batchSize: 32,
    shuffle: false,
    verbose: 1,
    validationSplit: 0.1,
    callbacks: {
      // @ts-expect-error
      onEpochEnd: (epoch: number, logs: tf.Logs) => {
        updateProgressBar(epoch + 1, 'Training');
        // console.log(
        //   `Epoch ${epoch + 1}: loss = ${logs?.loss}, accuracy = ${logs?.acc}`
        // );
      }
    }
  });

  trainButton.disabled = false;
  trainButton.textContent = buttonText;

  showMessage(
    `Training completed in <strong>${humanizeDuration(Date.now() - start)}</strong>`
  );
  trainButton.textContent = 'Train';
  selectedNodes.fillData('trained', false);
  progressBarContainer.innerHTML = '';

  featuresTensor.dispose();
  labelsTensor.dispose();
  adjTensor.dispose();
}

function showMessage(msg: string) {
  info.innerHTML = msg;
}

function applyHighlightMistakes() {
  const allNodes = ogma.getNodes();
  allNodes.forEach(node => {
    const nodeData = node.getData();

    const correct = nodeData.category === nodeData.predictedCategory;
    if (
      nodeData.processed &&
      !nodeData.trained &&
      highlightCorrect.checked &&
      !correct
    ) {
      // node.pulse(pulse_options);
      node.setData('showBadPrediction', true);
    } else {
      if (node.getData('showBadPrediction')) {
        node.setData('showBadPrediction', false);
      }
    }
  });
}

function refresh() {
  const allNodes = ogma.getNodes();
  allNodes
    .fillData('processed', false)
    .fillData('trained', false)
    .fillData('showBadPrediction', false)
    .fillData('predictedCategory', undefined);
  hasPredictedAllNodes = false;
  progressBarContainer.innerHTML = '';
  info.innerHTML = '';
}

const info = document.querySelector('#info')!;
const progressBarContainer = document.getElementById('progress-bar-container')!;
const highlightCorrect = document.getElementById(
  'highlight-correct'
) as HTMLInputElement;

highlightCorrect.addEventListener('change', applyHighlightMistakes);

console.time('Loading graph');
const graph = await Ogma.parse.jsonFromUrl<NodeData, unknown>(
  'files/cora_positions.json'
);
await ogma.view.locateRawGraph(graph);
console.timeEnd('Loading graph');
//const graph = await Ogma.parse.jsonFromUrl('files/cora.json');
await ogma.setGraph(graph);
const adjacencyMatrix = getAdjacencyMatrix(graph);

const model = await tf.loadLayersModel('files/pre_trained.json', {
  //weightPathPrefix: './'
});

function onNodeClick(node: Node<NodeData>) {
  if (popup) return null;
  const id = node.getId() as number;

  const adjacencyTensor = tf.tensor([adjacencyMatrix[id]]);

  // use pre-trained model to predict the category of the node
  // @ts-expect-error
  const { correct, predictedCategory } = predict(adjacencyTensor, model, node);

  const icon = correct ? checkIcon : failIcon;
  const realColor = COLORS[node.getData('category')];
  const predictedColor = COLORS[predictedCategory];
  const content = `
<div class="prediction-result">
  <h2 class="${correct ? 'success' : 'danger'}"><span class="icon">${icon}</span> Paper #${node.getData('paperId')}</h2>
  <table>
    <tr>
      <td>Predicted</td>
      <td><span class="category" style="background-color: ${predictedColor}"></span></td>
      <td>${CATEGORIES[predictedCategory]}</td>
    </tr>
    <tr>
      <td>Actual</td>
      <td><span class="category" style="background-color: ${realColor}"></span></td>
      <td>${CATEGORIES[node.getData('category')]}</td>
    </tr>
  </table>
</div>
    `;
  openPopup(content, node.getPosition());
}

ogma.events.on('click', ({ target }) => {
  if (target && target.isNode) onNodeClick(target);
  else if (popup) closePopup();
});

document.getElementById('predict')?.addEventListener('click', () => {
  predictAllNodes(adjacencyMatrix, model);
});

document.getElementById('train')?.addEventListener('click', () => {
  trainModel(adjacencyMatrix, model);
});

document.getElementById('refresh')?.addEventListener('click', () => {
  refresh();
});

ogma.events.on('dragStart', () => {
  if (ogma.keyboard.isKeyPressed('ctrl')) {
    ogma.getSelectedEdges().setSelected(false);
    ogma.getSelectedNodes().setSelected(false);
    ogma.tools.lasso.enable(lassoOptions);
  }
});

function humanizeDuration(ms: number): string {
  // Define the various time units in milliseconds.
  const msInASecond = 1000;
  const msInAMinute = 60 * msInASecond;
  const msInAnHour = 60 * msInAMinute;
  const msInADay = 24 * msInAnHour;

  let remainingMs = ms;

  // Calculate the number of days, hours, minutes, seconds, and milliseconds.
  const days = Math.floor(remainingMs / msInADay);
  remainingMs %= msInADay;

  const hours = Math.floor(remainingMs / msInAnHour);
  remainingMs %= msInAnHour;

  const minutes = Math.floor(remainingMs / msInAMinute);
  remainingMs %= msInAMinute;

  const seconds = Math.floor(remainingMs / msInASecond);
  remainingMs %= msInASecond;

  // Construct the humanized duration string.
  const parts: string[] = [];

  if (days) parts.push(`${days}d`);
  if (hours) parts.push(`${hours}h`);
  if (minutes) parts.push(`${minutes}min`);
  if (seconds || (!days && !hours && !minutes && !remainingMs))
    parts.push(`${seconds}s`);
  if (remainingMs) parts.push(`${remainingMs.toFixed()} ms`);

  return parts.join(' ');
}
html
<!doctype html>
<html>
  <head>
    <meta charset="utf-8" />

    <link type="text/css" rel="stylesheet" href="styles.css" />
    <script type="importmap">
      {
        "imports": {
          "@tensorflow/tfjs": "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/+esm"
        }
      }
    </script>
    <link rel="preconnect" href="https://fonts.googleapis.com" />
    <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
    <link
      href="https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;1,100;1,200;1,300;1,400;1,500;1,600;1,700&display=swap"
      rel="stylesheet"
    />
    <!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/dist/tf.min.js"></script> -->
  </head>
  <body>
    <div id="graph-container"></div>
    <div class="controls">
      <p>
        <button id="predict">Predict all</button>
      </p>
      <p>
        <button id="train">Train</button>
      </p>
      <p>
        <button id="refresh">Clear</button>
      </p>
      <p>
        <input type="checkbox" id="highlight-correct" />
        <label for="highlight-correct">Hide incorrect predictions</label>
      </p>
      <div id="info"></div>
      <div id="progress-bar-container"></div>
      <div id="warning-popup">
        Use lasso to select training data, try to select 150 nodes or more for
        higher accuracy.
      </div>
    </div>

    <div id="legend-container">
      <ul id="legend-list"></ul>
    </div>
    <script type="module" src="index.ts"></script>
  </body>
</html>
css
:root {
  --base-color: #4999f7;
  --active-color: var(--base-color);
  --white: #ffffff;
  --gray: #a9a9a9;
  --lighter-gray: #f4f4f4;
  --light-gray: #e6e6e6;
  --inactive-color: #acacac;
  --group-color: #525fe1;
  --group-inactive-color: #c2c8ff;
  --selection-color: #04ddcb;
  --country-color: #044b87;
  --danger-color: #842224;
  --country-inactive-color: #bccddb;
  --dark-color: #3a3535;
  --edge-color: var(--dark-color);
  --border-radius: 3px;
  --button-border-radius: var(--border-radius);
  --edge-inactive-color: var(--gray);
  --button-background-color: #ffffff;
  --shadow-color: rgba(0, 0, 0, 0.25);
  --shadow-hover-color: rgba(0, 0, 0, 0.5);
  --button-shadow: 0 0 4px var(--shadow-color);
  --button-shadow-hover: 0 0 4px var(--shadow-hover-color);
  --button-icon-color: #000000;
  --button-icon-hover-color: var(--active-color);
  --overlay-background-color: #fff;
  --overlay-text-color: #444;
  --button-text-color: var(--white);
}

body {
  font-family: 'IBM Plex Sans', sans-serif;
}

#graph-container {
  top: 0;
  bottom: 0;
  left: 0;
  right: 0;
  position: absolute;
  margin: 0;
  overflow: hidden;
}

.ogma-tooltip,
.ogma-popup {
  z-index: 401;
  box-sizing: border-box;
}

.ogma-tooltip--content,
.ogma-popup--body {
  transform: translate(-50%, 0);
  background-color: var(--overlay-background-color);
  color: var(--overlay-text-color);
  border-radius: 5px;
  padding: 5px;
  box-sizing: border-box;
  box-shadow: 0 8px 30px rgb(0 0 0 / 12%);
  width: auto;
  height: auto;
  position: relative;
}

.ogma-tooltip {
  pointer-events: none;
}

.ogma-popup--body {
  transform: translate(-50%, -100%);
}

.ogma-tooltip--content:after,
.ogma-popup--body:after {
  content: '';
  width: 0;
  height: 0;
  border-style: solid;
  border-width: 6px 7px 6px 0;
  border-color: transparent var(--overlay-background-color) transparent
    transparent;
  position: absolute;
  left: 50%;
  top: auto;
  bottom: 3px;
  right: auto;
  transform: translate(-50%, 100%) rotate(270deg);
}

.ogma-popup--close {
  position: absolute;
  top: 0px;
  right: 5px;
  cursor: pointer;
}

.ogma-popup--top .ogma-popup--body,
.ogma-tooltip--top .ogma-tooltip--content {
  bottom: 6px;
  transform: translate(-50%, -100%);
}

.ogma-popup--bottom .ogma-popup--body,
.ogma-tooltip--bottom .ogma-tooltip--content {
  transform: translate(-50%, 0%);
  top: 3px;
}

.ogma-popup--bottom .ogma-popup--body:after,
.ogma-tooltip--bottom .ogma-tooltip--content:after {
  top: 3px;
  bottom: auto;
  transform: translate(-50%, -100%) rotate(90deg);
}

.ogma-popup--right .ogma-popup--body,
.ogma-tooltip--right .ogma-tooltip--content {
  transform: translate(0, -50%);
  left: 6px;
}

.ogma-popup--right .ogma-popup--body:after,
.ogma-tooltip--right .ogma-tooltip--content:after {
  left: 0%;
  top: 50%;
  transform: translate(-100%, -50%) rotate(0deg);
}

.ogma-popup--left .ogma-popup--body,
.ogma-tooltip--left .ogma-tooltip--content {
  transform: translate(-100%, -50%);
  right: 6px;
}

.ogma-popup--left .ogma-popup--body:after,
.ogma-tooltip--left .ogma-tooltip--content:after {
  right: 0%;
  left: auto;
  top: 50%;
  transform: translate(100%, -50%) rotate(180deg);
}

.ogma-popup--content {
  padding: 10px;
}

.icon {
  vertical-align: middle;
}

.success {
  color: darkgreen;
}

.danger {
  color: var(--danger-color);
}

.prediction-result .category {
  border-radius: 50%;
  width: 1em;
  height: 1em;
  display: inline-block;
  vertical-align: text-top;
}

.controls {
  position: absolute;
  top: 10px;
  right: 10px;
  z-index: 400;
  padding: 1em;
  border-radius: var(--border-radius);
  box-shadow: var(--button-shadow);
  min-width: 200px;
  background-color: var(--white);
  font-weight: 300;
}

.controls p button {
  display: block;
  width: 100%;
  background-color: var(--base-color);
  border-radius: var(--button-border-radius);
  color: var(--button-text-color);
  padding: 0.5em 0.5em;
  border: 0;
}

.controls p button:disabled {
  background-color: var(--inactive-color);
}

#progress-bar-container {
  display: flex;
  flex-direction: column;
  flex-wrap: wrap;
  gap: 10px;
}

#progress-bar-container label {
  margin-right: 5px;
}

progress {
  width: 100%;
  border-radius: 15px;
  background-color: #e0e0e0;
  border: none;
  overflow: hidden;
}

progress::-webkit-progress-bar {
  background-color: #e0e0e0;
  border-radius: 15px;
}

progress::-webkit-progress-value {
  background-color: var(--base-color);
  border-radius: 15px;
  transition: width 0.4s ease;
}

progress::-moz-progress-bar {
  background-color: var(--base-color);
  border-radius: 15px;
  transition: width 0.4s ease;
}

label[for='progress-bar'] {
  display: block;
  margin-bottom: 8px;
}

#legend-container {
  position: fixed;
  bottom: 20px;
  right: 20px;
  background-color: #fff;
  border: 1px solid #ddd;
  border-radius: 8px;
  padding: 10px;
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
  z-index: 1000;
  font-weight: 300;
  font-size: 14px;
}

#legend-list {
  list-style: none;
  padding: 0;
  margin: 0;
}

#legend-list li {
  display: flex;
  align-items: center;
  margin-bottom: 5px;
}

.legend-color-box {
  width: 16px;
  height: 16px;
  margin-right: 8px;
  border-radius: 4px;
  border: 1px solid #ccc;
}

#warning-popup {
  display: none;
  margin-top: 10px;
  color: var(--danger-color);
  max-width: 200px;
}

#info {
  max-width: 200px;
}

#info strong {
  white-space: nowrap;
}
json
{
  "dependencies": {
    "@tensorflow/tfjs": "4.20.0"
  }
}
ts
export const checkIcon = `<svg width="24px" height="24px" stroke-width="1.5" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" color="currentColor"><path d="M7 12.5L10 15.5L17 8.5" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path><path d="M12 22C17.5228 22 22 17.5228 22 12C22 6.47715 17.5228 2 12 2C6.47715 2 2 6.47715 2 12C2 17.5228 6.47715 22 12 22Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path></svg>`;
export const failIcon = `<svg width="24px" height="24px" stroke-width="1.5" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" color="currentColor"><path d="M9.17218 14.8284L12.0006 12M14.829 9.17157L12.0006 12M12.0006 12L9.17218 9.17157M12.0006 12L14.829 14.8284" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path><path d="M12 22C17.5228 22 22 17.5228 22 12C22 6.47715 17.5228 2 12 2C6.47715 2 2 6.47715 2 12C2 17.5228 6.47715 22 12 22Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path></svg>`;