Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d0a5b20

Browse files
Merge pull request #1070 from DonaldAlan/meldoy4j-review-971
Meldoy4j review 971
2 parents 96e51a9 + d7adac9 commit d0a5b20

File tree

9 files changed

+439
-251
lines changed

9 files changed

+439
-251
lines changed
Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*******************************************************************************
2-
*
3-
*
42
*
53
* This program and the accompanying materials are made available under the
64
* terms of the Apache License, Version 2.0 which is available at
@@ -17,7 +15,7 @@
1715
* SPDX-License-Identifier: Apache-2.0
1816
******************************************************************************/
1917

20-
package org.deeplearning4j.examples.wip.advanced.modelling.melodl4j;
18+
package org.deeplearning4j.examples.advanced.modelling.charmodelling.melodl4j;
2119

2220
import org.apache.commons.io.FileUtils;
2321
import org.deeplearning4j.examples.advanced.modelling.charmodelling.utils.CharacterIterator;
@@ -31,39 +29,56 @@
3129
import org.deeplearning4j.nn.weights.WeightInit;
3230
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
3331
import org.deeplearning4j.util.ModelSerializer;
32+
import org.nd4j.common.util.ArchiveUtils;
3433
import org.nd4j.linalg.activations.Activation;
3534
import org.nd4j.linalg.api.ndarray.INDArray;
3635
import org.nd4j.linalg.dataset.DataSet;
3736
import org.nd4j.linalg.factory.Nd4j;
37+
import org.nd4j.linalg.learning.config.Adam;
3838
import org.nd4j.linalg.learning.config.RmsProp;
3939
import org.nd4j.linalg.lossfunctions.LossFunctions;
4040

41+
import javax.sound.midi.InvalidMidiDataException;
4142
import java.io.*;
4243
import java.net.URL;
4344
import java.nio.charset.Charset;
45+
import java.nio.file.Files;
46+
import java.nio.file.Path;
47+
import java.text.NumberFormat;
4448
import java.util.ArrayList;
4549
import java.util.List;
4650
import java.util.Random;
51+
import java.util.zip.ZipEntry;
52+
import java.util.zip.ZipInputStream;
4753

