百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

NSFW 图片分类(图片cnn分类)

nanshan 2025-03-28 19:24 11 浏览 0 评论

NSFW指的是 不适宜工作场所("Not Safe (or Suitable) For Work;")。在本文中,我介绍如何创建一个检测NSFW图像的图像分类模型。

数据集

由于数据集的性质,我们无法从一些数据集的网站(如Kaggle等)获得所有图像。

但是我们找到了一个专门抓取这种类型图片的github库,所以我们可以直接使用。clone项目后可以运行下面的代码来创建文件夹,并将每个图像下载到其特定的文件夹中。

folders = ['drawings','hentai','neutral','porn','sexy']
urls = ['urls_drawings.txt','urls_hentai.txt','urls_neutral.txt','urls_porn.txt','urls_sexy.txt']
names = ['d','h','n','p','s']
for i,j,k in zip(folders,urls,names):
try:
#Specify the path of the folder that has to be made
folder_path = os.path.join('your directory',i)
os.mkdir(folder_path)
except:
pass
#setup the path of url text file
url_path = os.path.join('Datasets_Urls',j)
my_file = open(url_path, "r")
data = my_file.read()
#create a list with all urls
data_into_list = data.split("\n")
my_file.close()
icount = 0
for ii in data_into_list:
try:
#create a unique image names for each images
image_name = 'image'+str(icount)+str(k)+'.png'
image_path = os.path.join(folder_path,image_name)
#download it using the library
urllib.request.urlretrieve(ii, image_path)
icount+=1
except Exception as e:
pass
#this below code is done to make the count of the image same for all the data 
#you can use a big number if you are building a more complex model or if you have a good system
if icount == 2000:
break

这里的folder变量表示类的名称,urls变量用于获取URL文本文件(可以根据文本文件名更改它),name变量用于为每个图像创建唯一的名称。

上面代码将为每个类下载2000张图像,可以编辑最后一个“if”条件来更改下载图像的个数。

数据准备

我们下载的文件夹可能包含其他类型的文件,所以首先必须删除不需要的类型的文件。

image_exts = ['jpeg','.jpg','bmp','png']
path_list = ['drawings','hentai','neutral','porn','sexy']
cwd = os.getcwd()
def remove_other_images(path_list):
for ii in path_list:
data_dir = os.path.join(cwd,'DataSet',ii)
for image in os.listdir(os.path.join(data_dir)):
image_path = os.path.join(data_dir,image_class,image)
try:
img = cv2.imread(image_path)
tip = imghdr.what(image_path)
if tip not in image_exts:
print('Image not in ext list {}'.format(image_path))
os.remove(image_path)
except Exception as e:
print("Issue with image {}".format(image_path))
remove_other_images(path_list)

上面的代码删除了扩展名不是指定格式的图像。

另外图像可能包含许多重复的图像,所以我们必须从每个文件夹中删除重复的图像。

cwd = os.getcwd()
path_list = ['drawings','hentai','neutral','porn','sexy']
def remove_dup_images(path_list):
for ii in path_list:
os.chdir(os.path.join(cwd,'DataSet',ii))
filelist = os.listdir()
duplicates = []
hash_keys = dict()
for index, filename in enumerate(filelist):
if os.path.isfile(filename):
with open(filename,'rb') as f:
filehash = hashlib.md5(f.read()).hexdigest()
if filehash not in hash_keys:
hash_keys[filehash] = index
else:
duplicates.append((index,hash_keys[filehash]))

for index in duplicates:
os.remove(filelist[index[0]])
print('{} duplicates removed from {}'.format(len(duplicates),ii))
remove_dup_images(path_list)

这里我们使用hashlib.md5编码来查找每个类中的重复图像。

Md5为每个图像创建一个唯一的哈希值,如果哈希值重复(重复图像),那么我们将重复图片添加到一个列表中,稍后进行删除。

因为使用TensorFlow框架所以需要判断是否被TensorFlow支持,所以我们这里加一个判断:

