Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 40c13b7

Browse files
add stock and transformers.
1 parent d0f7348 commit 40c13b7

27 files changed

+192561
-0
lines changed

‎38stock/LSTM实现股票预测--pytorch 版本-V1.0.ipynb‎

Lines changed: 533 additions & 0 deletions
Large diffs are not rendered by default.

‎38stock/LSTM实现股票预测--pytorch 版本-V2.0.ipynb‎

Lines changed: 706 additions & 0 deletions
Large diffs are not rendered by default.

‎38stock/README.md‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Pytorch Stock LSTM
2+
3+
预测股票价格的简单小程序,LSTM 实现,基于 Pytorch。仅供娱乐,并不实用。
4+
5+
![全局](1.png)
6+
![局部](2.png)

‎38stock/__init__.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@author:XuMing(xuming624@qq.com)
4+
@description:
5+
"""

‎38stock/lstm_demo/README.md‎

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# hs300_stock_predict
2+
该项目用于对hs300股票的预测,包括股票下载,数据清洗,LSTM 模型的训练,测试,以及实时预测。<br>
3+
4+
## 文件构成
5+
1、data_utils.py 用于股票数据下载,清洗,合并等。该文件共有9个函数。
6+
get_stock_data(code, date1, date2, filename, length=-1)<br>
7+
该函数用于下载股票数据,保存开、高、收、低、量、涨跌幅等6维数据。<br>
8+
由于用的tushare接口,因此只能下载最近两年的数据。(从新浪网易财经的数据爬虫接口后续开放)<br>
9+
共有`5个`参数<br>
10+
`code`是需要下载的股票代码,如000001是平安银行的股票代码,输入'000001'既下载平安银行的股票数据。<br>
11+
`date1`是开始日期,格式如"2019年01月03日",`date2`是结束日期,格式同上。<br>
12+
`filename`是存放数据的目录,如"D:\data\"<br>
13+
`length`是筛选股票长度,默认为-1,既对下载的股票数据长度上不做筛选,如果人为指定长度如200,既会将时间长度200以下的数据剔除不予以保存。<br><br>
14+
get_hs300_data(date1, date2, filename)<br>
15+
该函数用于下载沪深300指数数据,参数格式同get_stock_data<br><br>
16+
update_stock_data(filename)<br>
17+
该函数将股票数据从本地文件的最后日期更新至当日,`filename`是指定的单只股票路径名称,如"d:\data000001円.csv"<br><br>
18+
get_data_len(file_path)<br>
19+
该函数过去单只股票的时间长度,`file_path`是单只股票路径名称,如"d:\data000001円.csv"<br><br>
20+
select_stock_data(file1, file2, date1, date2)<br>
21+
该函数对已经再本地的文件按照日期筛选,`date1`是开始数据,`date2`是结束数据,`file1`是源文件夹,`file2`是筛选日期后文件存放的文件夹<br><br>
22+
crop_stock(df, date)<br>
23+
该函数暂时不使用<br><br>
24+
fill_stock_data(target, sfile, tfile)<br>
25+
该函数按照沪深300指数的时间长度来对个股停盘数据进行填充,填充为该股上一交易日的数据。该函数是对所选文件夹下所有文件进行处理。<br>
26+
注意,如果开始日期是属于停牌的,那么该段停牌将不会被填充,后续会有更新。<br>
27+
`target`为参照股票,一般选择同时间段的沪深300指数文件,`sfile`为原文件夹,`tfile`为填充完要存放文件夹。<br><br>
28+
merge_stock_data(filename, tfile)<br>
29+
该函数将多个文件按序合并为一个文件,如讲沪深300只个股文件合并为一个总文件,方便后续模型输入。<br>
30+
`filename`是需要合并的文件夹路径,`tfile`是存放合并后文件的文件夹路径。<br><br>
31+
quchong(file)<br>
32+
该函数暂不使用。<br><br>
33+
34+
2、dataprocess.py 用于训练数据的处理,归一化等,模型的输入都由该文件的接口输出提供。
35+
get_train_data(batch_size=args.batch_size, time_step=args.time_step)<br>
36+
该函数用于处理训练数据,参数默认,有配置文件给定。该函数返回五个变量:`batch_index`训练集分批处理的批次, `val_index`验证集批次, `train_x_1`, 训练集输入,`train_y_1`, 训练集标签,`val_x`, y验证集输入,`val_y`验证集标签<br>
37+
备注:由于整个数据处理是对多只股票合成的总文件处理,所以在时间步长迭代添加时须在各自股票时间长度内完成,因此,需要在配置文件中指定股票长度。<br><br>
38+
get_test_data(time_step=args.time_step)<br>
39+
该函数用于处理测试集数据,返回两个变量:`test_x`测试集输入, `test_y`测试集标签。<br><br>
40+
get_update_data(time_step=args.time_step)<br>
41+
该函数将更新数据加历史数据的前time_step-1拼接,用于整批处理,如2019 1-3月数据和2018.12的数据拼接,返回拼接后的`train_x`, `train_y`<br>
42+
get_predict_data(file_name)<br>
43+
该函数完成下载实时股票数据,与之前的数据拼接后返回输入x。有一个参数需要填充,`file_name`既要预测的单只股票文件。
44+
45+
3、config.py 配置文件,所有接口内超参数,路径等,在该文件修改,即可在全局生效。
46+
47+
4、lstm_model 模型,包括训练,微调,测试,及预测。
48+
train_lstm(time_step=args.time_step, val=True)<br>
49+
用于训练的函数,val既是否验证,默认开启。其数据来自`get_train_data()`<br><br>
50+
fining_tune_train(time_step=args.time_step)<br>
51+
用于微调模型,如新增数据在原模型继续训练,或者迁移学习等。其数据可以来自`get_update_data()`<br><br>
52+
test(time_step=args.time_step)<br>
53+
用于测量测试集的准确率和F1,数据来自`get_test_data()`<br><br>
54+
predict(time_step=args.time_step)<br>
55+
用于预测第二天的收盘价,数据来自`get_predict_data(args.predict_dir)`<br><br>
56+
57+
5、stock_main.py 主函数
58+
可以调用上述所有函数接口,实现相关功能。<br><br>
59+
## 相关论文
60+
《基于LSTM的股票价格的多分类预测》<br><br>
61+
论文地址:https://www.hanspub.org/journal/PaperInformation.aspx?paperID=32542<br><br>

‎38stock/lstm_demo/__init__.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@author:XuMing(xuming624@qq.com)
4+
@description:
5+
"""

