mirror of
https://github.com/cupcakearmy/nicco.io.git
synced 2025-09-06 02:30:45 +00:00
cleanup
This commit is contained in:
@@ -25,18 +25,6 @@ Today we will look on how to train a simple mnist digit recogniser and then expo
|
||||
|
||||
Also I am not going to explain what machine learning is, as there are enough guides, videos, podcasts, ... that already do a much better job than I could and would be outside the scope of this article.
|
||||
|
||||
<figure>
|
||||
|
||||

|
||||
|
||||
<figcaption>
|
||||
|
||||
Photo by [Natasha Connell](https://unsplash.com/@natcon773?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText) on [Unsplash](https://unsplash.com/s/photos/brain?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText)
|
||||
|
||||
</figcaption>
|
||||
|
||||
</figure>
|
||||
|
||||
So the first thing we need to understand is that we will not train the model in the browser. That is a job for GPUs and the goal here is only to use a pre-trained model inside of the browser. Training is a much more resource intensive task than simply using the net.
|
||||
|
||||
## Training the model
|
||||
@@ -46,7 +34,7 @@ So, the first step is to actually have a model. I will do this in tensorflow 2.0
|
||||
The code below is basically an adapted version of the [keras hello world example](https://keras.io/examples/mnist_cnn/).
|
||||
If you want to run the code yourself (which you should!) simply head over to [Google Colab](https://colab.research.google.com), create a new file and just paste the code. There you can run it for free on GPUs which is pretty dope!
|
||||
|
||||
```
|
||||
```py
|
||||
from tensorflow.keras.datasets import mnist
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import Dense, Dropout, Flatten
|
||||
@@ -106,7 +94,7 @@ Unfortunately this is not compatible with tensorflow-js. So we need another way.
|
||||
|
||||
There is a package called tensorflowjs for python (confusing right? 😅) that provides the functionality we need
|
||||
|
||||
```
|
||||
```ts
|
||||
import tensorflowjs as tfjs
|
||||
|
||||
tfjs.converters.save_keras_model(model, './js')
|
||||
@@ -119,19 +107,19 @@ Inside there you will find a `model.json` that basically describes the structure
|
||||
|
||||
Now we are ready to import that. First we need to install the `@tensorflow/tfjs` package.
|
||||
|
||||
```
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
```ts
|
||||
import * as tf from '@tensorflow/tfjs'
|
||||
|
||||
let model
|
||||
|
||||
tf.loadLayersModel('/model.json').then(m => {
|
||||
model = m
|
||||
tf.loadLayersModel('/model.json').then((m) => {
|
||||
model = m
|
||||
})
|
||||
```
|
||||
|
||||
Ok how do I use that now?
|
||||
|
||||
```
|
||||
```ts
|
||||
const tensor = tf.tensor(new Uint8Array(ourData), [1, 28, 28, 1])
|
||||
const prediction = model.predict(tensor)
|
||||
```
|
||||
@@ -146,67 +134,65 @@ I'm not gonna talk about what bundler, etc. I'm using. If you interested simply
|
||||
|
||||
First lets write some basic html for the skeleton of our page.
|
||||
|
||||
```
|
||||
```html
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<head>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
font-family: monospace;
|
||||
}
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
width: 100vw;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
html,
|
||||
body {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
height: 100vh;
|
||||
width: 100vw;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
body>div {
|
||||
text-align: center;
|
||||
}
|
||||
body > div {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
div canvas {
|
||||
display: inline-block;
|
||||
border: 1px solid;
|
||||
}
|
||||
div canvas {
|
||||
display: inline-block;
|
||||
border: 1px solid;
|
||||
}
|
||||
|
||||
div input {
|
||||
display: inline-block;
|
||||
margin-top: .5em;
|
||||
padding: .5em 2em;
|
||||
background: white;
|
||||
outline: none;
|
||||
border: 1px solid;
|
||||
font-weight: bold;
|
||||
}
|
||||
div input {
|
||||
display: inline-block;
|
||||
margin-top: 0.5em;
|
||||
padding: 0.5em 2em;
|
||||
background: white;
|
||||
outline: none;
|
||||
border: 1px solid;
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<body>
|
||||
<div>
|
||||
<h1>MNIST (Pretrained)</h1>
|
||||
<canvas id="can" width="28" height="28"></canvas>
|
||||
<br />
|
||||
<input id="clear" type="button" value="clear">
|
||||
<br />
|
||||
<input id="test" type="button" value="test">
|
||||
<br />
|
||||
<h2 id="result"></h2>
|
||||
<a href="https://github.com/cupcakearmy/mnist">
|
||||
<h3>source code</h3>
|
||||
</a>
|
||||
<h1>MNIST (Pretrained)</h1>
|
||||
<canvas id="can" width="28" height="28"></canvas>
|
||||
<br />
|
||||
<input id="clear" type="button" value="clear" />
|
||||
<br />
|
||||
<input id="test" type="button" value="test" />
|
||||
<br />
|
||||
<h2 id="result"></h2>
|
||||
<a href="https://github.com/cupcakearmy/mnist">
|
||||
<h3>source code</h3>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<script src="./tf.js"></script>
|
||||
<script src="./canvas.js"></script>
|
||||
</body>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
@@ -215,85 +201,102 @@ The code is adapted from [this stackoverflow answer](https://stackoverflow.com/a
|
||||
|
||||
In essence it's a canvas that listens on our mouse events and fills the pixels with black. Nothing more.
|
||||
|
||||
```
|
||||
```ts
|
||||
/* jslint esversion: 6, asi: true */
|
||||
|
||||
var canvas, ctx, flag = false,
|
||||
prevX = 0,
|
||||
currX = 0,
|
||||
prevY = 0,
|
||||
currY = 0,
|
||||
dot_flag = false;
|
||||
var canvas,
|
||||
ctx,
|
||||
flag = false,
|
||||
prevX = 0,
|
||||
currX = 0,
|
||||
prevY = 0,
|
||||
currY = 0,
|
||||
dot_flag = false
|
||||
|
||||
var x = "black",
|
||||
y = 2;
|
||||
var x = 'black',
|
||||
y = 2
|
||||
|
||||
function init() {
|
||||
canvas = document.getElementById('can');
|
||||
ctx = canvas.getContext("2d");
|
||||
w = canvas.width;
|
||||
h = canvas.height;
|
||||
canvas = document.getElementById('can')
|
||||
ctx = canvas.getContext('2d')
|
||||
w = canvas.width
|
||||
h = canvas.height
|
||||
|
||||
canvas.addEventListener("mousemove", function (e) {
|
||||
findxy('move', e)
|
||||
}, false);
|
||||
canvas.addEventListener("mousedown", function (e) {
|
||||
findxy('down', e)
|
||||
}, false);
|
||||
canvas.addEventListener("mouseup", function (e) {
|
||||
findxy('up', e)
|
||||
}, false);
|
||||
canvas.addEventListener("mouseout", function (e) {
|
||||
findxy('out', e)
|
||||
}, false);
|
||||
canvas.addEventListener(
|
||||
'mousemove',
|
||||
function (e) {
|
||||
findxy('move', e)
|
||||
},
|
||||
false
|
||||
)
|
||||
canvas.addEventListener(
|
||||
'mousedown',
|
||||
function (e) {
|
||||
findxy('down', e)
|
||||
},
|
||||
false
|
||||
)
|
||||
canvas.addEventListener(
|
||||
'mouseup',
|
||||
function (e) {
|
||||
findxy('up', e)
|
||||
},
|
||||
false
|
||||
)
|
||||
canvas.addEventListener(
|
||||
'mouseout',
|
||||
function (e) {
|
||||
findxy('out', e)
|
||||
},
|
||||
false
|
||||
)
|
||||
|
||||
|
||||
window.document.getElementById('clear').addEventListener('click', erase)
|
||||
window.document.getElementById('clear').addEventListener('click', erase)
|
||||
}
|
||||
|
||||
function draw() {
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(prevX, prevY);
|
||||
ctx.lineTo(currX, currY);
|
||||
ctx.strokeStyle = x;
|
||||
ctx.lineWidth = y;
|
||||
ctx.stroke();
|
||||
ctx.closePath();
|
||||
ctx.beginPath()
|
||||
ctx.moveTo(prevX, prevY)
|
||||
ctx.lineTo(currX, currY)
|
||||
ctx.strokeStyle = x
|
||||
ctx.lineWidth = y
|
||||
ctx.stroke()
|
||||
ctx.closePath()
|
||||
}
|
||||
|
||||
function erase() {
|
||||
ctx.clearRect(0, 0, w, h);
|
||||
ctx.clearRect(0, 0, w, h)
|
||||
}
|
||||
|
||||
function findxy(res, e) {
|
||||
if (res == 'down') {
|
||||
prevX = currX;
|
||||
prevY = currY;
|
||||
currX = e.clientX - canvas.offsetLeft;
|
||||
currY = e.clientY - canvas.offsetTop;
|
||||
if (res == 'down') {
|
||||
prevX = currX
|
||||
prevY = currY
|
||||
currX = e.clientX - canvas.offsetLeft
|
||||
currY = e.clientY - canvas.offsetTop
|
||||
|
||||
flag = true;
|
||||
dot_flag = true;
|
||||
if (dot_flag) {
|
||||
ctx.beginPath();
|
||||
ctx.fillStyle = x;
|
||||
ctx.fillRect(currX, currY, 2, 2);
|
||||
ctx.closePath();
|
||||
dot_flag = false;
|
||||
}
|
||||
flag = true
|
||||
dot_flag = true
|
||||
if (dot_flag) {
|
||||
ctx.beginPath()
|
||||
ctx.fillStyle = x
|
||||
ctx.fillRect(currX, currY, 2, 2)
|
||||
ctx.closePath()
|
||||
dot_flag = false
|
||||
}
|
||||
if (res == 'up' || res == "out") {
|
||||
flag = false;
|
||||
}
|
||||
if (res == 'move') {
|
||||
if (flag) {
|
||||
prevX = currX;
|
||||
prevY = currY;
|
||||
currX = e.clientX - canvas.offsetLeft;
|
||||
currY = e.clientY - canvas.offsetTop;
|
||||
draw();
|
||||
}
|
||||
}
|
||||
if (res == 'up' || res == 'out') {
|
||||
flag = false
|
||||
}
|
||||
if (res == 'move') {
|
||||
if (flag) {
|
||||
prevX = currX
|
||||
prevY = currY
|
||||
currX = e.clientX - canvas.offsetLeft
|
||||
currY = e.clientY - canvas.offsetTop
|
||||
draw()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
init()
|
||||
@@ -301,26 +304,26 @@ init()
|
||||
|
||||
And not the glue to put this together is the piece of code that listens on the "test" button.
|
||||
|
||||
```
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
```ts
|
||||
import * as tf from '@tensorflow/tfjs'
|
||||
|
||||
let model
|
||||
|
||||
tf.loadLayersModel('/model.json').then(m => {
|
||||
model = m
|
||||
tf.loadLayersModel('/model.json').then((m) => {
|
||||
model = m
|
||||
})
|
||||
|
||||
window.document.getElementById('test').addEventListener('click', async () => {
|
||||
const canvas = window.document.querySelector('canvas')
|
||||
const canvas = window.document.querySelector('canvas')
|
||||
|
||||
const { data, width, height } = canvas.getContext('2d').getImageData(0, 0, 28, 28)
|
||||
const { data, width, height } = canvas.getContext('2d').getImageData(0, 0, 28, 28)
|
||||
|
||||
const tensor = tf.tensor(new Uint8Array(data.filter((_, i) => i % 4 === 3)), [1, 28, 28, 1])
|
||||
const prediction = model.predict(tensor)
|
||||
const result = await prediction.data()
|
||||
const guessed = result.indexOf(1)
|
||||
console.log(guessed)
|
||||
window.document.querySelector('#result').innerText = guessed
|
||||
const tensor = tf.tensor(new Uint8Array(data.filter((_, i) => i % 4 === 3)), [1, 28, 28, 1])
|
||||
const prediction = model.predict(tensor)
|
||||
const result = await prediction.data()
|
||||
const guessed = result.indexOf(1)
|
||||
console.log(guessed)
|
||||
window.document.querySelector('#result').innerText = guessed
|
||||
})
|
||||
```
|
||||
|
||||
@@ -329,7 +332,7 @@ Here we need to explain a few things.
|
||||
|
||||
Then, instead of simply passing the data to the tensor. we need to do some magic with `data.filter` in order to get only every 3rd pixel. This is because our canvas has 3 channels + 1 alpha, but we only need to know if the pixel is black or not. We do this by simply filtering for the index mod 4
|
||||
|
||||
```
|
||||
```ts
|
||||
data.filter((_, i) => i % 4 === 3)
|
||||
```
|
||||
|
||||
|
Reference in New Issue
Block a user