深度学习之搭建LSTM模型预测股价
大家好,我是带我去滑雪!
本期利用股价数据集,该数据集中in.csv为训练集,t.csv为测试集,里面有开盘价、最高股价、最低股价、收盘价、调整后的收盘价、成交量,2021年11月以前,可以在美国Yahoo网站下载股价历史数据,但现在对中国已经禁用了,可以去其他地方进行下载。本次使用调整后的收盘价进行预测。
目录
1、导入相关模块和数据集
2、产生训练所需的特征和标签数据
3、转换数据为(样本数,时步、特征)的张量
4、定义LSTM模型
5、使用已经训练好的LSTM模型预测股价
6、绘制真实股价与预测股价的对比图
1、导入相关模块和数据集
numpy as np
as pd
from .
from keras.
from keras. Dense, ,LSTM,,GRU
# 载入股价数据集
= pd.(r'E:\工作\硕士\博客\博客37-\in.csv',="Date",=True)
print()
= pd.(r'E:\工作\硕士\博客\博客37-\t.csv',="Date",=True)
print( )
输出结果:
Open High Low Close Adj Close \ Date 2012-01-03 324.360352 331.916199 324.077179 330.555054 330.555054 2012-01-04 330.366272 332.959412 328.175537 331.980774 331.980774 2012-01-05 328.925659 329.839722 325.994720 327.375732 327.375732 2012-01-06 327.445282 327.867523 322.795532 322.909790 322.909790 2012-01-09 321.161163 321.409546 308.607819 309.218842 309.218842 ... ... ... ... ... ... 2016-12-23 790.900024 792.739990 787.280029 789.909973 789.909973 2016-12-27 790.679993 797.859985 787.656982 791.549988 791.549988 2016-12-28 793.700012 794.229980 783.200012 785.049988 785.049988 2016-12-29 783.330017 785.929993 778.919983 782.789978 782.789978 2016-12-30 782.750000 782.780029 770.409973 771.820007 771.820007 Volume Date 2012-01-03 7400800 2012-01-04 5765200 2012-01-05 6608400 2012-01-06 5420700 2012-01-09 11720900 ... ... 2016-12-23 623400 2016-12-27 789100 2016-12-28 1153800 2016-12-29 742200 2016-12-30 1770000 [1258 rows x 6 columns]Open High Low Close Adj Close \ Date 2017-01-03 778.809998 789.630005 775.799988 786.140015 786.140015 2017-01-04 788.359985 791.340027 783.159973 786.900024 786.900024 2017-01-05 786.080017 794.479980 785.020020 794.020020 794.020020 2017-01-06 795.260010 807.900024 792.203979 806.150024 806.150024 2017-01-09 806.400024 809.966003 802.830017 806.650024 806.650024 ... ... ... ... ... ... 2017-04-24 851.200012 863.450012 849.859985 862.760010 862.760010 2017-04-25 865.000000 875.000000 862.809998 872.299988 872.299988 2017-04-26 874.229980 876.049988 867.747986 871.729980 871.729980 2017-04-27 873.599976 875.400024 870.380005 874.250000 874.250000 2017-04-28 910.659973 916.849976 905.770020 905.960022 905.960022 Volume Date 2017-01-03 1657300 2017-01-04 1073000 2017-01-05 1335200 2017-01-06 1640200 2017-01-09 1272400 ... ... 2017-04-24 1372500 2017-04-25 1672000 2017-04-26 1237200 2017-04-27 2026800 2017-04-28 3219500 [81 rows x 6 columns]
2、产生训练所需的特征和标签数据
= .iloc[:,4:5].
#数据归一化
sc = ()
= sc.()
def (ds, =1):
, = [],[]
for i in range(len(ds)-):
.(ds[i:(i+), 0])
.(ds[i+, 0])
np.array(), np.array()
= 60
print("回看天数:", )
# 分割成特征数据和标签数据
, = (, )
输出结果:
回看天数: 60
Out[5]:
array([0.08291369, 0.07626093, 0.0815312 , ..., 0.94758974, 0.94336851,0.92287887])
3、转换数据为(样本数,时步、特征)的张量
= np.(, (.shape[0], .shape[1], 1))
.shape
输出结果:
(1198, 60, 1)
4、定义LSTM模型
在编译模型中,损失函数为MSE,优化器为adam。在训练模型中,训练周期为100,批次尺寸为32。
model = ()
model.add(LSTM(50, =True,
=(.shape[1], 1)))
model.add((0.2))
model.add(LSTM(50, =True))
model.add((0.2))
model.add(LSTM(50))
model.add((0.2))
model.add(Dense(1))
model.()
#编译模型
pile(loss="mse", ="adam")
#训练模型
model.fit(, , =100, =32)
输出结果:
38/38 [==============================] - 2s 46ms/step - loss: 0.0013 Epoch 94/100 38/38 [==============================] - 2s 46ms/step - loss: 0.0013 Epoch 95/100 38/38 [==============================] - 2s 47ms/step - loss: 0.0012 Epoch 96/100 38/38 [==============================] - 2s 46ms/step - loss: 0.0013 Epoch 97/100 38/38 [==============================] - 2s 46ms/step - loss: 0.0013 Epoch 98/100 38/38 [==============================] - 2s 47ms/step - loss: 0.0013 Epoch 99/100 38/38 [==============================] - 2s 46ms/step - loss: 0.0012 Epoch 100/100 38/38 [==============================] - 2s 46ms/step - loss: 0.0013
5、使用已经训练好的LSTM模型预测股价
测试集为2017年1月到3月的股价,因为使用的是前60天的股价数据,使用预测的是4月份股价。
= .iloc[:,4:5].
# 产生标签数据
_, = (, )
#特征数据和标准化
= sc.()
,_ = (, )
# 转换成(样本数, 时步, 特征)张量
= np.(, (.shape[0], .shape[1], 1))
= model.()
# 将预测值转换回股价
= sc.()
输出结果:
array([[814.5596 ],[819.2384 ],[821.1239 ],[823.5624 ],[824.0013 ],[822.3476 ],[819.3523 ],[816.00055],[813.82117],[812.62726],[812.6262 ],[812.9471 ],[817.2544 ],[821.539 ],[824.44244],[826.5891 ],[828.0157 ],[834.4217 ],[843.3087 ],[849.4051 ],[852.694 ]], dtype=float32)
6、绘制真实股价与预测股价的对比图
. as plt
plt.plot(, color="red", label="Real Stock Price")
plt.plot(, color="blue", label=" Stock Price")
plt.title("2017 Stock Price ")
plt.("Time")
plt.(" Time Price")
plt.()
plt.("E:\工作\硕士\博客\博客37-/.png",
="tight",
= 1,
= True,
="w",
='w',
dpi=300,
='')
输出结果:
更多优质内容持续发布中,请移步主页查看。