tSNE算法的原理和Python代码实现详解

2024年 1月 23日 58.6k 0

T分布随机邻域嵌入(t-SNE)算法原理及Python代码实现t-SNE算法

T分布随机邻域嵌入(t-SNE),是一种用于可视化的无监督机器学习算法,使用非线性降维技术,根据数据点与特征的相似性,试图最小化高维和低维空间中这些条件概率(或相似性)之间的差异,以在低维空间中完美表示数据点。

因此,t-SNE擅长在二维或三维的低维空间中嵌入高维数据以进行可视化。需要注意的是,t-SNE使用重尾分布来计算低维空间中两点之间的相似度,而不是高斯分布,这有助于解决拥挤和优化问题。而且离群值不影响t-SNE。

t-SNE算法步骤

1.找出高维空间中相邻点之间的成对相似性。

2.根据高维空间中点的成对相似性,将高维空间中的每个点映射到低维映射。

3.使用基于Kullback-Leibler散度(KL散度)的梯度下降找到最小化条件概率分布之间的不匹配的低维数据表示。

4.使用Student-t分布计算低维空间中两点之间的相似度。

MNIST数据集上实现t-SNE的Python代码

导入模块

# Importing Necessary Modules.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

登录后复制

读取数据

# Reading the data using pandas
df = pd.read_csv('mnist_train.csv')

# print first five rows of df
print(df.head(4))

# save the labels into a variable l.
l = df['label']

# Drop the label feature and store the pixel data in d.
d = df.drop("label", axis = 1)

登录后复制

数据预处理

# Data-preprocessing: Standardizing the data
from sklearn.preprocessing import StandardScaler

standardized_data = StandardScaler().fit_transform(data)
print(standardized_data.shape)

登录后复制

输出

# TSNE
# Picking the top 1000 points as TSNE
# takes a lot of time for 15K points
data_1000 = standardized_data[0:1000, :]
labels_1000 = labels[0:1000]

model = TSNE(n_components = 2, random_state = 0)
# configuring the parameters
# the number of components = 2
# default perplexity = 30
# default learning rate = 200
# default Maximum number of iterations
# for the optimization = 1000

tsne_data = model.fit_transform(data_1000)

# creating a new data frame which
# help us in plotting the result data
tsne_data = np.vstack((tsne_data.T, labels_1000)).T
tsne_df = pd.DataFrame(data = tsne_data,
columns =("Dim_1", "Dim_2", "label"))

# Plotting the result of tsne
sn.FacetGrid(tsne_df, hue ="label", size = 6).map(
plt.scatter, 'Dim_1', 'Dim_2').add_legend()

plt.show()

登录后复制

以上就是t-SNE算法的原理和Python代码实现详解的详细内容,更多请关注每日运维网(www.mryunwei.com)其它相关文章!

相关文章

JavaScript2024新功能:Object.groupBy、正则表达式v标志
PHP trim 函数对多字节字符的使用和限制
新函数 json_validate() 、randomizer 类扩展…20 个PHP 8.3 新特性全面解析
使用HTMX为WordPress增效:如何在不使用复杂框架的情况下增强平台功能
为React 19做准备:WordPress 6.6用户指南
如何删除WordPress中的所有评论

发布评论