Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 96e51a9

Browse files
Merge pull request #1069 from deeplearning4j/ag_fix_1063
Fix up generate text computation graph
2 parents 686db99 + 4838e3e commit 96e51a9

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

‎dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public class GenerateTxtCharCompGraphModel {
5555

5656
@SuppressWarnings("ConstantConditions")
5757
public static void main(String[] args ) throws Exception {
58-
int lstmLayerSize = 200; //Number of units in each LSTM layer
58+
int lstmLayerSize = 77; //Number of units in each LSTM layer
5959
int miniBatchSize = 32; //Size of mini batch to use when training
6060
int exampleLength = 1000; //Length of each training example sequence to use. This could certainly be increased
6161
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
@@ -90,18 +90,20 @@ public static void main(String[] args ) throws Exception {
9090
//Output layer, name "outputlayer" with inputs from the two layers called "first" and "second"
9191
.addLayer("outputLayer", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
9292
.activation(Activation.SOFTMAX)
93-
.nIn(2*lstmLayerSize).nOut(nOut).build(),"first","second")
93+
.nIn(lstmLayerSize).nOut(lstmLayerSize).build(),"second")
9494
.setOutputs("outputLayer") //List the output. For a ComputationGraph with multiple outputs, this also defines the input array orders
95-
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
95+
.backpropType(BackpropType.TruncatedBPTT)
96+
.tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
9697
.build();
9798

9899
ComputationGraph net = new ComputationGraph(conf);
99100
net.init();
100101
net.setListeners(new ScoreIterationListener(1));
102+
System.out.println(net.summary());
101103

102104
//Print the number of parameters in the network (and for each layer)
103105
long totalNumParams = 0;
104-
for( int i=0; i<net.getNumLayers(); i++ ){
106+
for( int i = 0; i < net.getNumLayers(); i++) {
105107
long nParams = net.getLayer(i).numParams();
106108
System.out.println("Number of parameters in layer " + i + ": " + nParams);
107109
totalNumParams += nParams;
@@ -110,16 +112,18 @@ public static void main(String[] args ) throws Exception {
110112

111113
//Do training, and then generate and print samples from network
112114
int miniBatchNumber = 0;
113-
for( int i=0; i<numEpochs; i++ ){
115+
for( int i = 0; i < numEpochs; i++) {
114116
while(iter.hasNext()){
115117
DataSet ds = iter.next();
118+
System.out.println("Input shape " + ds.getFeatures().shapeInfoToString());
119+
System.out.println("Labels " + ds.getLabels().shapeInfoToString());
116120
net.fit(ds);
117121
if(++miniBatchNumber % generateSamplesEveryNMinibatches == 0){
118122
System.out.println("--------------------");
119123
System.out.println("Completed " + miniBatchNumber + " minibatches of size " + miniBatchSize + "x" + exampleLength + " characters" );
120124
System.out.println("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\"");
121125
String[] samples = sampleCharactersFromNetwork(generationInitialization,net,iter,rng,nCharactersToSample,nSamplesToGenerate);
122-
for( int j=0; j<samples.length; j++ ){
126+
for( int j = 0; j < samples.length; j++) {
123127
System.out.println("----- Sample " + j + " -----");
124128
System.out.println(samples[j]);
125129
System.out.println();
@@ -135,7 +139,7 @@ public static void main(String[] args ) throws Exception {
135139

136140
/** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization
137141
* can be used to 'prime' the RNN with a sequence you want to extend/continue.<br>
138-
* Note that the initalization is used for all samples
142+
* Note that the initialization is used for all samples
139143
* @param initialization String, may be null. If null, select a random character as initialization for all samples
140144
* @param charactersToSample Number of characters to sample from network (excluding initialization)
141145
* @param net MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer
@@ -151,9 +155,9 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
151155
//Create input for initialization
152156
INDArray initializationInput = Nd4j.zeros(numSamples, iter.inputColumns(), initialization.length());
153157
char[] init = initialization.toCharArray();
154-
for( int i=0; i<init.length; i++ ){
158+
for( int i=0; i<init.length; i++) {
155159
int idx = iter.convertCharacterToIndex(init[i]);
156-
for( int j=0; j<numSamples; j++ ){
160+
for( int j = 0; j<numSamples; j++ ){
157161
initializationInput.putScalar(new int[]{j,idx,i}, 1.0f);
158162
}
159163
}
@@ -167,13 +171,13 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
167171
INDArray output = net.rnnTimeStep(initializationInput)[0];
168172
output = output.tensorAlongDimension((int)output.size(2)-1,1,0); //Gets the last time step output
169173

170-
for( int i=0; i<charactersToSample; i++ ){
174+
for( int i = 0; i < charactersToSample; i++ ){
171175
//Set up next input (single time step) by sampling from previous output
172176
INDArray nextInput = Nd4j.zeros(numSamples,iter.inputColumns());
173177
//Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
174178
for( int s=0; s<numSamples; s++ ){
175179
double[] outputProbDistribution = new double[iter.totalOutcomes()];
176-
for( int j=0; j<outputProbDistribution.length; j++) outputProbDistribution[j] = output.getDouble(s,j);
180+
for( int j = 0; j < outputProbDistribution.length; j++) outputProbDistribution[j] = output.getDouble(s,j);
177181
int sampledCharacterIdx = GenerateTxtModel.sampleFromDistribution(outputProbDistribution,rng);
178182

179183
nextInput.putScalar(new int[]{s,sampledCharacterIdx}, 1.0f); //Prepare next time step input

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /