diff --git a/Season1/4-6/load.py b/Season1/4-6/load.py index 0a67ffe..708c1be 100644 --- a/Season1/4-6/load.py +++ b/Season1/4-6/load.py @@ -69,11 +69,14 @@ def distribution(labels, name): def inspect(dataset, labels, i): # 显示图片看看 + print(labels[i]) + ''' if dataset.shape[3] == 1: shape = dataset.shape dataset = dataset.reshape(shape[0], shape[1], shape[2]) - print(labels[i]) plt.imshow(dataset[i]) + '''#可以改为以下 + plt.imshow(dataset[i].squeeze()) plt.show()