4854
/**
4955
* LSTM Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI.
50-
* Based closely on LSTMCharModellingExample.java.
56+
* LSTM logic is based closely on LSTMCharModellingExample.java.
5157
* See the README file in this directory for documentation.
5258
*
5359
* @author Alex Black, Donald A. Smith.
5460
*/
5561
public class MelodyModelingExample {
56-
final static String inputSymbolicMelodiesFilename = "bach-melodies-input.txt";
57-
// Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
62+
// If you want to change the MIDI files used in learning, create a zip file containing your MIDI
63+
// files and replace the following path. For example, you might use something like:
64+
//final static String midiFileZipFileUrlPath = "file:d:/music/midi/classical-midi.zip";
65+
final static String midiFileZipFileUrlPath = "http://waliberals.org/truthsite/music/bach-midi.zip";
5866

59-
final static String tmpDir = System.getProperty("java.io.tmpdir");
67+
// For example "bach-midi.txt"
68+
final static String inputSymbolicMelodiesFilename = getMelodiesFileNameFromURLPath(midiFileZipFileUrlPath);
6069

61-
final static String symbolicMelodiesInputFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
70+
// Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
71+
final static String tmpDir = System.getProperty("java.io.tmpdir");
72+
final static String inputSymbolicMelodiesFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
6273
final static String composedMelodiesOutputFilePath = tmpDir + "/composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
6374

6475
//final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
6576
//final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
66-
77+
final static NumberFormat numberFormat = NumberFormat.getNumberInstance();
78+
static {
79+
numberFormat.setMinimumFractionDigits(1);
80+
numberFormat.setMaximumFractionDigits(1);
81+
}
6782
//....
6883
public static void main(String[] args) throws Exception {
6984
String loadNetworkPath = null; //"/tmp/MelodyModel-bach.zip"; //null;
@@ -73,6 +88,8 @@ public static void main(String[] args) throws Exception {
7388
generationInitialization = args[1];
7489
}
7590

91+
makeMidiStringFileIfNecessary();
92+
7693
int lstmLayerSize = 200; //Number of units in each LSTM layer
7794
int miniBatchSize = 32; //Size of mini batch to use when training
7895
int exampleLength = 500; //1000; //Length of each training example sequence to use.
@@ -107,9 +124,10 @@ public static void main(String[] args) throws Exception {
107124

108125
//Set up network configuration:
109126
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
110-
.updater(new RmsProp(0.1))
111-
.seed(12345)
112-
.l2(0.001)
127+
//.updater(new RmsProp(0.1))
128+
.updater(new Adam(0.005))
129+
.seed(System.currentTimeMillis()) // So each run generates new melodies
130+
.l2(0.0001)
113131
.weightInit(WeightInit.XAVIER)
114132
.list()
115133
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
@@ -123,7 +141,6 @@ public static void main(String[] args) throws Exception {
123141
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
124142
.build();
125143

126-
127144
learn(miniBatchSize, exampleLength, numEpochs, generateSamplesEveryNMinibatches, nSamplesToGenerate, nCharactersToSample, generationInitialization, rng, startTime, iter, conf);
128145
}
129146

@@ -154,6 +171,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
154171
// order, so that the best melodies are at the start of the file.
155172
//Do training, and then generate and print samples from network
156173
int miniBatchNumber = 0;
174+
long lastTime = System.currentTimeMillis();
157175
for (int epoch = 0; epoch < numEpochs; epoch++) {
158176
System.out.println("Starting epoch " + epoch);
159177
while (iter.hasNext()) {
@@ -176,12 +194,19 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
176194
}
177195
}
178196
iter.reset(); //Reset iterator for another epoch
197+
final double secondsForEpoch = 0.001 * (System.currentTimeMillis() - startTime);
198+
final long now = System.currentTimeMillis();
179199
if (melodies.size() > 0) {
180200
String melody = melodies.get(melodies.size() - 1);
181201
int seconds = 25;
182202
System.out.println("\nFirst " + seconds + " seconds of " + melody);
183203
PlayMelodyStrings.playMelody(melody, seconds);
184204
}
205+
double seconds = 0.001*(now - lastTime);
206+
lastTime = now;
207+
System.out.println("\nEpoch " + epoch + " time in seconds: " + numberFormat.format(seconds));
208+
// 531.9 for GPU GTX 1070
209+
// 821.4 for CPU i7-6700K @ 4GHZ
185210
}
186211
int indexOfLastPeriod = inputSymbolicMelodiesFilename.lastIndexOf('.');
187212
String saveFileName = inputSymbolicMelodiesFilename.substring(0, indexOfLastPeriod > 0 ? indexOfLastPeriod : inputSymbolicMelodiesFilename.length());
@@ -193,42 +218,82 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
193218
printWriter.println(melodies.get(i));
194219
}
195220
printWriter.close();
196-
double seconds = 0.001 * (System.currentTimeMillis() - startTime);
197221

198-
System.out.println("\n\nExample complete in " + seconds + " seconds");
199222
System.exit(0);
200223
}
201224