‎38stock/lstm_demo/config.py‎

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pandas as pd
2+
3+
4+
def get_data_len(file_path):
5+
if file_path:
6+
with open(file_path) as f:
7+
df = pd.read_csv(f)
8+
return len(df)
9+
10+
11+
# -------------------参数配置----------------- #
12+
class Arg:
13+
def __init__(self):
14+
# 训练集数据存放路径
15+
self.train_dir = './data/train_mix-17-18.csv'
16+
# 测试集数据存放路径
17+
self.test_dir = './data/test_mix.csv'
18+
# 更新数据存放路径
19+
self.new_dir = './data/train_mix-19.csv'
20+
# 要预测的数据存放路径
21+
self.predict_dir = './data/000001.csv'
22+
# 模型存放路径
23+
self.train_model_dir = './model/'
24+
# fining-turn模型存放路径
25+
self.fining_turn_model_dir = './data/finet/'
26+
# 训练图存放路径
27+
self.train_graph_dir = './data/graph/train_270/'
28+
# 验证loss存放路径
29+
self.val_graph_dir = './data/graph/val_270/'
30+
# 模型名称
31+
self.model_name = 'model-270-17-19'
32+
self.model_name_ft = 'model-ft-01-03'
33+
self.rnn_unit = 128 # 隐层节点数
34+
self.input_size = 6 # 输入维度(既用几个特征)
35+
self.output_size = 6 # 输出维度(既使用分类类数预测)
36+
self.layer_num = 3 # 隐藏层层数
37+
self.lr = 0.0006 # 学习率
38+
self.time_step = 20 # 时间步长
39+
self.epoch = 50 # 训练次数
40+
self.epoch_fining = 30 # 微调的迭代次数
41+
# 单只股票的长度(同一数据集股票长度应处理等长)
42+
self.stock_len = get_data_len('./data/399300.csv')
43+
# 更新后单只股票的长度(同一数据集股票长度应处理等长)
44+
self.stock_len_new = get_data_len('./data/399300.csv')
45+
self.batch_size = 1024 # batch_size
46+
self.ratio = 0.8 # 训练集验证集比例

‎38stock/lstm_demo/data/readme‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
2017-2018年沪深300 270只股票的合成文件train和2019年1-3月的test

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /