p神经网络算法 常见的神经网络算法
时间:2023-05-02 03:24/span>
作者:tiger
分类:
新知
浏览:1330
评论:0
使用最简单的自定义三层网络,仅用3个节点,尝试拟合各种函数。效果还行,肯定还有更多优化的余地,仅供入门参考。
包引入
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import MinMaxScaler
函数定义
function_select = 5
def myfun(x):
functions = {
1: np.power(x-7,2), 二次函数
2: np.sin(x), sin
3: np.sign(x), signum
4: np.exp(x), 指数
5: np.power(x,3) - 3*np.power(x,2) + 5, 多项式
6: 1+np.power(x,2)/4000-np.cos(x) 格里旺克函数
}
return functions.get(function_select)
构建模型
activation_function = &39;tanh&39;
def build_model(train_data, labels, units, epochs):
print(train_data.shape)
model = keras.Sequential()
model.add(keras.layers.Dense(units, input_dim=train_data.shape[1], kernel_initializer=&39;he_normal&39;, activation=activation_function))
model.add(keras.layers.Dense(1, kernel_initializer=&39;he_normal&39;, activation=&39;linear&39;))
Compile model
model.compile(optimizer=&39;adam&39;,
loss=&39;mse&39;,
metrics=[&39;mse&39;])
训练模型
model.fit(train_data, labels, epochs=epochs, batch_size=50, verbose=0)
return model
训练数据
batch_size = 20
x_train = np.linspace(-10, 10, num=300).reshape(-1,1)
计算真实样本y
y_train = myfun(x_train)
正规化
x_scaler = MinMaxScaler(feature_range=(-1, 1))
y_scaler = MinMaxScaler(feature_range=(-1, 1))
x_scaled = x_scaler.fit_transform(x_train)
y_scaled = y_scaler.fit_transform(y_train)
开始训练模型
units = 3
epochs = 2000
model_best = build_model(train_data=x_scaled, labels=y_scaled, units=units, epochs=epochs)
测试
测试集
x_eval = np.linspace(-8, 5, num=40).reshape(-1,1)
x_eval_scaled = x_scaler.transform(x_eval)
result = model_best.predict(x_eval_scaled, batch_size=50)
predictions = y_scaler.inverse_transform(result)
画图
fig = plt.figure(1, figsize=(20,10))
ax = fig.add_subplot(1, 2, 1)
plt.plot(x_eval, predictions, &39;.&39;, color=&39;red&39;, linewidth=2.0)
plt.plot(x_eval, myfun(x_eval), &39;-&39;, color=&39;blue&39;, linewidth=1.0)
plt.plot(x_train, myfun(x_train), &39;-&39;, color=&39;gray&39;, linewidth=1.0)
ax = fig.add_subplot(1, 2, 2)
plt.plot(x_eval, np.abs(predictions-myfun(x_eval)), &39;-&39;, label=&39;output&39;, color=&39;firebrick&39;, linewidth=2.0)
plt.show()
多项式
二次函数
sin
注意:针对sin的特点,你需要将训练数据的密集度调整好,否则可能不能很好的拟合。