- 
  Notifications
 You must be signed in to change notification settings 
- Fork 14
Refactor #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
 
  Open
 
 
 
 
  Open
 Refactor #3
Changes from all commits
 Commits
 
 
 File filter
Filter by extension
Conversations
 Failed to load comments. 
 
 
 
  Loading
 
 Jump to
 
 Jump to file
 
 
 
 Failed to load files. 
 
 
 
  Loading
 
 Diff view
Diff view
There are no files selected for viewing
 
 
 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
 Learn more about bidirectional Unicode characters
 
 
 
 
 | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -3,4 +3,5 @@ data/ | |
| *.jpg | ||
| node_modules/ | ||
| .env | ||
| .DS_Store | ||
| .DS_Store | ||
| models/ | ||
 
 
 
 15 changes: 15 additions & 0 deletions
 
 
 
 ImageTransformer.js
 
 
 
 
  
 
 
 
 
 
 
 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
 Learn more about bidirectional Unicode characters
 
 
 
 
 | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| import Jimp from 'jimp' | ||
|  | ||
| /** | ||
| * Converts pixel arrays to real images which can be saved to disk | ||
| */ | ||
| export class ImageTransformer { | ||
| saveImage(img, width = 28, height = 28, path) { | ||
| new Jimp({ width, height, data: Buffer.from(img) }, (_, img) => img.write(path)) | ||
| } | ||
|  | ||
| toImages(data, filePrefix = 'processed', width = 28, height = 28) { | ||
| const imgs = data.map(img => img.flatMap(val => [val * 255, val * 255, val * 255, 255])) | ||
| imgs.forEach((img, i) => this.saveImage(img, width, height, `output/${filePrefix}_${i}.png`)) | ||
| } | ||
| } | 
 
 
 
 44 changes: 44 additions & 0 deletions
 
 
 
 arbitraryImageDataSource.js
 
 
 
 
  
 
 
 
 
 
 
 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
 Learn more about bidirectional Unicode characters
 
 
 
 
 | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| import fs from 'fs' | ||
| import sharp from 'sharp' | ||
| import tf from '@tensorflow/tfjs-node' | ||
|  | ||
| /** | ||
| * Loads images from the images directory and processes them to fit the model | ||
| */ | ||
| export class ArbitraryImageDataSource { | ||
| constructor(countTraining = 1000, countTest = 10) { | ||
| const files = fs.readdirSync('images') | ||
| .filter(f => f.endsWith('.jpg') || f.endsWith('.jpeg') || f.endsWith('.png')) | ||
| .map(f => `images/${f}`) | ||
|  | ||
| this.trainingFiles = files.shuffle().slice(0, countTraining) | ||
| this.testFiles = files.shuffle().slice(0, countTest) | ||
| } | ||
|  | ||
| async getTrainingData() { | ||
| const data = await Promise.all(this.trainingFiles.map(f => this._processImageFile(f))) | ||
| return tf.tensor(data).div(255) | ||
| } | ||
|  | ||
| async getTestData() { | ||
| const data = await Promise.all(this.testFiles.map(f => this._processImageFile(f))) | ||
| return tf.tensor(data).div(255) | ||
| } | ||
|  | ||
| _processImageFile(filename) { | ||
| return sharp(filename) | ||
| .resize(28, 28, { | ||
| fit: 'cover' | ||
| }) | ||
| .gamma() | ||
| .greyscale() | ||
| .raw() | ||
| .toBuffer() | ||
| } | ||
| } | ||
|  | ||
| Array.prototype.shuffle = function () { | ||
| return this.map((value) => ({ value, sort: Math.random() })) | ||
| .sort((a, b) => a.sort - b.sort) | ||
| .map(({ value }) => value) | ||
| } | 
 
 
 
 144 changes: 31 additions & 113 deletions
 
 
 
 index.js
 
 
 
 
  
 
 
 
 
 
 
 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
 Learn more about bidirectional Unicode characters
 
 
 
 
 | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,119 +1,37 @@ | ||
