tensorflow实现简单线性回归( Linear Regression)
创始人
2025-06-01 20:43:42
0

tensorflow实现简单线性回归( Linear Regression)

线性回归过程

线性回归中线性的含义:因变量y对于未知的回归系数是线性的。

  1. 准备数据集
  2. 建立线性模型:
    随机初始化w和b
    y=w·x+b,目标:求出权重w和偏置b
  3. 确定损失函数(预测值与真实值之间的误差)–均方误差
  4. 梯度下降优化损失:需要指定学习率(超参数)

(0)导入依赖包

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import matplotlib.pyplot as plt

(1)创建模拟数据

创建一个线性回归的模拟数据,y = x,一共20个点,再在这些点上加一些随机的噪音,数据用numpy生成

x_data = np.linspace(0, 10, 20) + np.random.uniform(-1.5, 1.5, 20)
y_data = np.linspace(0, 10, 20) + np.random.uniform(-1.5, 1.5, 20)

(2)初始化w和b

y = W * x + b

线性回归中,训练的参数就是权重(weight)w 和偏移(bais)b,用numpy随机生成w和b的初始值。

w = np.random.uniform(-1, 1)
b = np.random.uniform(-1, 1)

w和b是tensorflow中训练的对象,需要转换成tensorflow的变量。

w_tf = tf.Variable(w)
b_tf = tf.Variable(b)

(3)确定损失函数

确定损失函数(预测值与真实值之间的误差)-均方误差(Mean Square Error MSE)

error = 0
for x, y in zip(x_data, y_data):error += (w_tf * x + b_tf - y) ** 2

(4)优化损失

线性回归,需要通过Gradient Decendent(梯度下降方法)训练w和b,优化损失。

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train = optimizer.minimize(error)

(5)开始训练

init = tf.global_variables_initializer()train_steps = 10
with tf.Session() as sess:sess.run(init)for step in range(train_steps):sess.run(train)w_final, b_final = sess.run([w_tf, b_tf])

(6)检验训练成果

把原始数据和训练得到的结果通过plt.plot可视化

y_pred = w_final * x_data + b_finalprint(w_final, b_final)plt.plot(x_data, y_data, '*')
plt.plot(x_data, y_pred)
plt.show()

完整代码

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import matplotlib.pyplot as plt# # ------------1. 获取训练数据--------------------------------------
x_data = np.linspace(0, 10, 20) + np.random.uniform(-1.5, 1.5, 20)
y_data = np.linspace(0, 10, 20) + np.random.uniform(-1.5, 1.5, 20)# # ------------2.构造预测的线性回归函数 y = W * x + b------------------
w = tf.Variable(tf.random_uniform([1]))  # 构造一个0~1的随机数
b = tf.Variable(tf.zeros([1]))  # 设b的初始值为0
print(w, b)w_tf = tf.Variable(w)
b_tf = tf.Variable(b)# # ------------3.确定损失函数-----------------------------------------
error = 0
for x, y in zip(x_data, y_data):error += (w_tf * x + b_tf - y) ** 2# # ------------4.通过梯度下降方法优化损失--------------------------------
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train = optimizer.minimize(error)# # ------------5.开始训练----------------------------------------------
init = tf.global_variables_initializer()train_steps = 10
with tf.Session() as sess:sess.run(init)for step in range(train_steps):sess.run(train)w_final, b_final = sess.run([w_tf, b_tf])# # ------------6.检验训练成果-----------------------------------------
y_pred = w_final * x_data + b_finalprint(w_final, b_final)plt.plot(x_data, y_data, '*')
plt.plot(x_data, y_pred)
plt.show()

相关内容

热门资讯

港股概念追踪|《稳定币条例》正... 智通财经获悉,5月30日,香港《稳定币条例》作为全球首个针对法币稳定币的专项立法正式落地,有效填补了...
原创 全... 既然美国在联合国只手遮天,影响相关机构的公平公正,那我们就“另起炉灶”再建新群,“国际调解院” 的签...
“新政策”彰显中国市场开放性 近日,在拖轮助力下,载着上万个集装箱的“中远海运杜鹃”轮缓缓驶进天津港。得益于天津港中远海运美东航线...
雅博光伏E周刊:“两办:202... 01E点聚焦 两办定目标:碳排放权、用水权交易制度2027年基本完善。 据新华社5月29日消息,中...
为孤独症人群提供全生命周期服务... 为孤独症人群提供全生命周期服务需体系化法律保障 □ 本报记者 文丽娟 周斌 陈磊 孤独症谱系障碍(...
青海省总就安全生产发出劳动法律... 原标题:青海省总就安全生产发出劳动法律监督提示函(引题) 严禁强令职工冒险作业(主题) 中工网讯 (...
普惠托育服务咋收费,我省发布最... 近日,山西晚报·山河+记者从省发改委获悉:我省发布完善普惠托育服务收费新政策,明确了涵盖公办托育机构...
生活垃圾管理条例实施细则印发 ... 近日,市政府办公室正式印发《昆明市生活垃圾管理条例实施细则》(以下简称《细则》),昆明将开展二手商品...
四维图新:正式发布《市值管理制... 金融界6月3日消息,有投资者在互动平台向四维图新提问:董秘,您好!公司股价近期出现底部放量上涨,是否...
云南省重点高原湖泊入湖河道保护... 云南省人民代表大会常务委员会公告 〔十四届〕 第四十四号 《云南省重点高原湖泊入湖河道保护条例》已由...