Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit abab5ab

Browse files
author
codebasics
committed
word embeddings
1 parent 409e7e2 commit abab5ab

File tree

1 file changed

+317
-0
lines changed

1 file changed

+317
-0
lines changed
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 19,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np\n",
10+
"from tensorflow.keras.preprocessing.text import one_hot\n",
11+
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
12+
"from tensorflow.keras.models import Sequential\n",
13+
"from tensorflow.keras.layers import Dense\n",
14+
"from tensorflow.keras.layers import Flatten\n",
15+
"from tensorflow.keras.layers import Embedding\n",
16+
"\n",
17+
"reviews = ['nice food',\n",
18+
" 'amazing restaurant',\n",
19+
" 'too good',\n",
20+
" 'just loved it!',\n",
21+
" 'will go again',\n",
22+
" 'horrible food',\n",
23+
" 'never go there',\n",
24+
" 'poor service',\n",
25+
" 'poor quality',\n",
26+
" 'needs improvement']\n",
27+
"\n",
28+
"sentiment = np.array([1,1,1,1,1,0,0,0,0,0])"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 20,
34+
"metadata": {},
35+
"outputs": [
36+
{
37+
"data": {
38+
"text/plain": [
39+
"[4, 23]"
40+
]
41+
},
42+
"execution_count": 20,
43+
"metadata": {},
44+
"output_type": "execute_result"
45+
}
46+
],
47+
"source": [
48+
"one_hot(\"amazing restaurant\",30)"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 21,
54+
"metadata": {},
55+
"outputs": [
56+
{
57+
"name": "stdout",
58+
"output_type": "stream",
59+
"text": [
60+
"[[13, 21], [4, 23], [14, 17], [8, 15, 16], [22, 15, 29], [8, 21], [26, 15, 24], [16, 4], [16, 12], [4, 29]]\n"
61+
]
62+
}
63+
],
64+
"source": [
65+
"vocab_size = 30\n",
66+
"encoded_reviews = [one_hot(d, vocab_size) for d in reviews]\n",
67+
"print(encoded_reviews)"
68+
]
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": 22,
73+
"metadata": {},
74+
"outputs": [
75+
{
76+
"name": "stdout",
77+
"output_type": "stream",
78+
"text": [
79+
"[[13 21 0 0]\n",
80+
" [ 4 23 0 0]\n",
81+
" [14 17 0 0]\n",
82+
" [ 8 15 16 0]\n",
83+
" [22 15 29 0]\n",
84+
" [ 8 21 0 0]\n",
85+
" [26 15 24 0]\n",
86+
" [16 4 0 0]\n",
87+
" [16 12 0 0]\n",
88+
" [ 4 29 0 0]]\n"
89+
]
90+
}
91+
],
92+
"source": [
93+
"max_length = 4\n",
94+
"padded_reviews = pad_sequences(encoded_reviews, maxlen=max_length, padding='post')\n",
95+
"print(padded_reviews)"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 23,
101+
"metadata": {
102+
"scrolled": true
103+
},
104+
"outputs": [],
105+
"source": [
106+
"embeded_vector_size = 5\n",
107+
"\n",
108+
"model = Sequential()\n",
109+
"model.add(Embedding(vocab_size, embeded_vector_size, input_length=max_length,name=\"embedding\"))\n",
110+
"model.add(Flatten())\n",
111+
"model.add(Dense(1, activation='sigmoid'))"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": 24,
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"X = padded_reviews\n",
121+
"y = sentiment"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": 25,
127+
"metadata": {},
128+
"outputs": [
129+
{
130+
"name": "stdout",
131+
"output_type": "stream",
132+
"text": [
133+
"Model: \"sequential_1\"\n",
134+
"_________________________________________________________________\n",
135+
"Layer (type) Output Shape Param # \n",
136+
"=================================================================\n",
137+
"embedding (Embedding) (None, 4, 5) 150 \n",
138+
"_________________________________________________________________\n",
139+
"flatten_1 (Flatten) (None, 20) 0 \n",
140+
"_________________________________________________________________\n",
141+
"dense_1 (Dense) (None, 1) 21 \n",
142+
"=================================================================\n",
143+
"Total params: 171\n",
144+
"Trainable params: 171\n",
145+
"Non-trainable params: 0\n",
146+
"_________________________________________________________________\n",
147+
"None\n"
148+
]
149+
}
150+
],
151+
"source": [
152+
"model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
153+
"print(model.summary())"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": 26,
159+
"metadata": {
160+
"scrolled": true
161+
},
162+
"outputs": [
163+
{
164+
"data": {
165+
"text/plain": [
166+
"<tensorflow.python.keras.callbacks.History at 0x1bb8daa5a30>"
167+
]
168+
},
169+
"execution_count": 26,
170+
"metadata": {},
171+
"output_type": "execute_result"
172+
}
173+
],
174+
"source": [
175+
"model.fit(X, y, epochs=50, verbose=0)"
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": 29,
181+
"metadata": {},
182+
"outputs": [
183+
{
184+
"name": "stdout",
185+
"output_type": "stream",
186+
"text": [
187+
"1/1 [==============================] - 0s 1ms/step - loss: 0.6384 - accuracy: 1.0000\n"
188+
]
189+
},
190+
{
191+
"data": {
192+
"text/plain": [
193+
"1.0"
194+
]
195+
},
196+
"execution_count": 29,
197+
"metadata": {},
198+
"output_type": "execute_result"
199+
}
200+
],
201+
"source": [
202+
"# evaluate the model\n",
203+
"loss, accuracy = model.evaluate(X, y)\n",
204+
"accuracy"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": 30,
210+
"metadata": {},
211+
"outputs": [
212+
{
213+
"data": {
214+
"text/plain": [
215+
"30"
216+
]
217+
},
218+
"execution_count": 30,
219+
"metadata": {},
220+
"output_type": "execute_result"
221+
}
222+
],
223+
"source": [
224+
"weights = model.get_layer('embedding').get_weights()[0]\n",
225+
"len(weights)"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": 31,
231+
"metadata": {},
232+
"outputs": [
233+
{
234+
"data": {
235+
"text/plain": [
236+
"array([-0.08330977, -0.06752131, -0.04629624, -0.00765801, -0.02024159],\n",
237+
" dtype=float32)"
238+
]
239+
},
240+
"execution_count": 31,
241+
"metadata": {},
242+
"output_type": "execute_result"
243+
}
244+
],
245+
"source": [
246+
"weights[13]"
247+
]
248+
},
249+
{
250+
"cell_type": "code",
251+
"execution_count": 32,
252+
"metadata": {
253+
"scrolled": false
254+
},
255+
"outputs": [
256+
{
257+
"data": {
258+
"text/plain": [
259+
"array([-0.07935128, -0.08574004, 0.06615968, -0.02349528, 0.00917289],\n",
260+
" dtype=float32)"
261+
]
262+
},
263+
"execution_count": 32,
264+
"metadata": {},
265+
"output_type": "execute_result"
266+
}
267+
],
268+
"source": [
269+
"weights[4]"
270+
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": 33,
275+
"metadata": {
276+
"scrolled": true
277+
},
278+
"outputs": [
279+
{
280+
"data": {
281+
"text/plain": [
282+
"array([ 0.0128377 , 0.03549778, 0.05134471, -0.07147218, 0.03261041],\n",
283+
" dtype=float32)"
284+
]
285+
},
286+
"execution_count": 33,
287+
"metadata": {},
288+
"output_type": "execute_result"
289+
}
290+
],
291+
"source": [
292+
"weights[16]"
293+
]
294+
}
295+
],
296+
"metadata": {
297+
"kernelspec": {
298+
"display_name": "Python 3",
299+
"language": "python",
300+
"name": "python3"
301+
},
302+
"language_info": {
303+
"codemirror_mode": {
304+
"name": "ipython",
305+
"version": 3
306+
},
307+
"file_extension": ".py",
308+
"mimetype": "text/x-python",
309+
"name": "python",
310+
"nbconvert_exporter": "python",
311+
"pygments_lexer": "ipython3",
312+
"version": "3.8.5"
313+
}
314+
},
315+
"nbformat": 4,
316+
"nbformat_minor": 4
317+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /