import os
import time
import torch
import faiss
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
%matplotlib inlineGPU 加速
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# cudatransform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class MyDataset(Dataset):
    def __init__(self, data_path, transform=None):
        super().__init__()
        self.transform = transform
        self.data_path = data_path
        self.data = []
        
        img_path = os.path.join(data_path, 'img.txt')
        with open(img_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                line = line.strip()
                img_name = os.path.join(data_path, line)
                img = Image.open(img_name)
                if img.mode == 'RGB':
                    self.data.append(line)
    
    
    def __getitem__(self, idx):
        # take the data sample by it's index
        img_path = os.path.join(self.data_path, self.data[idx])
        # read image
        img = Image.open(img_path)
        # apply the transform
        if self.transform:
            img = self.transform(img)
            
        # return the image and index
        dict_data = {
            'index': idx,
            'img': img
        }
        return dict_data
    
    
    def __len__(self):
        return len(self.data)
img_folder = 'JPEGImages'
val_dataset = MyDataset(img_folder, transform=transform)
batch_size = 64
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print('Val_dataset: ', val_dataset.__len__())
print('iter: ', int(val_dataset.__len__()/batch_size)+1)
# Val_dataset:  17125
# iter:  268# 加载预训练模型
def load_model():
    model = models.resnet18(pretrained=True)
    model.to(device)
    model.eval()
    return model
# 定义 特征提取器
def feature_extract(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = model.avgpool(x)
    x = torch.flatten(x, 1)
    return x
model = load_model()
for idx, batch in enumerate(val_dataloader):
    img = batch['img']  # 图片数据表示 --> 图片特征
    index = batch['index']
    img = Img To Website(device)
    feature = feature_extract(model, img)
    feature = feature.data.cpu().numpy()
    imgs_path = [os.path.join(img_folder, val_dataset.data[i] + '.txt') for i in index]
    assert len(feature) == len(imgs_path)
    
    for i in range(len(imgs_path)):
        feature_list = [str(f) for f in feature[i]]
        img_path = imgs_path[i]
        
        with open(img_path, 'w', encoding='utf-8') as f:
            f.write(" ".join(feature_list))
    print('*' * 60)
    print(idx * batch_size)
# 获取图片特征¶
def img2feat(pic_file):
    feat = []
    with open(pic_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        feat = [float(f) for f in lines[0].split()]
    return feat
ids = []
data = []
img_folder = 'VOC2012'#'VOC2012_small/'
img_path = os.path.join(img_folder,'img.txt')
with open(img_path,'r',encoding='utf-8') as f:
    for line in f.readlines():
        img_name = line.strip()
        img_id = img_name.split('.')[0]
        pic_txt_file = os.path.join( img_folder,"{}.txt".format(img_name) )
        
        if not os.path.exists(pic_txt_file):
            continue
            
        feat = img2feat(pic_txt_file)
        ids.append(int(img_id))
        data.append(np.array(feat))
        
# 构建数据<id,data> 
ids = np.array(ids)
data = np.array(data).astype('float32')
d = 512 # feature 特征长度(模型的结果) 
print(" 特征向量记录数: ",data.shape)
print(" 特征向量ID的记录数:",ids.shape)
# 特征向量记录数:  (17125, 512)
# 特征向量ID的记录数: (17125,)# 创建图片特征索引 - 方案1
# index = faiss.index_factory(d,"IDMap,Flat")
# index.add_with_ids(data,ids)
# 创建图片特征索引-方案2(  资源有限,效果更好 )
###IDMap 支持add_with_ids 
###如果很在意,使用”PCARx,...,SQ8“ 如果保存全部原始数据的开销太大,可以用这个索引方式。包含三个部分,
# 1.降维
# 2.聚类
# 3.scalar 量化,每个向量编码为8bit 不支持GPU
index = faiss.index_factory(d, "IDMap,PCAR16,IVF50,SQ8") 
index.train(data)
index.add_with_ids(data, ids)
# 索引文件保存磁盘
faiss.write_index(index,'index_file.index') # 讲index保存index_file.index 的文件
# index = faiss.read_index("index_file.index")
# print(index.ntotal) # 查看索引库大小加载 Faiss Index 索引文件
index = faiss.read_index('index_file.index')
print('索引记录数:', index.ntotal)
# 索引记录数: 17125def index_search(feat,topK ):
    """
        feat: 检索的图片特征
        topK: 返回最高topK相似的图片        
    """
    
    feat = np.expand_dims( np.array(feat),axis=0 )
    feat = feat.astype('float32')
    
    start_time = time.time()
    dis,ind = index.search( feat,topK )
    end_time = time.time()
    
    print( 'index_search consume time:{}ms'.format(  int(end_time - start_time) * 1000  ) )
    return dis,ind # 距离,相似图片iddef visual_plot(ind,dis,topK,query_img = None):       
    # 相似照片
    cols = 4
    rows = int(topK / cols)
    idx = 0
    
    fig,axes = plt.subplots(rows,cols,figsize=(20 ,5*rows),tight_layout=True)
    #axes[0,0].imshow(query_img)
    
    for row in range(rows):
        for col in range(cols):
            _id = ind[0][idx]
            _dis = dis[0][idx]
            
            img_path = os.path.join(img_folder,'{}.jpg'.format(_id))
            #print(img_path)
            
            if query_img is not None and idx == 0:
                axes[row,col].imshow(query_img)
                axes[row,col].set_title( 'query',fontsize = 20  )
            else:
                img = plt.imread(  img_path   )
                axes[row,col].imshow(img)
                axes[row,col].set_title( 'matched_-{}_{}'.format(_id,int(_dis)) ,fontsize = 20  )
            idx+=1
            
    plt.savefig('pic')
img_folder = 'VOC2012/'
# img_id = '100211.jpg'
img_id = '100002.jpg'
topK = 20
img_path = os.path.join( img_folder,img_id)
print(img_path) # 查看  这个img_path 的相似图片
img = Image.open(img_path)
img = transform(img) # torch.Size([3, 224, 224])
img = img.unsqueeze(0) # torch.Size([1, 3, 224, 224])
img = img.to(device)
# 对我们的图片进行预测
with torch.no_grad():
    # 图片-> 图片特征
    print('1.图片特征提取')
    feature = feature_extract( model,img )
    # 特征-> 检索
    feature_list = feature.data.cpu().tolist()[0]
    print('2.基于特征的检索,从faiss获取相似度图片')
    # 相似图片可视化
    dis,ind = index_search( feature_list,topK=topK )
    print('ind = ',ind)
    print('3.图片可视化展示')
    # 当前图片
    query_img = plt.imread( img_path )
    visual_plot( ind,dis,topK,query_img)
# VOC2012/100002.jpg
# 1.图片特征提取
# 2.基于特征的检索,从faiss获取相似度图片
# index_search consume time:0ms
# ind =  [[100002 101430 116500 101585 116528 100507 104768 107651 112514 102820
#  112416 116458 106167 111781 116247 103299 103154 106012 115086 111156]]
# 3.图片可视化展示
原网址: 访问
创建于: 2025-08-23 14:37:49
目录: default
标签: 无
未标明原创文章均为采集,版权归作者所有,转载无需和我联系,请注明原出处,南摩阿彌陀佛,知识,不只知道,要得到
最新评论