88import re
99import logging
1010import sys
11+ import math
1112import unicodedata
1213import codecs
1314
@@ -324,40 +325,48 @@ def make_sequences_same_length(sequences, sequence_lengths, default_value=0.0):
324325
325326
326327def main (argv ):
327- # Read text file.
328- text_file_path = TRAIN_DIR + "211-122425-0059.txt"
329- text = read_text_file (text_file_path )
330- text = normalize_text (text )
331- 332- # Read audio file.
333- wav_file_path = TRAIN_DIR + "211-122425-0059.wav"
334- audio_rate , audio_data = wav .read (wav_file_path )
335- inputs = mfcc (audio_data , samplerate = audio_rate )
336- 337- # Make text as as char array.
338- labels = make_char_array (text , SPACE_TOKEN )
339- labels = np .asarray ([SPACE_INDEX if x == SPACE_TOKEN else ord (x ) - FIRST_INDEX for x in labels ])
340- 341- # Labels sparse representation for feeding the placeholder.
342- train_labels = sparse_tuples_from_sequences ([labels ])
343- 344- # Train inputs.
345- train_inputs = np .asarray (inputs [np .newaxis , :])
328+ # Read train data files.
329+ train_texts = read_text_files (TRAIN_DIR )
330+ train_labels = texts_encoder (train_texts ,
331+ first_index = FIRST_INDEX ,
332+ space_index = SPACE_INDEX ,
333+ space_token = SPACE_TOKEN )
334+ train_inputs = read_audio_files (TRAIN_DIR )
346335 train_inputs = standardize_audios (train_inputs )
347- train_sequence_length = [train_inputs .shape [1 ]]
348- 349- # TODO(ugnelis): define different validation variables.
350- validation_inputs = train_inputs
351- validation_labels = train_labels
352- validation_sequence_length = train_sequence_length
336+ train_sequence_lengths = get_sequence_lengths (train_inputs )
337+ train_inputs = make_sequences_same_length (train_inputs , train_sequence_lengths )
338+ 339+ # Read validation data files.
340+ validation_texts = read_text_files (DEV_DIR )
341+ validation_labels = texts_encoder (validation_texts ,
342+ first_index = FIRST_INDEX ,
343+ space_index = SPACE_INDEX ,
344+ space_token = SPACE_TOKEN )
345+ validation_labels = sparse_tuples_from_sequences (validation_labels )
346+ validation_inputs = read_audio_files (DEV_DIR )
347+ validation_inputs = standardize_audios (validation_inputs )
348+ validation_sequence_lengths = get_sequence_lengths (validation_inputs )
349+ validation_inputs = make_sequences_same_length (validation_inputs , validation_sequence_lengths )
350+ 351+ # Read test data files.
352+ test_texts = read_text_files (TEST_DIR )
353+ test_labels = texts_encoder (test_texts ,
354+ first_index = FIRST_INDEX ,
355+ space_index = SPACE_INDEX ,
356+ space_token = SPACE_TOKEN )
357+ test_labels = sparse_tuples_from_sequences (test_labels )
358+ test_inputs = read_audio_files (DEV_DIR )
359+ test_inputs = standardize_audios (test_inputs )
360+ test_sequence_lengths = get_sequence_lengths (test_inputs )
361+ test_inputs = make_sequences_same_length (test_inputs , test_sequence_lengths )
353362
354363 with tf .device ('/gpu:0' ):
355364 config = tf .ConfigProto (allow_soft_placement = True )
356365 config .gpu_options .allow_growth = True
357366
358367 graph = tf .Graph ()
359368 with graph .as_default ():
360- 369+ logging . debug ( "Starting new TensorFlow graph." )
361370 inputs_placeholder = tf .placeholder (tf .float32 , [None , None , NUM_FEATURES ])
362371
363372 # SparseTensor placeholder required by ctc_loss op.
@@ -382,13 +391,13 @@ def main(argv):
382391 # Reshaping to apply the same weights over the time steps.
383392 outputs = tf .reshape (outputs , [- 1 , NUM_HIDDEN ])
384393
385- weights = tf .Variable (tf .truncated_normal ([NUM_HIDDEN ,
394+ weigths = tf .Variable (tf .truncated_normal ([NUM_HIDDEN ,
386395 NUM_CLASSES ],
387396 stddev = 0.1 ))
388397 biases = tf .Variable (tf .constant (0. , shape = [NUM_CLASSES ]))
389398
390399 # Doing the affine projection.
391- logits = tf .matmul (outputs , weights ) + biases
400+ logits = tf .matmul (outputs , weigths ) + biases
392401
393402 # Reshaping back to the original shape.
394403 logits = tf .reshape (logits , [batch_s , - 1 , NUM_CLASSES ])
@@ -402,53 +411,83 @@ def main(argv):
402411 optimizer = tf .train .MomentumOptimizer (INITIAL_LEARNING_RATE , 0.9 ).minimize (cost )
403412
404413 # CTC decoder.
405- decoded , log_prob = tf .nn .ctc_greedy_decoder (logits , sequence_length_placeholder )
414+ decoded , neg_sum_logits = tf .nn .ctc_greedy_decoder (logits , sequence_length_placeholder )
406415
407416 label_error_rate = tf .reduce_mean (tf .edit_distance (tf .cast (decoded [0 ], tf .int32 ),
408417 labels_placeholder ))
409- with tf .Session (graph = graph ) as session :
410- # Initialize the weights and biases.
411- tf .global_variables_initializer ().run ()
412- 413- for current_epoch in range (NUM_EPOCHS ):
414- train_cost = train_label_error_rate = 0
415- start_time = time .time ()
416- 417- for batch in range (NUM_BATCHES_PER_EPOCH ):
418- feed = {inputs_placeholder : train_inputs ,
419- labels_placeholder : train_labels ,
420- sequence_length_placeholder : train_sequence_length }
421- 422- batch_cost , _ = session .run ([cost , optimizer ], feed )
423- train_cost += batch_cost * BATCH_SIZE
424- train_label_error_rate += session .run (label_error_rate , feed_dict = feed ) * BATCH_SIZE
425- 426- train_cost /= NUM_EXAMPLES
427- train_label_error_rate /= NUM_EXAMPLES
428- 429- val_feed = {inputs_placeholder : validation_inputs ,
430- labels_placeholder : validation_labels ,
431- sequence_length_placeholder : validation_sequence_length }
432- 433- validation_cost , validation_label_error_rate = session .run ([cost , label_error_rate ], feed_dict = val_feed )
434- 435- # Output intermediate step information.
436- logging .info ("Epoch %d/%d (time: %.3f s)" ,
437- current_epoch + 1 ,
438- NUM_EPOCHS ,
439- time .time () - start_time )
440- logging .info ("Train cost: %.3f, train label error rate: %.3f" ,
441- train_cost ,
442- train_label_error_rate )
443- logging .info ("Validation cost: %.3f, validation label error rate: %.3f" ,
444- validation_cost ,
445- validation_label_error_rate )
446- 447- # Decoding.
448- decoded_outputs = session .run (decoded [0 ], feed_dict = feed )
449- decoded_text = sequence_decoder (decoded_outputs [1 ])
450- 451- logging .info ("Original:\n %s" , text )
418+ 419+ with tf .Session (graph = graph ) as session :
420+ logging .debug ("Starting TensorFlow session." )
421+ # Initialize the weights and biases.
422+ tf .global_variables_initializer ().run ()
423+ 424+ train_num = train_inputs .shape [0 ]
425+ validation_num = validation_inputs .shape [0 ]
426+ 427+ # Check if there is any example.
428+ if train_num <= 0 :
429+ logging .error ("There are no training examples." )
430+ return
431+ 432+ num_batches_per_epoch = math .ceil (train_num / BATCH_SIZE )
433+ 434+ for current_epoch in range (NUM_EPOCHS ):
435+ train_cost = 0
436+ train_label_error_rate = 0
437+ start_time = time .time ()
438+ 439+ for batch in range (num_batches_per_epoch ):
440+ # Format batches.
441+ if int (train_num / ((batch + 1 ) * BATCH_SIZE )) >= 1 :
442+ indexes = [i % train_num for i in range (batch * BATCH_SIZE , (batch + 1 ) * BATCH_SIZE )]
443+ else :
444+ indexes = [i % train_num for i in range (batch * BATCH_SIZE , train_num )]
445+ 446+ batch_train_inputs = train_inputs [indexes ]
447+ batch_train_sequence_lengths = train_sequence_lengths [indexes ]
448+ batch_train_targets = sparse_tuples_from_sequences (train_labels [indexes ])
449+ 450+ feed = {inputs_placeholder : batch_train_inputs ,
451+ labels_placeholder : batch_train_targets ,
452+ sequence_length_placeholder : batch_train_sequence_lengths }
453+ 454+ batch_cost , _ = session .run ([cost , optimizer ], feed )
455+ train_cost += batch_cost * BATCH_SIZE
456+ train_label_error_rate += session .run (label_error_rate , feed_dict = feed ) * BATCH_SIZE
457+ 458+ train_cost /= train_num
459+ train_label_error_rate /= train_num
460+ 461+ validation_feed = {inputs_placeholder : validation_inputs ,
462+ labels_placeholder : validation_labels ,
463+ sequence_length_placeholder : validation_sequence_lengths }
464+ 465+ validation_cost , validation_label_error_rate = session .run ([cost , label_error_rate ],
466+ feed_dict = validation_feed )
467+ 468+ validation_cost /= validation_num
469+ validation_label_error_rate /= validation_num
470+ 471+ # Output intermediate step information.
472+ print ("Epoch %d/%d (time: %.3f s)" %
473+ (current_epoch + 1 , NUM_EPOCHS , time .time () - start_time ))
474+ print ("Train cost: %.3f, train label error rate: %.3f" %
475+ (train_cost , train_label_error_rate ))
476+ print ("Validation cost: %.3f, validation label error rate: %.3f" %
477+ (validation_cost , validation_label_error_rate ))
478+ 479+ test_feed = {inputs_placeholder : test_inputs ,
480+ sequence_length_placeholder : test_sequence_lengths }
481+ # Decoding.
482+ decoded_outputs = session .run (decoded [0 ], feed_dict = test_feed )
483+ dense_decoded = tf .sparse_tensor_to_dense (decoded_outputs , default_value = - 1 ).eval (session = session )
484+ 485+ for i , sequence in enumerate (dense_decoded ):
486+ sequence = [s for s in sequence if s != - 1 ]
487+ decoded_text = sequence_decoder (sequence )
488+ 489+ print ('Sequence %d' % i )
490+ logging .info ("Original:\n %s" , test_texts [i ])
452491 logging .info ("Decoded:\n %s" , decoded_text )
453492
454493
0 commit comments