| console.log("Hello Autoencoder 🚂"); | ||
| import { Model } from './model.js' | ||
| import { MnistDataSource } from './mnistDataSource.js' | ||
| import { ImageTransformer } from './ImageTransformer.js' | ||
| import { RandomDataSource } from './randomDataSource.js' | ||
| import { ArbitraryImageDataSource } from './arbitraryImageDataSource.js' | ||
|  | ||
| import * as tf from "@tensorflow/tfjs-node"; | ||
| // import canvas from "canvas"; | ||
| // const { loadImage } = canvas; | ||
| import Jimp from "jimp"; | ||
| import numeral from "numeral"; | ||
|  | ||
| main(); | ||
| main() | ||
|  | ||
| async function main() { | ||
| // Build the model | ||
| const autoencoder = buildModel(); | ||
| // load all image data | ||
| const images = await loadImages(550); | ||
|  | ||
| // train the model | ||
| const x_train = tf.tensor2d(images.slice(0, 500)); | ||
| await trainModel(autoencoder, x_train, 250); | ||
|  | ||
| // test the model | ||
| const x_test = tf.tensor2d(images.slice(500)); | ||
| await generateTests(autoencoder, x_test); | ||
| } | ||
|  | ||
| async function generateTests(autoencoder, x_test) { | ||
| const output = autoencoder.predict(x_test); | ||
| // output.print(); | ||
|  | ||
| const newImages = await output.array(); | ||
| for (let i = 0; i < newImages.length; i++) { | ||
| const img = newImages[i]; | ||
| const buffer = []; | ||
| for (let n = 0; n < img.length; n++) { | ||
| const val = Math.floor(img[n] * 255); | ||
| buffer[n * 4 + 0] = val; | ||
| buffer[n * 4 + 1] = val; | ||
| buffer[n * 4 + 2] = val; | ||
| buffer[n * 4 + 3] = 255; | ||
| // Instantiate the model | ||
| const model = new Model() | ||
|  | ||
| // Instantiate a data source | ||
| const dataSource = new MnistDataSource() | ||
| // const dataSource = new RandomDataSource() | ||
| // const dataSource = new ArbitraryImageDataSource() | ||
|  | ||
| // Instatiate the Image transformer | ||
| const transformer = new ImageTransformer() | ||
|  | ||
| // Check if there is a pretrained model. If it exists load it, or train the model | ||
| if (model.pretrainedModelExists()) { | ||
| await model.load() | ||
| } else { | ||
| // Create the layers | ||
| model.configure() | ||
| // and train | ||
| await model.train(await dataSource.getTrainingData(), 200) | ||
| } | ||
| const image = new Jimp( | ||
| { | ||
| data: Buffer.from(buffer), | ||
| width: 28, | ||
| height: 28, | ||
| }, | ||
| (err, image) => { | ||
| const num = numeral(i).format("000"); | ||
| image.write(`output/square${num}.png`); | ||
| } | ||
| ); | ||
| } | ||
| } | ||
|  | ||
| function buildModel() { | ||
| const autoencoder = tf.sequential(); | ||
| // Build the model | ||
| autoencoder.add( | ||
| tf.layers.dense({ | ||
| units: 256, | ||
| inputShape: [784], | ||
| activation: "relu", | ||
| }) | ||
| ); | ||
| autoencoder.add( | ||
| tf.layers.dense({ | ||
| units: 128, | ||
| activation: "relu", | ||
| }) | ||
| ); | ||
| // Test the model with testing data from the data source | ||
| const testData = await dataSource.getTestData() | ||
| const autoEncodedImages = model.autoencode(testData) | ||
|  | ||
| autoencoder.add( | ||
| tf.layers.dense({ | ||
| units: 256, | ||
| activation: "sigmoid", | ||
| }) | ||
| ); | ||
|  | ||
| autoencoder.add( | ||
| tf.layers.dense({ | ||
| units: 784, | ||
| activation: "sigmoid", | ||
| }) | ||
| ); | ||
| autoencoder.compile({ | ||
| optimizer: "adam", | ||
| loss: "binaryCrossentropy", | ||
| metrics: ["accuracy"], | ||
| }); | ||
| return autoencoder; | ||
| } | ||
|  | ||
| async function trainModel(autoencoder, x_train, epochs) { | ||
| await autoencoder.fit(x_train, x_train, { | ||
| epochs: epochs, | ||
| batch_size: 32, | ||
| shuffle: true, | ||
| verbose: true, | ||
| }); | ||
| } | ||
|  | ||
| async function loadImages(total) { | ||
| const allImages = []; | ||
| for (let i = 0; i < total; i++) { | ||
| const num = numeral(i).format("000"); | ||
| const img = await Jimp.read(`data/square${num}.png`); | ||
|  | ||
| let rawData = []; | ||
| for (let n = 0; n < 28 * 28; n++) { | ||
| let index = n * 4; | ||
| let r = img.bitmap.data[index + 0]; | ||
| // let g = img.bitmap.data[n + 1]; | ||
| // let b = img.bitmap.data[n + 2]; | ||
| rawData[n] = r / 255.0; | ||
| } | ||
| allImages[i] = rawData; | ||
| } | ||
| return allImages; | ||
| // save the images to disk | ||
| transformer.toImages(testData.arraySync(), 'org') | ||
| transformer.toImages(autoEncodedImages.arraySync()) | ||
| } | 
 
 
 
 20 changes: 20 additions & 0 deletions
 
 
 
 mnistDataSource.js
 
 
 
 
  
 
 
 
 
 
 
 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
 Learn more about bidirectional Unicode characters
 
 
 
 
 | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| import mnist from 'mnist' | ||
| import tf from '@tensorflow/tfjs-node' | ||
|  | ||
| /** | ||
| * Load data from the mnist data set | ||
| */ | ||
| export class MnistDataSource { | ||
| constructor(countTraining = 1000, countTest = 10) { | ||
| const { training, test } = mnist.set(countTraining, countTest) | ||
| this.training = training.map(x => x.input) | ||
| this.test = test.map(x => x.input) | ||
| } | ||
|  | ||
| getTrainingData() { | ||
| return tf.tensor(this.training) | ||
| } | ||
| getTestData() { | ||
| return tf.tensor(this.test) | ||
| } | ||
| } | 
 
 
 
 67 changes: 67 additions & 0 deletions
 
 
 
 model.js
 
 
 
 
  
 
 
 
 
 
 
 This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
 Learn more about bidirectional Unicode characters
 
 
 
 
 | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import fs from 'fs' | ||
| import tf from '@tensorflow/tfjs-node' | ||
|  | ||
| /** | ||
| * Abstraction for tfjs | ||
| */ | ||
| export class Model { | ||
| pretrainedModelExists() { | ||
| return fs.existsSync('models/autoencoder/model.json') && | ||
| fs.existsSync('models/encoder/model.json') && | ||
| fs.existsSync('models/decoder/model.json') | ||
| } | ||
|  | ||
| async load() { | ||
| this.autoencoder = await tf.loadLayersModel('file://models/autoencoder/model.json') | ||
| this.encoder = await tf.loadLayersModel('file://models/encoder/model.json') | ||
| this.decoder = await tf.loadLayersModel('file://models/decoder/model.json') | ||
| } | ||
|  | ||
| configure() { | ||
| const encoded = [ | ||
| tf.layers.dense({ units: 128, inputShape: [784], activation: "relu" }), | ||
| tf.layers.dense({ units: 64, activation: "relu" }), | ||
| tf.layers.dense({ units: 32, activation: "relu" }), | ||
| ] | ||
| const decoded = [ | ||
| tf.layers.dense({ units: 64, activation: "relu" }), | ||
| tf.layers.dense({ units: 128, activation: "relu" }), | ||
| tf.layers.dense({ units: 784, activation: "sigmoid" }), | ||
| ] | ||
|  | ||
| this.autoencoder = tf.sequential({ layers: [...encoded, ...decoded] }) | ||
| this.encoder = tf.sequential({ layers: encoded }) | ||
|  | ||
| const encoded_input = tf.layers.inputLayer({ inputShape: [32] }) | ||
| this.decoder = tf.sequential({ layers: [encoded_input, ...decoded] }) | ||
|  | ||
| this.autoencoder.compile({ | ||
| optimizer: 'adam', | ||
| loss: 'binaryCrossentropy', | ||
| }) | ||
| } | ||
|  | ||
| async train(x_train, epochs = 100) { | ||
| await this.autoencoder.fit(x_train, x_train, { | ||
| epochs, | ||
| batchSize: 32, | ||
| shuffle: true, | ||
| }) | ||
| fs.mkdirSync('models') | ||
| await this.autoencoder.save('file://models/autoencoder') | ||
| await this.encoder.save('file://models/encoder') | ||
| await this.decoder.save('file://models/decoder') | ||
| } | ||
|  | ||
| autoencode(data) { | ||
| return this.decode(this.encode(data)) | ||
| } | ||
|  | ||
| encode (data) { | ||
| return this.encoder.predict(data) | ||
| } | ||
|  | ||
| decode(encoded) { | ||
| return this.decoder.predict(encoded) | ||
| } | ||
| } | 
 
 Oops, something went wrong.
 
 
 
 Add this suggestion to a batch that can be applied as a single commit.
 This suggestion is invalid because no changes were made to the code.
 Suggestions cannot be applied while the pull request is closed.
 Suggestions cannot be applied while viewing a subset of changes.
 Only one suggestion per line can be applied in a batch.
 Add this suggestion to a batch that can be applied as a single commit.
 Applying suggestions on deleted lines is not supported.
 You must change the existing code in this line in order to create a valid suggestion.
 Outdated suggestions cannot be applied.
 This suggestion has been applied or marked resolved.
 Suggestions cannot be applied from pending reviews.
 Suggestions cannot be applied on multi-line comments.
 Suggestions cannot be applied while the pull request is queued to merge.
 Suggestion cannot be applied right now. Please check back later.