mirror of
https://github.com/cupcakearmy/mnist.git
synced 2024-12-22 16:16:32 +00:00
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""simple.ipynb
|
|
|
|
Automatically generated by Colaboratory.
|
|
|
|
Original file is located at
|
|
https://colab.research.google.com/drive/1CGpActfOTQuiUkle2q40rg8Bf7ijfyAU
|
|
"""
|
|
|
|
from tensorflow.keras.datasets import mnist
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import Dense, Dropout, Flatten
|
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D
|
|
from tensorflow.keras.utils import to_categorical
|
|
|
|
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
|
|
|
# Reshaping for channels_last (tensorflow) with one channel
|
|
size = 28
|
|
print(x_train.shape, x_test.shape)
|
|
x_train = x_train.reshape(len(x_train), size, size, 1).astype('float32')
|
|
x_test = x_test.reshape(len(x_test), size, size, 1).astype('float32')
|
|
print(x_train.shape, x_test.shape)
|
|
|
|
# Normalize
|
|
upper = max(x_train.max(), x_test.max())
|
|
lower = min(x_train.min(), x_test.min())
|
|
print(f'Max: {upper} Min: {lower}')
|
|
x_train /= upper
|
|
x_test /= upper
|
|
|
|
total_classes = 10
|
|
y_train = to_categorical(y_train, total_classes)
|
|
y_test = to_categorical(y_test, total_classes)
|
|
|
|
# Make the model
|
|
model = Sequential()
|
|
model.add(Conv2D(64, (3, 3), activation='relu', input_shape=(size,size, 1), data_format='channels_last'))
|
|
model.add(Conv2D(32, (3, 3), activation='relu'))
|
|
model.add(MaxPooling2D(pool_size=(2, 2)))
|
|
model.add(Dropout(0.25))
|
|
model.add(Flatten())
|
|
model.add(Dense(128, activation='relu'))
|
|
model.add(Dropout(0.5))
|
|
model.add(Dense(total_classes, activation='softmax'))
|
|
|
|
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
|
|
|
# Train
|
|
model.fit(x_train, y_train,
|
|
batch_size=32,
|
|
epochs=12,
|
|
verbose=True)
|
|
|
|
score = model.evaluate(x_test, y_test, verbose=0)
|
|
print('Test loss:', score[0])
|
|
print('Test accuracy:', score[1])
|
|
|
|
# Save for keras
|
|
model.save("model.h5")
|
|
|
|
!pip install tensorflowjs
|
|
import tensorflowjs as tfjs
|
|
|
|
# Save for the web
|
|
tfjs.converters.save_keras_model(model, './js')
|
|
|