202-
public static void makeSureFileIsInTmpDir(String filename) {
225+
public static File makeSureFileIsInTmpDir(String urlString) throws IOException {
226+
final URL url = new URL(urlString);
227+
final String filename = urlString.substring(1+urlString.lastIndexOf("/"));
203228
final File f = new File(tmpDir + "/" + filename);
204-
if (!f.exists()) {
205-
URL url = null;
206-
try {
207-
url = new URL("http://truthsite.org/music/" + filename);
208-
FileUtils.copyURLToFile(url, f);
209-
} catch (Exception exc) {
210-
System.err.println("Error copying " + url + " to " + f);
211-
throw new RuntimeException(exc);
212-
}
229+
if (f.exists()) {
230+
System.out.println("Using existing " + f.getAbsolutePath());
231+
} else {
232+
FileUtils.copyURLToFile(url, f);
213233
if (!f.exists()) {
214234
throw new RuntimeException(f.getAbsolutePath() + " does not exist");
215235
}
216236
System.out.println("File downloaded to " + f.getAbsolutePath());
217-
} else {
218-
System.out.println("Using existing text file at " + f.getAbsolutePath());
219237
}
238+
return f;
220239
}
221240

241+
//https://stackoverflow.com/questions/10633595/java-zip-how-to-unzip-folder
242+
public static void unzip(File zipFile, File targetDirFile) throws IOException {
243+
InputStream is = new FileInputStream(zipFile);
244+
Path targetDir = targetDirFile.toPath();
245+
targetDir = targetDir.toAbsolutePath();
246+
try (ZipInputStream zipIn = new ZipInputStream(is)) {
247+
for (ZipEntry ze; (ze = zipIn.getNextEntry()) != null; ) {
248+
Path resolvedPath = targetDir.resolve(ze.getName()).normalize();
249+
if (!resolvedPath.startsWith(targetDir)) {
250+
// see: https://snyk.io/research/zip-slip-vulnerability
251+
throw new RuntimeException("Entry with an illegal path: "
252+
+ ze.getName());
253+
}
254+
if (ze.isDirectory()) {
255+
Files.createDirectories(resolvedPath);
256+
} else {
257+
Files.createDirectories(resolvedPath.getParent());
258+
Files.copy(zipIn, resolvedPath);
259+
}
260+
}
261+
}
262+
is.close();
263+
}
264+
private static void makeMidiStringFileIfNecessary() throws IOException, InvalidMidiDataException {
265+
final File inputMelodiesFile = new File(inputSymbolicMelodiesFilePath);
266+
if (inputMelodiesFile.exists() && inputMelodiesFile.length()>1000) {
267+
System.out.println("Using existing " + inputSymbolicMelodiesFilePath);
268+
return;
269+
}
270+
final File midiZipFile = makeSureFileIsInTmpDir(midiFileZipFileUrlPath);
271+
final String midiZipFileName = midiZipFile.getName();
272+
final String midiZipFileNameWithoutSuffix = midiZipFileName.substring(0,midiZipFileName.lastIndexOf("."));
273+
final File outputDirectoryFile = new File(tmpDir,midiZipFileNameWithoutSuffix);
274+
final String outputDirectoryPath = outputDirectoryFile.getAbsolutePath();
275+
if (!outputDirectoryFile.exists()) {
276+
outputDirectoryFile.mkdir();
277+
}
278+
if (!outputDirectoryFile.exists() || !outputDirectoryFile.isDirectory()) {
279+
throw new IllegalStateException(outputDirectoryFile + " is not a directory or can't be created");
280+
}
281+
final PrintStream printStream = new PrintStream(inputSymbolicMelodiesFilePath);
282+
System.out.println("Unzipping "+ midiZipFile.getAbsolutePath() + " to " + outputDirectoryPath);
283+
unzip(midiZipFile, outputDirectoryFile);
284+
System.out.println("Extracted " + midiZipFile.getAbsolutePath() + " to " + outputDirectoryPath);
285+
MidiMelodyExtractor.processDirectoryAndWriteMelodyFile(outputDirectoryFile,inputMelodiesFile);
286+
printStream.close();
287+
}
222288
/**
223289
* Sets up and return a simple DataSetIterator that does vectorization based on the melody sample.
224290
*
225291
* @param miniBatchSize Number of text segments in each training mini-batch
226292
* @param sequenceLength Number of characters in each text segment.
227293
*/
228294
public static CharacterIterator getMidiIterator(int miniBatchSize, int sequenceLength) throws Exception {
229-
makeSureFileIsInTmpDir(inputSymbolicMelodiesFilename);
230295
final char[] validCharacters = MelodyStrings.allValidCharacters.toCharArray(); //Which characters are allowed? Others will be removed
231-
return new CharacterIterator(symbolicMelodiesInputFilePath, Charset.forName("UTF-8"),
296+
return new CharacterIterator(inputSymbolicMelodiesFilePath, Charset.forName("UTF-8"),
232297
miniBatchSize, sequenceLength, validCharacters, new Random(12345), MelodyStrings.COMMENT_STRING);
233298
}
234299

@@ -312,5 +377,13 @@ public static int sampleFromDistribution(double[] distribution, Random rng) {
312377
//Should be extremely unlikely to happen if distribution is a valid probability distribution
313378
throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum);
314379
}
380+
private static String getMelodiesFileNameFromURLPath(String midiFileZipFileUrlPath) {
381+
if (!(midiFileZipFileUrlPath.endsWith(".zip") || midiFileZipFileUrlPath.endsWith(".ZIP"))) {
382+
throw new IllegalStateException("zipFilePath must end with .zip");
383+
}
384+
midiFileZipFileUrlPath = midiFileZipFileUrlPath.replace('\\','/');
385+
String fileName = midiFileZipFileUrlPath.substring(midiFileZipFileUrlPath.lastIndexOf("/") + 1);
386+
return fileName + ".txt";
387+
}
315388
}
316389

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /