Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d7e51b8

Browse files
committed
Changed the get_prices function to use pandas DataReader because
the yahoo library doesn't work any more. Also changed to tf.global_variables_initializer() to get rid of the deprecation warning
1 parent 4461953 commit d7e51b8

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

‎ch08_rl/prices.png

-1000 Bytes
Loading[フレーム]

‎ch08_rl/rl.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from yahoo_finance import Share
21
from matplotlib import pyplot as plt
32
import numpy as np
43
import random
54
import tensorflow as tf
65
import random
7-
6+
import pandas as pd
7+
pd.core.common.is_list_like = pd.api.types.is_list_like
8+
from pandas_datareader import data
9+
import datetime
10+
import requests_cache
811

912
class DecisionPolicy:
1013
def select_action(self, current_state, step):
@@ -43,7 +46,7 @@ def __init__(self, actions, input_dim):
4346
loss = tf.square(self.y - self.q)
4447
self.train_op = tf.train.AdagradOptimizer(0.01).minimize(loss)
4548
self.sess = tf.Session()
46-
self.sess.run(tf.initialize_all_variables())
49+
self.sess.run(tf.global_variables_initializer())
4750

4851
def select_action(self, current_state, step):
4952
threshold = min(self.epsilon, step / 1000.)
@@ -108,17 +111,12 @@ def run_simulations(policy, budget, num_stocks, prices, hist):
108111
return avg, std
109112

110113

111-
def get_prices(share_symbol, start_date, end_date, cache_filename='stock_prices.npy'):
112-
try:
113-
stock_prices = np.load(cache_filename)
114-
except IOError:
115-
share = Share(share_symbol)
116-
stock_hist = share.get_historical(start_date, end_date)
117-
stock_prices = [stock_price['Open'] for stock_price in stock_hist]
118-
np.save(cache_filename, stock_prices)
119-
120-
return stock_prices
121-
114+
def get_prices(share_symbol, start_date, end_date):
115+
expire_after = datetime.timedelta(days=3)
116+
session = requests_cache.CachedSession(cache_name='cache', backend='sqlite', expire_after=expire_after)
117+
stock_hist = data.DataReader(share_symbol, 'iex', start_date, end_date, session=session)
118+
open_prices = stock_hist['open']
119+
return open_prices.values.tolist()
122120

123121
def plot_prices(prices):
124122
plt.title('Opening stock prices')
@@ -129,7 +127,7 @@ def plot_prices(prices):
129127

130128

131129
if __name__ == '__main__':
132-
prices = get_prices('MSFT', '1992-07-22', '2016-07-22')
130+
prices = get_prices('MSFT', '2013-07-22', '2018-07-22')
133131
plot_prices(prices)
134132
actions = ['Buy', 'Sell', 'Hold']
135133
hist = 200

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /