网站地图    收藏   

主页 > 后端 > python >

详解tensorflow之过拟合问题实战

来源:自学PHP网    时间:2020-11-02 10:42 作者:小飞侠 阅读:

[导读] 详解tensorflow之过拟合问题实战...

今天带来详解tensorflow之过拟合问题实战教程详解

过拟合问题实战

1.构建数据集

我们使用的数据集样本特性向量长度为 2,标签为 0 或 1,分别代表了 2 种类别。借助于 scikit-learn 库中提供的 make_moons 工具我们可以生成任意多数据的训练集。

import matplotlib.pyplot as plt
# 导入数据集生成工具
import numpy as np
import seaborn as sns
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, Sequential, regularizers
from mpl_toolkits.mplot3d import Axes3D

为了演示过拟合现象,我们只采样了 1000 个样本数据,同时添加标准差为 0.25 的高斯噪声数据:

def load_dataset():
 # 采样点数
 N_SAMPLES = 1000
 # 测试数量比率
 TEST_SIZE = None

 # 从 moon 分布中随机采样 1000 个点,并切分为训练集-测试集
 X, y = make_moons(n_samples=N_SAMPLES, noise=0.25, random_state=100)
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)
 return X, y, X_train, X_test, y_train, y_test

make_plot 函数可以方便地根据样本的坐标 X 和样本的标签 y 绘制出数据的分布图:

def make_plot(X, y, plot_name, file_name, XX=None, YY=None, preds=None, dark=False, output_dir=OUTPUT_DIR):
 # 绘制数据集的分布, X 为 2D 坐标, y 为数据点的标签
 if dark:
  plt.style.use('dark_background')
 else:
  sns.set_style("whitegrid")
 axes = plt.gca()
 axes.set_xlim([-2, 3])
 axes.set_ylim([-1.5, 2])
 axes.set(xlabel="$x_1$", ylabel="$x_2$")
 plt.title(plot_name, fontsize=20, fontproperties='SimHei')
 plt.subplots_adjust(left=0.20)
 plt.subplots_adjust(right=0.80)
 if XX is not None and YY is not None and preds is not None:
  plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=0.08, cmap=plt.cm.Spectral)
  plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6)
 # 绘制散点图,根据标签区分颜色m=markers
 markers = ['o' if i == 1 else 's' for i in y.ravel()]
 mscatter(X[:, 0], X[:, 1], c=y.ravel(), s=20, cmap=plt.cm.Spectral, edgecolors='none', m=markers, ax=axes)
 # 保存矢量图
 plt.savefig(output_dir + '/' + file_name)
 plt.close()
def mscatter(x, y, ax=None, m=None, **kw):
 import matplotlib.markers as mmarkers
 if not ax: ax = plt.gca()
 sc = ax.scatter(x, y, **kw)
 if (m is not None) and (len(m) == len(x)):
  paths = []
  for marker in m:
   if isinstance(marker, mmarkers.MarkerStyle):
    marker_obj = marker
   else:
    marker_obj = mmarkers.MarkerStyle(marker)
   path = marker_obj.get_path().transformed(
    marker_obj.get_transform())
   paths.append(path)
  sc.set_paths(paths)
 return sc
X, y, X_train, X_test, y_train, y_test = load_dataset()
make_plot(X,y,"haha",'月牙形状二分类数据集分布.svg')

2.网络层数的影响

为了探讨不同的网络深度下的过拟合程度,我们共进行了 5 次训练实验。在

自学PHP网专注网站建设学习,PHP程序学习,平面设计学习,以及操作系统学习

京ICP备14009008号-1@版权所有www.zixuephp.com

网站声明:本站所有视频,教程都由网友上传,站长收集和分享给大家学习使用,如由牵扯版权问题请联系站长邮箱904561283@qq.com

添加评论