import tensorflow as tf
os.chdir('{data-set} directory')
cwd = os.getcwd()
for ii in path_list:
os.chdir(os.path.join(cwd,ii))
filelist = os.listdir()
for image_file in filelist:
with open(image_file, 'rb') as f:
image_data = f.read()
# Check the file format
_, ext = os.path.splitext(image_file)
if ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
print('Unsupported image format:', ext)
os.remove(os.path.join(cwd,ii,image_file)) 
else:
# Decode the image
try:
image = tf.image.decode_image(image_data)
except:
print(image_file)
print("unspported")
os.remove(os.path.join(cwd,ii,image_file))

以上就是数据准备的所有工作,在清理完数据后,我们可以拆分数据。比如分割创建一个训练、验证和测试文件夹,并手动添加文件夹中的图像,我们将80%用于训练,10%用于验证,10%用于测试。

模型

首先导入tensorflow

import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import hashlib
from imageio import imread
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.layers import Flatten,Dense,Input
from tensorflow.keras.models import Model,Sequential
from keras import optimizers

对于图像,默认大小设置为224,224。

IMAGE_SIZE = [224,224]

可以使用ImageDataGenerator库,进行数据增强。数据增强也叫数据扩充,是为了增加数据集的大小。ImageDataGenerator根据给定的参数创建新图像,并将其用于训练(注意:当使用ImageDataGenerator时,原始数据将不用于训练)。

