视频抽取关键帧算法
可直接运行代码:
https://colab.research.google.com/drive/1iXgzIB8k-_ZpgCiGn-r9WgU7mvdRrKVB?usp=sharing
1. 计算帧间差分,取局部极大值(抽帧较少)
# -*- coding: utf-8 -*-
import cv2
import operator
import numpy as np
import matplotlib.pyplot as plt
import sys
from scipy.signal import argrelextrema
import osdef smooth(x, window_len=13, window='hanning'):print(len(x), window_len)s = np.r_[2 * x[0] - x[window_len:1:-1],x, 2 * x[-1] - x[-1:-window_len:-1]]if window == 'flat': # moving averagew = np.ones(window_len, 'd')else:w = getattr(np, window)(window_len)y = np.convolve(w / w.sum(), s, mode='same')return y[window_len - 1:-window_len + 1]class Frame:def __init__(self, id, diff):self.id = idself.diff = diffdef __lt__(self, other):if self.id == other.id:return self.id < other.idreturn self.id < other.iddef __gt__(self, other):return other.__lt__(self)def __eq__(self, other):return self.id == other.id and self.id == other.iddef __ne__(self, other):return not self.__eq__(other)def rel_change(a, b):x = (b - a) / max(a, b)print(x)return xif __name__ == "__main__":print(sys.executable)# Setting fixed threshold criteriaUSE_THRESH = False# fixed threshold valueTHRESH = 0.6# Setting fixed threshold criteriaUSE_TOP_ORDER = False# Setting local maxima criteriaUSE_LOCAL_MAXIMA = True# Number of top sorted framesNUM_TOP_FRAMES = 50# 遍历当前目录下的所有MP4文件for filename in os.listdir("."):if filename.endswith(".mp4"):videopath = filename # 当前目录下的MP4文件name = os.path.splitext(filename)[0] # 文件名(不带扩展名)dir = f"./extract_result/{name}/" # 保存关键帧的目录os.makedirs(dir, exist_ok=True) # 创建目录len_window = int(50) # 平滑窗口大小print("Target video :" + videopath)print("Frame save directory: " + dir)# load video and compute diff between framescap = cv2.VideoCapture(str(videopath))curr_frame = Noneprev_frame = Noneframe_diffs = []frames = []success, frame = cap.read()i = 0while success:luv = cv2.cvtColor(frame, cv2.COLOR_BGR2LUV)curr_frame = luvif curr_frame is not None and prev_frame is not None:# logic herediff = cv2.absdiff(curr_frame, prev_frame)diff_sum = np.sum(diff)diff_sum_mean = diff_sum / (diff.shape[0] * diff.shape[1])frame_diffs.append(diff_sum_mean)frame = Frame(i, diff_sum_mean)frames.append(frame)prev_frame = curr_framei = i + 1success, frame = cap.read()cap.release()# compute keyframekeyframe_id_set = set()if USE_TOP_ORDER:# sort the list in descending orderframes.sort(key=operator.attrgetter("diff"), reverse=True)for keyframe in frames[:NUM_TOP_FRAMES]:keyframe_id_set.add(keyframe.id)if USE_THRESH:print("Using Threshold")for i in range(1, len(frames)):if (rel_change(np.float(frames[i - 1].diff), np.float(frames[i].diff)) >= THRESH):keyframe_id_set.add(frames[i].id)if USE_LOCAL_MAXIMA:print("Using Local Maxima")diff_array = np.array(frame_diffs)sm_diff_array = smooth(diff_array, len_window)frame_indexes = np.asarray(argrelextrema(sm_diff_array, np.greater))[0]for i in frame_indexes:keyframe_id_set.add(frames[i - 1].id)# Plot the smoothed differencesplt.figure(figsize=(40, 20))plt.gca().xaxis.set_major_locator(plt.MaxNLocator(100)) # Set number of x-axis ticksplt.gca().yaxis.set_major_locator(plt.MaxNLocator(10)) # Optionally set number of y-axis ticksplt.stem(sm_diff_array)plt.savefig(dir + 'plot.png')# save all keyframes as imagecap = cv2.VideoCapture(str(videopath))curr_frame = Nonekeyframes = []success, frame = cap.read()idx = 0while success:if idx in keyframe_id_set:name = "keyframe_" + str(idx) + ".jpg"cv2.imwrite(dir + name, frame)keyframe_id_set.remove(idx)idx = idx + 1success, frame = cap.read()cap.release()print(f"关键帧已保存到:{dir}")
2.基于光流方法 (结果为保存帧信息的json文件)
import cv2
import json
import os
import numpy as npdef getInfo(sourcePath):cap = cv2.VideoCapture(sourcePath)info = {"framecount": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),"fps": cap.get(cv2.CAP_PROP_FPS),"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),"codec": int(cap.get(cv2.CAP_PROP_FOURCC))}cap.release()return infodef scale(img, xScale, yScale):return cv2.resize(img, None, fx=xScale, fy=yScale, interpolation=cv2.INTER_AREA)def resize(img, width, height):return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)def extract_cols(image, numCols):Z = image.reshape((-1, 3)).astype(np.float32)criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0)_, labels, centers = cv2.kmeans(Z, numCols, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)clusterCounts = [int(np.sum(labels == i)) for i in range(numCols)]rgbCenters = [center.tolist()[::-1] for center in centers]return [{"count": count, "col": col} for count, col in zip(clusterCounts, rgbCenters)]def calculateFrameStats(sourcePath, after_frame=0):cap = cv2.VideoCapture(sourcePath)data = {"frame_info": []}lastFrame = Nonewhile cap.isOpened():ret, frame = cap.read()if frame is None:breakframe_number = int(cap.get(cv2.CAP_PROP_POS_FRAMES) - 1)gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)gray = scale(gray, 0.25, 0.25)gray = cv2.GaussianBlur(gray, (9, 9), 0.0)if frame_number >= after_frame and lastFrame is not None:diff = cv2.absdiff(gray, lastFrame)diffMag = int(cv2.countNonZero(diff)) # 转为 Python intdata["frame_info"].append({"frame_number": frame_number, "diff_count": diffMag})lastFrame = graycap.release()diff_counts = [fi["diff_count"] for fi in data["frame_info"]]if diff_counts:data["stats"] = {"num": int(len(diff_counts)),"min": int(np.min(diff_counts)),"max": int(np.max(diff_counts)),"mean": float(np.mean(diff_counts)),"median": float(np.median(diff_counts)),"sd": float(np.std(diff_counts))}return datadef detectScenes(sourcePath, destPath, data):diff_threshold = data["stats"]["sd"] * 2.05 + data["stats"]["mean"]cap = cv2.VideoCapture(sourcePath)os.makedirs(destPath, exist_ok=True)for index, fi in enumerate(data["frame_info"]):if fi["diff_count"] < diff_threshold:continue# 将视频定位到关键帧并读取该帧cap.set(cv2.CAP_PROP_POS_FRAMES, fi["frame_number"])ret, frame = cap.read()if not ret:continue# 保存关键帧图像到目标文件夹frame_filename = os.path.join(destPath, f"key_frame_{fi['frame_number']}.jpg")cv2.imwrite(frame_filename, frame)cap.release()return data# 遍历当前目录下的所有MP4文件
for filename in os.listdir("."):if filename.endswith(".mp4"):source = filename # 当前目录下的MP4文件dest = os.path.splitext(filename)[0] # 以视频文件名创建目标文件夹name = os.path.splitext(filename)[0] # 文件名(不带扩展名)after_frame = 0 # 起始帧print(f"处理视频: {source}")info = getInfo(source)print("视频信息: ", info)# 计算帧差数据并检测场景变换data = calculateFrameStats(source, after_frame)data = detectScenes(source, dest, data)# 保存元数据data_fp = os.path.join(dest, f"{name}-meta.json")with open(data_fp, 'w') as f:json.dump(data, f, indent=4)keyframe_info_fp = os.path.join(dest, f"{name}-keyframe-meta.json")keyframeInfo = [frame_info for frame_info in data["frame_info"] if "dominant_cols" in frame_info]with open(keyframe_info_fp, 'w') as f:json.dump(keyframeInfo, f, indent=4)print(f"关键帧数据和图片已保存到:{dest}")
3. 基于颜色直方图聚类 (抽帧较多)
import cv2
import numpy as np
import os# 遍历当前目录下的所有MP4文件
for filename in os.listdir("."):if filename.endswith(".mp4"):video_path = filename # 当前目录下的MP4文件name = os.path.splitext(filename)[0] # 文件名(不带扩展名)output_folder = f'key_frames/{name}' # 保存关键帧的目录os.makedirs(output_folder, exist_ok=True) # 创建目录print(f"处理视频: {video_path}")print(f"关键帧保存路径: {output_folder}")cap = cv2.VideoCapture(video_path)if not cap.isOpened():raise ValueError("无法打开视频文件!")# 获取视频帧数num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))key = np.zeros(num) # 初始化关键帧数组cluster = np.zeros(num) # 初始化聚类数组cluster_count = np.zeros(num) # 各聚类的帧数量count = 0 # 聚类数量threshold = 0.91 # 阈值centrodR = np.zeros((num, 256)) # 聚类质心R的直方图centrodG = np.zeros((num, 256)) # 聚类质心G的直方图centrodB = np.zeros((num, 256)) # 聚类质心B的直方图# 读取首帧,形成第一个聚类ret, frame = cap.read()if not ret:raise ValueError("无法读取第一帧!")count += 1preCountR = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()preCountG = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()preCountB = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()cluster[0] = 1cluster_count[0] += 1centrodR[0] = preCountRcentrodG[0] = preCountGcentrodB[0] = preCountBvisit = 1# 遍历视频的其他帧for k in range(1, num):ret, frame = cap.read()if not ret:breaktmpCountR = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()tmpCountG = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()tmpCountB = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()clusterGroupId = 1maxSimilar = 0# 计算相似度for clusterCountI in range(visit, count + 1):sR = np.sum(np.minimum(centrodR[clusterCountI - 1], tmpCountR))sG = np.sum(np.minimum(centrodG[clusterCountI - 1], tmpCountG))sB = np.sum(np.minimum(centrodB[clusterCountI - 1], tmpCountB))dR = sR / np.sum(tmpCountR)dG = sG / np.sum(tmpCountG)dB = sB / np.sum(tmpCountB)d = 0.30 * dR + 0.59 * dG + 0.11 * dBif d > maxSimilar:clusterGroupId = clusterCountImaxSimilar = d# 判断是否加入现有聚类或形成新聚类if maxSimilar > threshold:centrodR[clusterGroupId - 1] = (centrodR[clusterGroupId - 1] * cluster_count[clusterGroupId - 1] + tmpCountR) / (cluster_count[clusterGroupId - 1] + 1)centrodG[clusterGroupId - 1] = (centrodG[clusterGroupId - 1] * cluster_count[clusterGroupId - 1] + tmpCountG) / (cluster_count[clusterGroupId - 1] + 1)centrodB[clusterGroupId - 1] = (centrodB[clusterGroupId - 1] * cluster_count[clusterGroupId - 1] + tmpCountB) / (cluster_count[clusterGroupId - 1] + 1)cluster_count[clusterGroupId - 1] += 1cluster[k] = clusterGroupIdelse:count += 1visit += 1cluster_count[count - 1] += 1centrodR[count - 1] = tmpCountRcentrodG[count - 1] = tmpCountGcentrodB[count - 1] = tmpCountBcluster[k] = countcap.release()# 提取每个聚类的关键帧max_similarity = np.zeros(count)frame_indices = np.zeros(count, dtype=int)cap = cv2.VideoCapture(video_path)frame_number = 0while True:ret, frame = cap.read()if not ret:breaktmpCountR = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()tmpCountG = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()tmpCountB = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()sR = np.sum(np.minimum(centrodR[int(cluster[frame_number]) - 1], tmpCountR))sG = np.sum(np.minimum(centrodG[int(cluster[frame_number]) - 1], tmpCountG))sB = np.sum(np.minimum(centrodB[int(cluster[frame_number]) - 1], tmpCountB))dR = sR / np.sum(tmpCountR)dG = sG / np.sum(tmpCountG)dB = sB / np.sum(tmpCountB)d = 0.30 * dR + 0.59 * dG + 0.11 * dBif d > max_similarity[int(cluster[frame_number]) - 1]:max_similarity[int(cluster[frame_number]) - 1] = dframe_indices[int(cluster[frame_number]) - 1] = frame_numberframe_number += 1cap.release()# 保存关键帧到文件夹cap = cv2.VideoCapture(video_path)for idx in frame_indices:cap.set(cv2.CAP_PROP_POS_FRAMES, idx)ret, img = cap.read()if ret:frame_filename = os.path.join(output_folder, f'key_frame_{int(idx)}.jpg')cv2.imwrite(frame_filename, img)cap.release()print(f"关键帧已保存到:{output_folder}")