Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 68296f7

Browse files
committed
wrap image write in try
This sometimes fails due to "Unsupported format" error. Report the exception and move on.
1 parent 6aef4c8 commit 68296f7

File tree

1 file changed

+57
-24
lines changed

1 file changed

+57
-24
lines changed

‎train.py‎

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,17 @@ def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeake
409409
global_step, idx, speaker_str))
410410
save_alignment(path, alignment)
411411
tag = "eval_averaged_alignment_{}_{}".format(idx, speaker_str)
412-
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
412+
try:
413+
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
414+
except Exception as e:
415+
warn(str(e))
413416

414417
# Mel
415-
writer.add_image("(Eval) Predicted mel spectrogram text{}_{}".format(idx, speaker_str),
416-
prepare_spec_image(mel), global_step)
418+
try:
419+
writer.add_image("(Eval) Predicted mel spectrogram text{}_{}".format(idx, speaker_str),
420+
prepare_spec_image(mel), global_step)
421+
except Exception as e:
422+
warn(str(e))
417423

418424
# Audio
419425
path = join(eval_output_dir, "step{:09d}_text{}_{}_predicted.wav".format(
@@ -442,44 +448,63 @@ def save_states(global_step, writer, mel_outputs, linear_outputs, attn, mel, y,
442448
for i, alignment in enumerate(attn):
443449
alignment = alignment[idx].cpu().data.numpy()
444450
tag = "alignment_layer{}".format(i + 1)
445-
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
446-
447-
# save files as well for now
448-
alignment_dir = join(checkpoint_dir, "alignment_layer{}".format(i + 1))
449-
os.makedirs(alignment_dir, exist_ok=True)
450-
path = join(alignment_dir, "step{:09d}_layer_{}_alignment.png".format(
451-
global_step, i + 1))
452-
save_alignment(path, alignment)
451+
try:
452+
writer.add_image(tag, np.uint8(cm.viridis(
453+
np.flip(alignment, 1).T) * 255), global_step)
454+
# save files as well for now
455+
alignment_dir = join(
456+
checkpoint_dir, "alignment_layer{}".format(i + 1))
457+
os.makedirs(alignment_dir, exist_ok=True)
458+
path = join(alignment_dir, "step{:09d}_layer_{}_alignment.png".format(
459+
global_step, i + 1))
460+
save_alignment(path, alignment)
461+
except Exception as e:
462+
warn(str(e))
453463

454464
# Save averaged alignment
455465
alignment_dir = join(checkpoint_dir, "alignment_ave")
456466
os.makedirs(alignment_dir, exist_ok=True)
457-
path = join(alignment_dir, "step{:09d}_alignment.png".format(global_step))
467+
path = join(alignment_dir, "step{:09d}_layer_alignment.png".format(global_step))
458468
alignment = attn.mean(0)[idx].cpu().data.numpy()
459469
save_alignment(path, alignment)
460-
461470
tag = "averaged_alignment"
462-
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
471+
472+
try:
473+
writer.add_image(tag, np.uint8(cm.viridis(
474+
np.flip(alignment, 1).T) * 255), global_step)
475+
except Exception as e:
476+
warn(str(e))
463477

464478
# Predicted mel spectrogram
465479
if mel_outputs is not None:
466480
mel_output = mel_outputs[idx].cpu().data.numpy()
467481
mel_output = prepare_spec_image(audio._denormalize(mel_output))
468-
writer.add_image("Predicted mel spectrogram", mel_output, global_step)
482+
try:
483+
writer.add_image("Predicted mel spectrogram",
484+
mel_output, global_step)
485+
except Exception as e:
486+
warn(str(e))
487+
pass
469488

470489
# Predicted spectrogram
471490
if linear_outputs is not None:
472491
linear_output = linear_outputs[idx].cpu().data.numpy()
473492
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
474-
writer.add_image("Predicted linear spectrogram", spectrogram, global_step)
493+
try:
494+
writer.add_image("Predicted linear spectrogram",
495+
spectrogram, global_step)
496+
except Exception as e:
497+
warn(str(e))
498+
pass
475499

476500
# Predicted audio signal
477501
signal = audio.inv_spectrogram(linear_output.T)
478502
signal /= np.max(np.abs(signal))
479503
path = join(checkpoint_dir, "step{:09d}_predicted.wav".format(
480504
global_step))
481505
try:
482-
writer.add_audio("Predicted audio signal", signal, global_step, sample_rate=hparams.sample_rate)
506+
writer.add_audio("Predicted audio signal", signal,
507+
global_step, sample_rate=hparams.sample_rate)
483508
except Exception as e:
484509
warn(str(e))
485510
pass
@@ -489,13 +514,22 @@ def save_states(global_step, writer, mel_outputs, linear_outputs, attn, mel, y,
489514
if mel_outputs is not None:
490515
mel_output = mel[idx].cpu().data.numpy()
491516
mel_output = prepare_spec_image(audio._denormalize(mel_output))
492-
writer.add_image("Target mel spectrogram", mel_output, global_step)
517+
try:
518+
writer.add_image("Target mel spectrogram", mel_output, global_step)
519+
except Exception as e:
520+
warn(str(e))
521+
pass
493522

494523
# Target spectrogram
495524
if linear_outputs is not None:
496525
linear_output = y[idx].cpu().data.numpy()
497526
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
498-
writer.add_image("Target linear spectrogram", spectrogram, global_step)
527+
try:
528+
writer.add_image("Target linear spectrogram",
529+
spectrogram, global_step)
530+
except Exception as e:
531+
warn(str(e))
532+
pass
499533

500534

501535
def logit(x, eps=1e-8):
@@ -712,7 +746,8 @@ def train(device, model, data_loader, optimizer, writer,
712746
train_seq2seq, train_postnet)
713747

714748
if global_step > 0 and global_step % hparams.eval_interval == 0:
715-
eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker)
749+
eval_model(global_step, writer, device, model,
750+
checkpoint_dir, ismultispeaker)
716751

717752
# Update
718753
loss.backward()
@@ -731,8 +766,7 @@ def train(device, model, data_loader, optimizer, writer,
731766
if train_postnet:
732767
writer.add_scalar("linear_loss", float(linear_loss.item()), global_step)
733768
writer.add_scalar("linear_l1_loss", float(linear_l1_loss.item()), global_step)
734-
writer.add_scalar("linear_binary_div_loss", float(
735-
linear_binary_div.item()), global_step)
769+
writer.add_scalar("linear_binary_div_loss", float(linear_binary_div.item()), global_step)
736770
if train_seq2seq and hparams.use_guided_attention:
737771
writer.add_scalar("attn_loss", float(attn_loss.item()), global_step)
738772
if clip_thresh > 0:
@@ -963,8 +997,7 @@ def restore_parts(path, model):
963997
# Setup summary writer for tensorboard
964998
if log_event_path is None:
965999
if platform.system() == "Windows":
966-
log_event_path = "log/run-test" + \
967-
str(datetime.now()).replace(" ", "_").replace(":", "_")
1000+
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_").replace(":", "_")
9681001
else:
9691002
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_")
9701003
print("Log event path: {}".format(log_event_path))

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /