This commit is contained in:
2024-12-01 17:36:36 +01:00
parent b5af0cae15
commit 75d19fa5d9
14 changed files with 266 additions and 446 deletions

View File

@@ -25,18 +25,6 @@ Today we will look on how to train a simple mnist digit recogniser and then expo
Also I am not going to explain what machine learning is, as there are enough guides, videos, podcasts, ... that already do a much better job than I could and would be outside the scope of this article.
<figure>
![](images/natasha-connell-byp5TTxUbL0-unsplash-scaled-1.jpg)
<figcaption>
Photo by [Natasha Connell](https://unsplash.com/@natcon773?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText) on [Unsplash](https://unsplash.com/s/photos/brain?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText)
</figcaption>
</figure>
So the first thing we need to understand is that we will not train the model in the browser. That is a job for GPUs and the goal here is only to use a pre-trained model inside of the browser. Training is a much more resource intensive task than simply using the net.
## Training the model
@@ -46,7 +34,7 @@ So, the first step is to actually have a model. I will do this in tensorflow 2.0
The code below is basically an adapted version of the [keras hello world example](https://keras.io/examples/mnist_cnn/).
If you want to run the code yourself (which you should!) simply head over to [Google Colab](https://colab.research.google.com), create a new file and just paste the code. There you can run it for free on GPUs which is pretty dope!
```
```py
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
@@ -106,7 +94,7 @@ Unfortunately this is not compatible with tensorflow-js. So we need another way.
There is a package called tensorflowjs for python (confusing right? 😅) that provides the functionality we need
```
```ts
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, './js')
@@ -119,19 +107,19 @@ Inside there you will find a `model.json` that basically describes the structure
Now we are ready to import that. First we need to install the `@tensorflow/tfjs` package.
```
import * as tf from '@tensorflow/tfjs';
```ts
import * as tf from '@tensorflow/tfjs'
let model
tf.loadLayersModel('/model.json').then(m => {
model = m
tf.loadLayersModel('/model.json').then((m) => {
model = m
})
```
Ok how do I use that now?
```
```ts
const tensor = tf.tensor(new Uint8Array(ourData), [1, 28, 28, 1])
const prediction = model.predict(tensor)
```
@@ -146,67 +134,65 @@ I'm not gonna talk about what bundler, etc. I'm using. If you interested simply
First lets write some basic html for the skeleton of our page.
```
```html
<html>
<head>
<head>
<style>
* {
box-sizing: border-box;
font-family: monospace;
}
* {
box-sizing: border-box;
font-family: monospace;
}
html,
body {
padding: 0;
margin: 0;
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
html,
body {
padding: 0;
margin: 0;
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
body>div {
text-align: center;
}
body > div {
text-align: center;
}
div canvas {
display: inline-block;
border: 1px solid;
}
div canvas {
display: inline-block;
border: 1px solid;
}
div input {
display: inline-block;
margin-top: .5em;
padding: .5em 2em;
background: white;
outline: none;
border: 1px solid;
font-weight: bold;
}
div input {
display: inline-block;
margin-top: 0.5em;
padding: 0.5em 2em;
background: white;
outline: none;
border: 1px solid;
font-weight: bold;
}
</style>
</head>
</head>
<body>
<body>
<div>
<h1>MNIST (Pretrained)</h1>
<canvas id="can" width="28" height="28"></canvas>
<br />
<input id="clear" type="button" value="clear">
<br />
<input id="test" type="button" value="test">
<br />
<h2 id="result"></h2>
<a href="https://github.com/cupcakearmy/mnist">
<h3>source code</h3>
</a>
<h1>MNIST (Pretrained)</h1>
<canvas id="can" width="28" height="28"></canvas>
<br />
<input id="clear" type="button" value="clear" />
<br />
<input id="test" type="button" value="test" />
<br />
<h2 id="result"></h2>
<a href="https://github.com/cupcakearmy/mnist">
<h3>source code</h3>
</a>
</div>
<script src="./tf.js"></script>
<script src="./canvas.js"></script>
</body>
</body>
</html>
```
@@ -215,85 +201,102 @@ The code is adapted from [this stackoverflow answer](https://stackoverflow.com/a
In essence it's a canvas that listens on our mouse events and fills the pixels with black. Nothing more.
```
```ts
/* jslint esversion: 6, asi: true */
var canvas, ctx, flag = false,
prevX = 0,
currX = 0,
prevY = 0,
currY = 0,
dot_flag = false;
var canvas,
ctx,
flag = false,
prevX = 0,
currX = 0,
prevY = 0,
currY = 0,
dot_flag = false
var x = "black",
y = 2;
var x = 'black',
y = 2
function init() {
canvas = document.getElementById('can');
ctx = canvas.getContext("2d");
w = canvas.width;
h = canvas.height;
canvas = document.getElementById('can')
ctx = canvas.getContext('2d')
w = canvas.width
h = canvas.height
canvas.addEventListener("mousemove", function (e) {
findxy('move', e)
}, false);
canvas.addEventListener("mousedown", function (e) {
findxy('down', e)
}, false);
canvas.addEventListener("mouseup", function (e) {
findxy('up', e)
}, false);
canvas.addEventListener("mouseout", function (e) {
findxy('out', e)
}, false);
canvas.addEventListener(
'mousemove',
function (e) {
findxy('move', e)
},
false
)
canvas.addEventListener(
'mousedown',
function (e) {
findxy('down', e)
},
false
)
canvas.addEventListener(
'mouseup',
function (e) {
findxy('up', e)
},
false
)
canvas.addEventListener(
'mouseout',
function (e) {
findxy('out', e)
},
false
)
window.document.getElementById('clear').addEventListener('click', erase)
window.document.getElementById('clear').addEventListener('click', erase)
}
function draw() {
ctx.beginPath();
ctx.moveTo(prevX, prevY);
ctx.lineTo(currX, currY);
ctx.strokeStyle = x;
ctx.lineWidth = y;
ctx.stroke();
ctx.closePath();
ctx.beginPath()
ctx.moveTo(prevX, prevY)
ctx.lineTo(currX, currY)
ctx.strokeStyle = x
ctx.lineWidth = y
ctx.stroke()
ctx.closePath()
}
function erase() {
ctx.clearRect(0, 0, w, h);
ctx.clearRect(0, 0, w, h)
}
function findxy(res, e) {
if (res == 'down') {
prevX = currX;
prevY = currY;
currX = e.clientX - canvas.offsetLeft;
currY = e.clientY - canvas.offsetTop;
if (res == 'down') {
prevX = currX
prevY = currY
currX = e.clientX - canvas.offsetLeft
currY = e.clientY - canvas.offsetTop
flag = true;
dot_flag = true;
if (dot_flag) {
ctx.beginPath();
ctx.fillStyle = x;
ctx.fillRect(currX, currY, 2, 2);
ctx.closePath();
dot_flag = false;
}
flag = true
dot_flag = true
if (dot_flag) {
ctx.beginPath()
ctx.fillStyle = x
ctx.fillRect(currX, currY, 2, 2)
ctx.closePath()
dot_flag = false
}
if (res == 'up' || res == "out") {
flag = false;
}
if (res == 'move') {
if (flag) {
prevX = currX;
prevY = currY;
currX = e.clientX - canvas.offsetLeft;
currY = e.clientY - canvas.offsetTop;
draw();
}
}
if (res == 'up' || res == 'out') {
flag = false
}
if (res == 'move') {
if (flag) {
prevX = currX
prevY = currY
currX = e.clientX - canvas.offsetLeft
currY = e.clientY - canvas.offsetTop
draw()
}
}
}
init()
@@ -301,26 +304,26 @@ init()
And not the glue to put this together is the piece of code that listens on the "test" button.
```
import * as tf from '@tensorflow/tfjs';
```ts
import * as tf from '@tensorflow/tfjs'
let model
tf.loadLayersModel('/model.json').then(m => {
model = m
tf.loadLayersModel('/model.json').then((m) => {
model = m
})
window.document.getElementById('test').addEventListener('click', async () => {
const canvas = window.document.querySelector('canvas')
const canvas = window.document.querySelector('canvas')
const { data, width, height } = canvas.getContext('2d').getImageData(0, 0, 28, 28)
const { data, width, height } = canvas.getContext('2d').getImageData(0, 0, 28, 28)
const tensor = tf.tensor(new Uint8Array(data.filter((_, i) => i % 4 === 3)), [1, 28, 28, 1])
const prediction = model.predict(tensor)
const result = await prediction.data()
const guessed = result.indexOf(1)
console.log(guessed)
window.document.querySelector('#result').innerText = guessed
const tensor = tf.tensor(new Uint8Array(data.filter((_, i) => i % 4 === 3)), [1, 28, 28, 1])
const prediction = model.predict(tensor)
const result = await prediction.data()
const guessed = result.indexOf(1)
console.log(guessed)
window.document.querySelector('#result').innerText = guessed
})
```
@@ -329,7 +332,7 @@ Here we need to explain a few things.
Then, instead of simply passing the data to the tensor. we need to do some magic with `data.filter` in order to get only every 3rd pixel. This is because our canvas has 3 channels + 1 alpha, but we only need to know if the pixel is black or not. We do this by simply filtering for the index mod 4
```
```ts
data.filter((_, i) => i % 4 === 3)
```