Appearance
AI classification new
This example is showing how to use a pre-trained Graph Neural Network model to predict node classification. It uses the "Cora" dataset, which is a citation network dataset. The model is trained on the a small subset of 500 nodes of the dataset, and uses the feature vectors extracted from the content of the papers associated with nodes and the cross-citations. This allows it to predict the category of the nodes that the model was not aware of.
You can try and re-train the model on the different parts of the dataset to see the effect it has on the accuracy.
ts
import Ogma, { Node, Overlay, RawGraph, Point } from '@linkurious/ogma';
import { failIcon, checkIcon } from './icons';
import * as tf from '@tensorflow/tfjs';
interface NodeData {
features: number[];
value: number;
id: number;
category: number;
predictedCategory?: number;
processed?: boolean;
trained?: boolean;
showBadPrediction?: boolean;
}
const CATEGORIES = [
'Case Based',
'Genetic Algorithms',
'Neural Networks',
'Probabilistic Methods',
'Reinforcement Learning',
'Rule Learning',
'Theory'
];
const pulse_options = {
number: 5,
duration: 1700,
interval: 600,
startColor: 'inherit',
endColor: 'red',
width: 5,
startRatio: 1,
endRatio: 5
};
const COLORS = [
'#FA5319',
'#2846FA',
'#FA9419',
'#19B1FA',
'#A57B47',
'#7A5649',
'#C844D1'
];
const lassoOptions = {
strokeWidth: 1,
bothExtremities: false,
callback: (evt: any) => {
evt.nodes.setSelected(true);
evt.edges.setSelected(false);
}
};
const ogma = new Ogma<NodeData>({
container: 'graph-container',
options: {
backgroundColor: '#F5F6F6',
detect: { edges: false }
}
});
function createLegend() {
const legendList = document.getElementById('legend-list');
if (legendList) {
CATEGORIES.forEach((category, index) => {
const listItem = document.createElement('li');
const colorBox = document.createElement('span');
colorBox.className = 'legend-color-box';
colorBox.style.backgroundColor = COLORS[index];
const text = document.createElement('span');
text.textContent = category;
listItem.appendChild(colorBox);
listItem.appendChild(text);
legendList.appendChild(listItem);
});
}
}
createLegend();
let popup: Overlay | null = null;
function openPopup(content: string, position: Point) {
if (popup) popup.destroy();
popup = ogma.layers.addOverlay({
element: `
<div class="ogma-popup ogma-popup--top">
<div class="ogma-popup--body">
<div class="ogma-popup--close">×</div>
<div class="ogma-popup--content">${content}</div>
</div>
</div>
`,
size: { width: 'auto', height: 'auto' },
scaled: false,
position
});
popup.element.addEventListener('click', evt => evt.stopPropagation());
popup.element
?.querySelector('.ogma-popup--close')!
.addEventListener('click', evt => {
evt.stopPropagation();
closePopup();
});
ogma.getSelectedNodes().setSelected(false);
ogma.setOptions({
interactions: {
zoom: { enabled: false },
pan: { enabled: false },
selection: { enabled: false }
},
detect: { nodes: false }
});
}
function closePopup() {
if (popup) {
popup.destroy();
popup = null;
ogma.setOptions({
interactions: {
zoom: { enabled: true },
pan: { enabled: true },
selection: { enabled: true }
},
detect: { nodes: true }
});
}
}
ogma.styles.addNodeRule({
color: node => {
if (node.getData('trained')) return '#6e0606';
else if (node.getData('processed')) return COLORS[node.getData('category')];
},
radius: node => {
if (node.getData('processed') || node.getData('trained')) return 20;
return 10;
},
opacity: node => {
if (!node.getData('processed') || !highlightCorrect.checked)
return undefined;
if (node.getData('category') === node.getData('predictedCategory'))
return 1;
return 0.25;
}
});
function getAdjacencyMatrix({ nodes, edges }: RawGraph) {
const nodesLength = nodes.length;
const adjacencyMatrix: number[][] = Array.from({ length: nodesLength }, () =>
Array(nodesLength).fill(0)
);
edges.forEach(edge => {
adjacencyMatrix[+edge.source][+edge.target] = 1;
adjacencyMatrix[+edge.target][+edge.source] = 1;
});
return adjacencyMatrix;
}
let correctPredictions = 0;
let totalPredictedNodes = 0;
let currentAccuracy = 0;
function updateAccuracy(correct: boolean) {
if (correct) correctPredictions++;
totalPredictedNodes++;
currentAccuracy = (correctPredictions / totalPredictedNodes) * 100;
showMessage(`Accuracy: <strong>${currentAccuracy.toFixed(2)}%</strong>`);
}
function initializeProgressBar(totalNodes: number, name: string) {
const progressBarId = `progress-bar-${name}`;
const existingProgressBar = document.getElementById(progressBarId);
if (existingProgressBar) return;
progressBarContainer.innerHTML += `
<label for="progress-bar-${name}"> ${name} Progress:</label>
<progress id="progress-bar-${name}" value="0" max="100"></progress>`;
const progressBar = document.getElementById(
`progress-bar-${name}`
)! as HTMLProgressElement;
progressBar.value = 0;
progressBar.max = totalNodes;
}
function updateProgressBar(currentProgress: number, name: string) {
const progressBar = document.getElementById(
`progress-bar-${name}`
)! as HTMLProgressElement;
progressBar.value = currentProgress;
}
// @ts-expect-error
function predict(adj: tf.tensor, model: tf.LayersModel, node: Node<NodeData>) {
if (node.getData('trained')) return;
const fts = tf.tensor([node.getData('features')]);
const adjTensor = tf.tensor([adj]);
const prediction = model.predict([fts, adj]);
const predictedCategory = prediction.argMax(1).dataSync()[0];
const correct = predictedCategory === node.getData('category');
node.setData('predictedCategory', predictedCategory);
node.setData('processed', true);
updateAccuracy(correct);
return { correct, predictedCategory };
}
const delay = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
let hasPredictedAllNodes = false;
// @ts-expect-error
async function predictAllNodes(adj: tf.Tensor, model: tf.LayersModel) {
if (hasPredictedAllNodes) return;
hasPredictedAllNodes = true;
const allNodes = ogma.getNodes();
initializeProgressBar(allNodes.size, 'Predict');
for (let i = 0; i < 2708; i++) {
const node = allNodes.get(i);
const nodeId = node.getId() as number;
const adjTensor = tf.tensor([adj[nodeId]]);
predict(adjTensor, model, node);
updateProgressBar(i + 1, 'Predict');
if (i % 50 === 0) await delay(1);
}
setTimeout(() => {
progressBarContainer.innerHTML = '';
}, 5000);
}
function showWarningPopup() {
const warningPopup = document.getElementById('warning-popup');
if (warningPopup) {
warningPopup.style.display = 'block';
setTimeout(hideWarningPopup, 3000);
}
}
function hideWarningPopup() {
const warningPopup = document.getElementById('warning-popup');
if (warningPopup) warningPopup.style.display = 'none';
}
// @ts-expect-error
async function trainModel(adj: tf.tensor, model: tf.LayersModel) {
const selectedNodes = ogma.getSelectedNodes();
const epoch = 30;
const trainButton = document.getElementById('train') as HTMLButtonElement;
const buttonText = trainButton.textContent?.toString()!;
if (selectedNodes.size < 150) {
showWarningPopup();
trainButton.disabled = true;
trainButton.textContent = 'Selecting nodes for training';
ogma.tools.lasso.enable(lassoOptions);
ogma.events.once('nodesSelected', () => {
trainButton.removeAttribute('disabled');
trainButton.textContent = `Train on ${ogma.getSelectedNodes().size} nodes`;
});
return;
}
initializeProgressBar(epoch, 'Training');
hideWarningPopup();
trainButton.disabled = true;
const sampleSize = selectedNodes.size / ogma.getNodes().size;
trainButton.textContent = `Training on ${selectedNodes.size} nodes (${Math.floor(sampleSize * 100)}%)...`;
await delay(200);
const featuresArray = selectedNodes.getData('features');
const labelsArray = selectedNodes.getData('category');
const labels1 = tf.tensor1d(labelsArray, 'int32');
const labelsTensor = tf.oneHot(labels1, 7);
const adjArray = selectedNodes.map(node => {
const nodeId = node.getId() as number;
return adj[nodeId];
});
selectedNodes.fillData('trained', true).setSelected(false);
const featuresTensor = tf.tensor(featuresArray);
const adjTensor = tf.tensor2d(adjArray);
model.compile({
optimizer: tf.train.adam(0.01),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const start = Date.now();
await model.fit([featuresTensor, adjTensor], labelsTensor, {
epochs: epoch,
batchSize: 32,
shuffle: false,
verbose: 1,
validationSplit: 0.1,
callbacks: {
// @ts-expect-error
onEpochEnd: (epoch: number, logs: tf.Logs) => {
updateProgressBar(epoch + 1, 'Training');
// console.log(
// `Epoch ${epoch + 1}: loss = ${logs?.loss}, accuracy = ${logs?.acc}`
// );
}
}
});
trainButton.disabled = false;
trainButton.textContent = buttonText;
showMessage(
`Training completed in <strong>${humanizeDuration(Date.now() - start)}</strong>`
);
trainButton.textContent = 'Train';
selectedNodes.fillData('trained', false);
progressBarContainer.innerHTML = '';
featuresTensor.dispose();
labelsTensor.dispose();
adjTensor.dispose();
}
function showMessage(msg: string) {
info.innerHTML = msg;
}
function applyHighlightMistakes() {
const allNodes = ogma.getNodes();
allNodes.forEach(node => {
const nodeData = node.getData();
const correct = nodeData.category === nodeData.predictedCategory;
if (
nodeData.processed &&
!nodeData.trained &&
highlightCorrect.checked &&
!correct
) {
// node.pulse(pulse_options);
node.setData('showBadPrediction', true);
} else {
if (node.getData('showBadPrediction')) {
node.setData('showBadPrediction', false);
}
}
});
}
function refresh() {
const allNodes = ogma.getNodes();
allNodes
.fillData('processed', false)
.fillData('trained', false)
.fillData('showBadPrediction', false)
.fillData('predictedCategory', undefined);
hasPredictedAllNodes = false;
progressBarContainer.innerHTML = '';
info.innerHTML = '';
}
const info = document.querySelector('#info')!;
const progressBarContainer = document.getElementById('progress-bar-container')!;
const highlightCorrect = document.getElementById(
'highlight-correct'
) as HTMLInputElement;
highlightCorrect.addEventListener('change', applyHighlightMistakes);
console.time('Loading graph');
const graph = await Ogma.parse.jsonFromUrl<NodeData, unknown>(
'files/cora_positions.json'
);
await ogma.view.locateRawGraph(graph);
console.timeEnd('Loading graph');
//const graph = await Ogma.parse.jsonFromUrl('files/cora.json');
await ogma.setGraph(graph);
const adjacencyMatrix = getAdjacencyMatrix(graph);
const model = await tf.loadLayersModel('files/pre_trained.json', {
//weightPathPrefix: './'
});
function onNodeClick(node: Node<NodeData>) {
if (popup) return null;
const id = node.getId() as number;
const adjacencyTensor = tf.tensor([adjacencyMatrix[id]]);
// use pre-trained model to predict the category of the node
// @ts-expect-error
const { correct, predictedCategory } = predict(adjacencyTensor, model, node);
const icon = correct ? checkIcon : failIcon;
const realColor = COLORS[node.getData('category')];
const predictedColor = COLORS[predictedCategory];
const content = `
<div class="prediction-result">
<h2 class="${correct ? 'success' : 'danger'}"><span class="icon">${icon}</span> Paper #${node.getData('paperId')}</h2>
<table>
<tr>
<td>Predicted</td>
<td><span class="category" style="background-color: ${predictedColor}"></span></td>
<td>${CATEGORIES[predictedCategory]}</td>
</tr>
<tr>
<td>Actual</td>
<td><span class="category" style="background-color: ${realColor}"></span></td>
<td>${CATEGORIES[node.getData('category')]}</td>
</tr>
</table>
</div>
`;
openPopup(content, node.getPosition());
}
ogma.events.on('click', ({ target }) => {
if (target && target.isNode) onNodeClick(target);
else if (popup) closePopup();
});
document.getElementById('predict')?.addEventListener('click', () => {
predictAllNodes(adjacencyMatrix, model);
});
document.getElementById('train')?.addEventListener('click', () => {
trainModel(adjacencyMatrix, model);
});
document.getElementById('refresh')?.addEventListener('click', () => {
refresh();
});
ogma.events.on('dragStart', () => {
if (ogma.keyboard.isKeyPressed('ctrl')) {
ogma.getSelectedEdges().setSelected(false);
ogma.getSelectedNodes().setSelected(false);
ogma.tools.lasso.enable(lassoOptions);
}
});
function humanizeDuration(ms: number): string {
// Define the various time units in milliseconds.
const msInASecond = 1000;
const msInAMinute = 60 * msInASecond;
const msInAnHour = 60 * msInAMinute;
const msInADay = 24 * msInAnHour;
let remainingMs = ms;
// Calculate the number of days, hours, minutes, seconds, and milliseconds.
const days = Math.floor(remainingMs / msInADay);
remainingMs %= msInADay;
const hours = Math.floor(remainingMs / msInAnHour);
remainingMs %= msInAnHour;
const minutes = Math.floor(remainingMs / msInAMinute);
remainingMs %= msInAMinute;
const seconds = Math.floor(remainingMs / msInASecond);
remainingMs %= msInASecond;
// Construct the humanized duration string.
const parts: string[] = [];
if (days) parts.push(`${days}d`);
if (hours) parts.push(`${hours}h`);
if (minutes) parts.push(`${minutes}min`);
if (seconds || (!days && !hours && !minutes && !remainingMs))
parts.push(`${seconds}s`);
if (remainingMs) parts.push(`${remainingMs.toFixed()} ms`);
return parts.join(' ');
}
html
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<link type="text/css" rel="stylesheet" href="styles.css" />
<script type="importmap">
{
"imports": {
"@tensorflow/tfjs": "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/+esm"
}
}
</script>
<link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
<link
href="https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;1,100;1,200;1,300;1,400;1,500;1,600;1,700&display=swap"
rel="stylesheet"
/>
<!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.20.0/dist/tf.min.js"></script> -->
</head>
<body>
<div id="graph-container"></div>
<div class="controls">
<p>
<button id="predict">Predict all</button>
</p>
<p>
<button id="train">Train</button>
</p>
<p>
<button id="refresh">Clear</button>
</p>
<p>
<input type="checkbox" id="highlight-correct" />
<label for="highlight-correct">Hide incorrect predictions</label>
</p>
<div id="info"></div>
<div id="progress-bar-container"></div>
<div id="warning-popup">
Use lasso to select training data, try to select 150 nodes or more for
higher accuracy.
</div>
</div>
<div id="legend-container">
<ul id="legend-list"></ul>
</div>
<script type="module" src="index.ts"></script>
</body>
</html>
css
:root {
--base-color: #4999f7;
--active-color: var(--base-color);
--white: #ffffff;
--gray: #a9a9a9;
--lighter-gray: #f4f4f4;
--light-gray: #e6e6e6;
--inactive-color: #acacac;
--group-color: #525fe1;
--group-inactive-color: #c2c8ff;
--selection-color: #04ddcb;
--country-color: #044b87;
--danger-color: #842224;
--country-inactive-color: #bccddb;
--dark-color: #3a3535;
--edge-color: var(--dark-color);
--border-radius: 3px;
--button-border-radius: var(--border-radius);
--edge-inactive-color: var(--gray);
--button-background-color: #ffffff;
--shadow-color: rgba(0, 0, 0, 0.25);
--shadow-hover-color: rgba(0, 0, 0, 0.5);
--button-shadow: 0 0 4px var(--shadow-color);
--button-shadow-hover: 0 0 4px var(--shadow-hover-color);
--button-icon-color: #000000;
--button-icon-hover-color: var(--active-color);
--overlay-background-color: #fff;
--overlay-text-color: #444;
--button-text-color: var(--white);
}
body {
font-family: 'IBM Plex Sans', sans-serif;
}
#graph-container {
top: 0;
bottom: 0;
left: 0;
right: 0;
position: absolute;
margin: 0;
overflow: hidden;
}
.ogma-tooltip,
.ogma-popup {
z-index: 401;
box-sizing: border-box;
}
.ogma-tooltip--content,
.ogma-popup--body {
transform: translate(-50%, 0);
background-color: var(--overlay-background-color);
color: var(--overlay-text-color);
border-radius: 5px;
padding: 5px;
box-sizing: border-box;
box-shadow: 0 8px 30px rgb(0 0 0 / 12%);
width: auto;
height: auto;
position: relative;
}
.ogma-tooltip {
pointer-events: none;
}
.ogma-popup--body {
transform: translate(-50%, -100%);
}
.ogma-tooltip--content:after,
.ogma-popup--body:after {
content: '';
width: 0;
height: 0;
border-style: solid;
border-width: 6px 7px 6px 0;
border-color: transparent var(--overlay-background-color) transparent
transparent;
position: absolute;
left: 50%;
top: auto;
bottom: 3px;
right: auto;
transform: translate(-50%, 100%) rotate(270deg);
}
.ogma-popup--close {
position: absolute;
top: 0px;
right: 5px;
cursor: pointer;
}
.ogma-popup--top .ogma-popup--body,
.ogma-tooltip--top .ogma-tooltip--content {
bottom: 6px;
transform: translate(-50%, -100%);
}
.ogma-popup--bottom .ogma-popup--body,
.ogma-tooltip--bottom .ogma-tooltip--content {
transform: translate(-50%, 0%);
top: 3px;
}
.ogma-popup--bottom .ogma-popup--body:after,
.ogma-tooltip--bottom .ogma-tooltip--content:after {
top: 3px;
bottom: auto;
transform: translate(-50%, -100%) rotate(90deg);
}
.ogma-popup--right .ogma-popup--body,
.ogma-tooltip--right .ogma-tooltip--content {
transform: translate(0, -50%);
left: 6px;
}
.ogma-popup--right .ogma-popup--body:after,
.ogma-tooltip--right .ogma-tooltip--content:after {
left: 0%;
top: 50%;
transform: translate(-100%, -50%) rotate(0deg);
}
.ogma-popup--left .ogma-popup--body,
.ogma-tooltip--left .ogma-tooltip--content {
transform: translate(-100%, -50%);
right: 6px;
}
.ogma-popup--left .ogma-popup--body:after,
.ogma-tooltip--left .ogma-tooltip--content:after {
right: 0%;
left: auto;
top: 50%;
transform: translate(100%, -50%) rotate(180deg);
}
.ogma-popup--content {
padding: 10px;
}
.icon {
vertical-align: middle;
}
.success {
color: darkgreen;
}
.danger {
color: var(--danger-color);
}
.prediction-result .category {
border-radius: 50%;
width: 1em;
height: 1em;
display: inline-block;
vertical-align: text-top;
}
.controls {
position: absolute;
top: 10px;
right: 10px;
z-index: 400;
padding: 1em;
border-radius: var(--border-radius);
box-shadow: var(--button-shadow);
min-width: 200px;
background-color: var(--white);
font-weight: 300;
}
.controls p button {
display: block;
width: 100%;
background-color: var(--base-color);
border-radius: var(--button-border-radius);
color: var(--button-text-color);
padding: 0.5em 0.5em;
border: 0;
}
.controls p button:disabled {
background-color: var(--inactive-color);
}
#progress-bar-container {
display: flex;
flex-direction: column;
flex-wrap: wrap;
gap: 10px;
}
#progress-bar-container label {
margin-right: 5px;
}
progress {
width: 100%;
border-radius: 15px;
background-color: #e0e0e0;
border: none;
overflow: hidden;
}
progress::-webkit-progress-bar {
background-color: #e0e0e0;
border-radius: 15px;
}
progress::-webkit-progress-value {
background-color: var(--base-color);
border-radius: 15px;
transition: width 0.4s ease;
}
progress::-moz-progress-bar {
background-color: var(--base-color);
border-radius: 15px;
transition: width 0.4s ease;
}
label[for='progress-bar'] {
display: block;
margin-bottom: 8px;
}
#legend-container {
position: fixed;
bottom: 20px;
right: 20px;
background-color: #fff;
border: 1px solid #ddd;
border-radius: 8px;
padding: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
z-index: 1000;
font-weight: 300;
font-size: 14px;
}
#legend-list {
list-style: none;
padding: 0;
margin: 0;
}
#legend-list li {
display: flex;
align-items: center;
margin-bottom: 5px;
}
.legend-color-box {
width: 16px;
height: 16px;
margin-right: 8px;
border-radius: 4px;
border: 1px solid #ccc;
}
#warning-popup {
display: none;
margin-top: 10px;
color: var(--danger-color);
max-width: 200px;
}
#info {
max-width: 200px;
}
#info strong {
white-space: nowrap;
}
json
{
"dependencies": {
"@tensorflow/tfjs": "4.20.0"
}
}
ts
export const checkIcon = `<svg width="24px" height="24px" stroke-width="1.5" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" color="currentColor"><path d="M7 12.5L10 15.5L17 8.5" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path><path d="M12 22C17.5228 22 22 17.5228 22 12C22 6.47715 17.5228 2 12 2C6.47715 2 2 6.47715 2 12C2 17.5228 6.47715 22 12 22Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path></svg>`;
export const failIcon = `<svg width="24px" height="24px" stroke-width="1.5" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" color="currentColor"><path d="M9.17218 14.8284L12.0006 12M14.829 9.17157L12.0006 12M12.0006 12L9.17218 9.17157M12.0006 12L14.829 14.8284" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path><path d="M12 22C17.5228 22 22 17.5228 22 12C22 6.47715 17.5228 2 12 2C6.47715 2 2 6.47715 2 12C2 17.5228 6.47715 22 12 22Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path></svg>`;