train_datagen = ImageDataGenerator(
rescale=1./255,
preprocessing_function=preprocess_input,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')

对于测试集也是这样:

test_datagen = ImageDataGenerator(rescale=1./255)

为了演示,我们直接使用VGG模型

vgg = VGG16(input_shape=IMAGE_SIZE+[3],weights='imagenet',include_top=False

然后冻结前面的层:

for layer in vgg.layers:
layer.trainable = False

最后我们加入自己的分类头:

x = Flatten()(vgg.output)
prediction = Dense(5,activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=prediction)
model.summary()

模型是这样的:

训练

看看我们训练集:

train_set = train_datagen.flow_from_directory('DataSet/train',
target_size=(224,224),
batch_size=32,
class_mode='sparse')

验证集

val_set = train_datagen.flow_from_directory('DataSet/validation',
target_size=(224,224),
batch_size=32,
class_mode='sparse')

使用'
sparse_categorical_crossentropy '损失,这样可以将标签编码为整数而不是独热编码。

from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras.metrics import CategoricalAccuracy
adam = optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy',MeanSquaredError(name='val_loss'),CategoricalAccuracy(name='val_accuracy')])

然后就可以训练了:

from datetime import datetime
from keras.callbacks import ModelCheckpoint
log_dir = 'vg_log'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)
start = datetime.now()
history = model.fit_generator(train_set,
validation_data=val_set,
epochs=100,
steps_per_epoch=len(train_set)// batch_size,
validation_steps=len(val_set)//batch_size,
callbacks=[tensorboard_callback],
verbose=1)
duration = datetime.now() - start
print("Time taken for training is ",duration)

模型训练了100次。得到了80%的验证准确率。f1得分为93%

预测

下面的函数将获取一个图像列表并根据该列表进行预测。

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
def print_classes(images,model):
classes = ['Drawing','Hentai','Neutral','Porn','Sexual']
fig, ax = plt.subplots(ncols=len(images), figsize=(20,20))
for idx,img in enumerate(images):
img = mpimg.imread(img)
resize = tf.image.resize(img,(224,224))
result = model.predict(np.expand_dims(resize/255,0))
result = np.argmax(result)
if classes[result] == 'Porn':
img = gaussian_filter(img, sigma=6)
elif classes[result] == 'Sexual':
img = gaussian_filter(img, sigma=6)
elif classes[result] == 'Hentai':
img = gaussian_filter(img, sigma=6)
ax[idx].imshow(img)
ax[idx].title.set_text(classes[result])
li = ['test1.jpeg','test2.jpeg','test3.jpeg','test4.jpeg','test5.jpeg']
print_classes(li,model)

看结果还是可以的。

最后,本文的源代码:

https://github.com/Nikhilthalappalli/NSFW-Classifier

但是我觉得源代码不重要,Alex Kim的项目才是你需要的:

https://github.com/alex000kim/nsfw_data_scraper

作者:Nikhil Thalappalli

相关推荐

ubuntu24.04下kubernetes1.30环境搭建

设置root用户密码#在Ubuntu系统中,默认情况下root用户是被禁用的(没有设置密码)#而是通过sudo命令让普通用户临时获取管理员权限,#如果需要启用或修改root密...

Canonical 在 Ubuntu 24.10 发布之前对 Snap 进行了更多改进

作为Ubuntu桌面临时工程总监,OliverSmith介绍了Ubuntu24.10的最新进展。在Ubuntu24.10功能冻结之前,GNOME47测试版已经登陆Ubuntu...

Ubuntu Touch OTA-5手机系统发布:细化电源配置等

IT之家8月2日消息,UBports基金会于7月30日发布UbuntuTouch20.04OTA-5版本更新,距离上次OTA-4更新发布相隔6个月时间。Ubuntu...

Ubuntu更契合英特尔酷睿Ultra,综合性能比Win11高15%

IT之家12月23日消息,英特尔本月推出酷睿UltraMeteorLake处理器,那么Win11和Ubuntu发行版两者时间,谁能更好地发挥其性能呢?国外科技媒体phoron...

针对英特尔酷睿CPU优化,Canonical发布Ubuntu实时内核

IT之家7月27日消息,Canonical今天宣布针对支持时序协调运算(TCC)和时间敏感网络(IEEETSN)的英特尔酷睿处理器,推出优化版实时Ubuntu内核。Canonical...

在Ubuntu/Debian上设置永久DNS域名服务器

在Linux上设置自定义DNS服务器可以提高性能和安全性,甚至可以通过DNS阻止一些使用地理屏蔽的网站。有几种方法可以做到这一点,包括在许多Linux发行版中包括的NetworkManagerGUI...

宣布延期:Ubuntu 24.04 LTS第一个版本发布推迟两周

Ubuntu开发团队原计划于8月19日星期四发布Ubuntu24.04.1LTS。然而,由于发现几个重大升级错误,发布被推迟。Ubuntu24.04.1LTS的新发布日期现定为...

Ubuntu系统已经十岁了 10月新推14.10版

|责编:李鑫比较非主流的Ubuntu系统刚刚推出了14.10版,同时大家也可能不知道其实它已经十岁了!它的第一次公布时间为2004年的10月呢。在这次十周年更新中,Ubuntu为用...

wsl2在休眠后的时间偏差问题的修复

笔记本电脑在日常使用中,常常会有进入休眠状态的情况。休眠对于wsl2而言,却造成了时间偏差的问题,休眠期间wsl2的时间停止了。这个问题的根治,需要等微软。本文提供的是一种简单的修复办法。ntp是网络...

基于Ubuntu22.04源码安装配置RabbitVCS过程记录

基于Ubuntu22.04源码安装配置RabbitVCS过程记录安装开始时间开始时间:2025年7月18日17:09(北京时间)系统:Ubuntu22.04用户:itgather时区:A...

GNOME 46桌面环境发布,Ubuntu 24.04 LTS和Fedora 40率先预装

IT之家3月21日消息,GNOME团队今天发布公告,正式推出代号为“Kathmandu”的GNOME46桌面环境,并已经开放下载。Fedora40发行版将于4月发布,率先预装...

如果大家同意的话 Ubuntu可能很快就会有一个新的垃圾桶图标

Ubuntu贡献者目前正在构思一个新的垃圾桶图标,该图标最早可能在10月份Ubuntu25.10发布时出现在Dock栏中。关于Ubuntu垃圾桶图标外观的讨论在2019年持续进...

Ubuntu 25.10 通过更安全地获取时间来提供进一步的安全性提升

Canonical宣布将从Ubuntu25.10开始使用一款名为chrony的软件,以实现更安全的时间管理。最终用户无需过于担心这一变化,但它将增强系统安全性,尤其是在加密操作和证书验证方...

Linux 修改系统时间的两种方式

一:更新系统时间的方式1、手动修改通过相关工具来手动修改系统的时间。2、自动同步使用NTP自动同步系统时间。二:手动修改系统时间1、date工具作用:显示和设置系统时间选项:-d<字符串&g...

Ubuntu计划下版本为RISC-V设置RVA23基线,大量硬件无法升级

IT之家7月14日消息,主要Linux发行版之一的Ubuntu计划在其接下来的一个大版本25.10中将对RISC-V处理器的准入门槛设置从此前的RVA20配置文件更新至最新...

取消回复欢迎 发表评论: