Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 701c772

Browse files
3d graph
1 parent dd925a7 commit 701c772

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

‎Summer20/NeuralNetwork/tf3d.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from tensorflow.keras import datasets, models, layers, losses
2+
import tensorflow as tf
3+
from mpl_toolkits import mplot3d
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
import random
7+
8+
def cone(x,y):
9+
return np.sqrt(x**2 + y**2)
10+
11+
def ripple(x,y):
12+
return np.sin(10 * (x**2 + y**2)) / 10
13+
14+
def makeTuple(X,Y):
15+
inputList = []
16+
for index, value in enumerate(X):
17+
for index1, value1 in enumerate(value):
18+
inputList.append([value1, Y[index][index1]])
19+
return inputList
20+
21+
def unpackTuple(A):
22+
X = []
23+
Y = []
24+
for item in A:
25+
X.append([item[0]])
26+
Y.append([item[1]])
27+
return X, Y
28+
29+
def makeArray(Z):
30+
zList = []
31+
for subList in Z:
32+
for value in subList:
33+
zList.append([value])
34+
return zList
35+
36+
def randomPoints(number, bounds):
37+
inputList = []
38+
outputList = []
39+
while(number > 0):
40+
value1 = random.uniform(bounds[0],bounds[1])
41+
value2 = random.uniform(bounds[0],bounds[1])
42+
inputList.append([value1, value2])
43+
outputList.append([ripple(value1, value2)])
44+
number = number -1
45+
return inputList, outputList
46+
47+
bounds = (-1,1)
48+
inputList, outputList = randomPoints(50000, bounds)
49+
X_Train, Y_Train = unpackTuple(inputList)
50+
51+
model = models.Sequential()
52+
model.add(layers.Dense(32, activation='exponential', input_shape=(2,)))
53+
model.add(layers.Dense(48, activation='tanh'))
54+
model.add(layers.Dense(1, activation=None))
55+
model.compile(optimizer='Adam',
56+
loss=losses.MeanSquaredError(),
57+
metrics=['mean_squared_error'])
58+
59+
history = model.fit(np.array(inputList),np.array(outputList), epochs=300)
60+
#print(model.get_weights())
61+
62+
63+
# plots out learning curve
64+
# plt.plot(history.history['mean_squared_error'], label='mean_squared_error')
65+
# plt.xlabel('Epoch')
66+
# plt.ylabel('MSE')
67+
# plt.ylim([0.0, 0.2])
68+
# plt.legend(loc='lower right')
69+
# plt.show()
70+
71+
# generate test data
72+
inputTest, outputTest = randomPoints(10, bounds)
73+
X_Test, Y_Test = unpackTuple(inputTest)
74+
print(model.predict(np.array(inputTest)))
75+
print(outputTest)
76+
77+
x = np.linspace(-1, 1, 800)
78+
y = np.linspace(-1, 1, 800)
79+
80+
X, Y = np.meshgrid(x, y)
81+
Z = ripple(X, Y)
82+
83+
fig = plt.figure()
84+
ax = plt.axes(projection="3d")
85+
86+
ax.plot_wireframe(X, Y, Z, color='c')
87+
ax.scatter3D(X_Test, Y_Test, model.predict(np.array(inputTest)), c='r')
88+
ax.set_xlabel('x')
89+
ax.set_ylabel('y')
90+
ax.set_zlabel('z')
91+
92+
plt.show()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /