来源:自学PHP网 时间:2020-11-02 10:42 作者:小飞侠 阅读:次
[导读] 详解tensorflow之过拟合问题实战...
今天带来详解tensorflow之过拟合问题实战教程详解
过拟合问题实战 1.构建数据集 我们使用的数据集样本特性向量长度为 2,标签为 0 或 1,分别代表了 2 种类别。借助于 scikit-learn 库中提供的 make_moons 工具我们可以生成任意多数据的训练集。 import matplotlib.pyplot as plt # 导入数据集生成工具 import numpy as np import seaborn as sns from sklearn.datasets import make_moons from sklearn.model_selection import train_test_split from tensorflow.keras import layers, Sequential, regularizers from mpl_toolkits.mplot3d import Axes3D 为了演示过拟合现象,我们只采样了 1000 个样本数据,同时添加标准差为 0.25 的高斯噪声数据: def load_dataset(): # 采样点数 N_SAMPLES = 1000 # 测试数量比率 TEST_SIZE = None # 从 moon 分布中随机采样 1000 个点,并切分为训练集-测试集 X, y = make_moons(n_samples=N_SAMPLES, noise=0.25, random_state=100) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42) return X, y, X_train, X_test, y_train, y_test make_plot 函数可以方便地根据样本的坐标 X 和样本的标签 y 绘制出数据的分布图: def make_plot(X, y, plot_name, file_name, XX=None, YY=None, preds=None, dark=False, output_dir=OUTPUT_DIR): # 绘制数据集的分布, X 为 2D 坐标, y 为数据点的标签 if dark: plt.style.use('dark_background') else: sns.set_style("whitegrid") axes = plt.gca() axes.set_xlim([-2, 3]) axes.set_ylim([-1.5, 2]) axes.set(xlabel="$x_1$", ylabel="$x_2$") plt.title(plot_name, fontsize=20, fontproperties='SimHei') plt.subplots_adjust(left=0.20) plt.subplots_adjust(right=0.80) if XX is not None and YY is not None and preds is not None: plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=0.08, cmap=plt.cm.Spectral) plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6) # 绘制散点图,根据标签区分颜色m=markers markers = ['o' if i == 1 else 's' for i in y.ravel()] mscatter(X[:, 0], X[:, 1], c=y.ravel(), s=20, cmap=plt.cm.Spectral, edgecolors='none', m=markers, ax=axes) # 保存矢量图 plt.savefig(output_dir + '/' + file_name) plt.close() def mscatter(x, y, ax=None, m=None, **kw): import matplotlib.markers as mmarkers if not ax: ax = plt.gca() sc = ax.scatter(x, y, **kw) if (m is not None) and (len(m) == len(x)): paths = [] for marker in m: if isinstance(marker, mmarkers.MarkerStyle): marker_obj = marker else: marker_obj = mmarkers.MarkerStyle(marker) path = marker_obj.get_path().transformed( marker_obj.get_transform()) paths.append(path) sc.set_paths(paths) return sc X, y, X_train, X_test, y_train, y_test = load_dataset() make_plot(X,y,"haha",'月牙形状二分类数据集分布.svg') 2.网络层数的影响 为了探讨不同的网络深度下的过拟合程度,我们共进行了 5 次训练实验。在 |
自学PHP网专注网站建设学习,PHP程序学习,平面设计学习,以及操作系统学习
京ICP备14009008号-1@版权所有www.zixuephp.com
网站声明:本站所有视频,教程都由网友上传,站长收集和分享给大家学习使用,如由牵扯版权问题请联系站长邮箱904561283@qq.com