Python和TensorFlow构建了一个基于ResNet50的迁移学习模型

news/2025/2/25 4:10:46

在这篇博客中,我们将介绍如何实现一个车辆识别系统。我们将使用Python和TensorFlow来实现这个系统,同时也会涉及到使用预训练模型进行迁移学习。在这篇文章中,我们将详细介绍整个过程,包括数据准备、模型训练和模型评估等步骤。

目录

1. 数据准备

1.1 下载和解压数据集

1.2 划分数据集

1.3 数据预处理

2. 构建模型

3. 模型训练

3.1 编译模型

4. 模型评估


1. 数据准备

首先,我们需要一个包含大量车辆图片的数据集。一个常用的数据集是Stanford Cars Dataset,它包含196类车辆的图片。我们将使用这个数据集来训练和评估我们的模型。

1.1 下载和解压数据集

下载数据集并解压到一个目录,例如./data。数据集中包含两个文件,cars_traincars_test,分别用于训练和测试。

1.2 划分数据集

我们将数据集划分为训练集、验证集和测试集。使用train_test_split函数进行划分,保留20%的数据作为测试集。

from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

1.3 数据预处理

我们需要对图像数据进行预处理,包括缩放、归一化等操作。使用ImageDataGenerator类进行预处理,同时也可以进行数据增强。

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1./255)

 

2. 构建模型

我们将使用预训练的ResNet50模型作为基础模型,并在此基础上添加自定义的分类层。

from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D

base_model = ResNet50(weights='imagenet', include_top=False)

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(196, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

3. 模型训练

3.1 编译模型

我们需要定义损失函数、优化器和评估指标。这里我们使用交叉熵损失、Adam优化器和准确率。

model.compile(loss
(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])


3.2 训练模型

使用`fit_generator`函数进行模型训练。设置合适的批大小、迭代次数和回调函数。

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

batch_size = 32
epochs = 50

train_generator = train_datagen.flow_from_directory(train_data,
                                                    target_size=(224, 224),
                                                    batch_size=batch_size,
                                                    class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(validation_data,
                                                        target_size=(224, 224),
                                                        batch_size=batch_size,
                                                        class_mode='categorical')

checkpointer = ModelCheckpoint(filepath='model.h5', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=10)

history = model.fit_generator(train_generator,
                              steps_per_epoch=len(train_data) // batch_size,
                              epochs=epochs,
                              validation_data=validation_generator,
                              validation_steps=len(validation_data) // batch_size,
                              callbacks=[checkpointer, early_stopping])

4. 模型评估

我们将使用测试集来评估训练好的模型。首先,加载保存的最佳模型。

from tensorflow.keras.models import load_model

model = load_model('model.h5')

然后,使用evaluate_generator函数进行模型评估。

test_generator = test_datagen.flow_from_directory(test_data,
                                                  target_size=(224, 224),
                                                  batch_size=batch_size,
                                                  class_mode='categorical')

scores = model.evaluate_generator(test_generator, steps=len(test_data) // batch_size)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])

文章来源:https://blog.csdn.net/a871923942/article/details/130550512
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.niftyadmin.cn/n/298804.html

相关文章

Python使用imghdr模块检测图片类型

在Python中,我们经常需要处理图片文件。但是,有时候我们并不知道图片的具体类型,这时候就需要使用imghdr模块来检测图片类型。 imghdr模块是Python自带的模块,可以用来检测图片文件的类型。它可以检测常见的图片格式,…

探索编程语言的本质:了解编程语言的定义与分类

前言: 由于我看了一眼我的粉丝列表,发现好像关于开发语言的童鞋占比较多哈,所以出一下这篇专栏。 要关注的小伙伴可以提前订阅哈。 目录 前言: 引言 1.1. 编程语言的重要性 1.2. 本文的目的与结构 2.什么是编程语言&#…

Windows Information Protection(WIP)部署方案

目录 前言 一、方案准备工作 1、确定哪些数据需要保护 2、选择合适的加密方式

HLS直播与延迟时长的来源与超低延迟直播

1.HLS直播延迟时长(HTTP Live Streaming) HTTP Live Streaming(简称 HLS)是一个基于 HTTP 的视频流协议,由 Apple 公司实现,Mac OS 上的 QuickTime、Safari 以及 iOS 上的 Safari 都能很好的支持 HLS&…

day43—编程题

文章目录 1.第一题1.1题目1.2思路1.3解题 2.第二题2.1题目2.2思路2.3解题 1.第一题 1.1题目 描述: 输入两个整数 n 和 m,从数列1,2,3…n 中随意取几个数,使其和等于 m ,要求将其中所有的可能组合列出来 输入描述: 每个测试输入包…

9个加密货币交易所被查封,交易所安全审计后仍不安全

美国联邦调查局和乌克兰警方查封了九个加密货币交易网站,这些网站为包括勒索软件参与者在内的诈骗者和网络犯罪分子洗钱提供了便利。 联邦调查局 FBI 在其公告中表示,该行动是在虚拟货币响应小组、乌克兰国家警察和该国法律检察官的帮助下进行的。 此次…

nodejs+vue网上学生社团管理系统

并运用Photoshop CS6技术美化网页,辅之以CSS技术。系统是基于面向对象编程的web应用程序。本系统主要实现的功能有系统用户管理、社团信息管理、社团类别管理、社团活动管理、社团论坛管理、系统管理、个人资料管理、学生入团管理、社团公告管理、社团活动管理、社团…

【深度学习】基于PyTorch 迁移学习 实现医学影像识别(详细案例分析 + 源代码) | 附:深度学习在医学影像领域的应用

但是太阳,他每时每刻都是夕阳,也是旭日,当他熄灭着走下山去收尽苍凉残照之际,正是他在另一面燃烧着爬上山巅布散烈烈朝晖之时。 🎯作者主页: 追光者♂🔥 🌸个人简介: 💖[1] 计算机专业硕士研究生💖 🌟[2] 2022年度博客之星人工智能领域TOP4🌟…