当前位置: 首页 > news >正文

【大模型】3D因果卷积动图怎么画

目录

写在前面

一、1D卷积与1D因果卷积

二、2D卷积与2D因果卷积

三、3D卷积与3D因果卷积


写在前面

        这篇文章主要记录这篇文章中画图用到的代码,3d图形使用了mpl_toolkits.mplot3d.art3d库。

        mpl_toolkits.mplot3d.art3d 是 Matplotlib 中用于处理 3D 图形元素的模块,它提供了将 2D 图形对象转换为 3D 空间显示的功能。

这个模块主要包含以下内容:
1.将 2D 艺术对象转换为 3D 的类:
Line3D:3D 线条
Path3D:3D 路径
PathPatch3D:3D 路径补丁
Poly3DCollection:3D 多边形集合
Line3DCollection:3D 线条集合

2.转换函数:
path_to_3d_segment:将 2D 路径转换为 3D 线段
path_to_3d_segment_with_codes:带代码的 2D 路径转 3D 线段
paths_to_3d_segments:多个 2D 路径转 3D 线段
paths_to_3d_segments_with_codes:带代码的多路径转 3D 线段

一、1D卷积与1D因果卷积

代码如下:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import imageio
import osdef visualize_1d_convolution(output_path="1d_convolution.gif"):# Parametersinput_length = 10kernel_size = 3output_length = input_length - kernel_size + 1# Create temporary directoryos.makedirs("data/tmp", exist_ok=True)frames = []for i in range(output_length):fig, ax = plt.subplots(figsize=(10, 4))# Draw input arrayinput_x = np.arange(input_length)ax.plot(input_x, np.zeros_like(input_x), 'bo-', markersize=10, label='Input')for x in input_x:ax.text(x, 0.02, f'x{x}', ha='center', va='bottom', color='blue')# Draw kernelkernel_x = np.arange(i, i + kernel_size)ax.plot(kernel_x, np.zeros_like(kernel_x), 'ro-', markersize=10, label='Kernel')for x in kernel_x:ax.text(x, -0.02, f'w{x - i}', ha='center', va='top', color='red')# Draw output arrayoutput_x = np.arange(output_length)output_y = np.ones(output_length) * 0.5ax.plot(output_x, output_y, 'go-', markersize=10, label='Output')# Highlight current output positionax.plot(i, 0.5, 'yo', markersize=15, alpha=0.5)for x in output_x:if x <= i:ax.text(x, 0.52, f'y{x}', ha='center', va='bottom', color='green')# Draw operation linesfor k in range(kernel_size):ax.plot([i + k, i], [0, 0.5], 'k:', alpha=0.3)# Annotationsax.text(input_length / 2, -0.2, f'Input length: {input_length}', ha='center')ax.text(input_length / 2, -0.3, f'Kernel size: {kernel_size}', ha='center')ax.text(input_length / 2, -0.4, f'Output length: {output_length}', ha='center')# Formattingax.set_xlim(-0.5, max(input_length, output_length) + 0.5)ax.set_ylim(-0.5, 0.7)ax.set_title(f'1D Convolution - Step {i + 1}/{output_length}\n'f'Computing output y[{i}] = x[{i}:{i + kernel_size}] · w[0:{kernel_size}]')ax.legend(loc='upper right')ax.axis('off')# Save frametemp_file = f"data/tmp/temp_frame_{i}.png"plt.savefig(temp_file, dpi=100, bbox_inches='tight')frames.append(imageio.imread(temp_file))plt.close()os.remove(temp_file)# Create GIFimageio.mimsave(output_path, frames, duration=2.0)print(f"✅ GIF saved to: {output_path}")def visualize_1d_causal_convolution(output_path="1d_causal_convolution_window2.gif"):# Parametersinput_length = 10kernel_size = 2  # Changed to window size 2padding = kernel_size - 1  # Causal padding (left padding only)output_length = input_length  # Same length as input due to padding# Create temporary directoryos.makedirs("data/tmp", exist_ok=True)frames = []# Padded input (with zeros on the left)padded_input = np.pad(np.arange(input_length), (padding, 0), constant_values=-1)for i in range(output_length):fig, ax = plt.subplots(figsize=(12, 5))# Draw original input (without padding)original_x = np.arange(input_length)ax.plot(original_x, np.zeros_like(original_x), 'bo-', markersize=10, label='Original Input')for x in original_x:ax.text(x, 0.02, f'x{x}', ha='center', va='bottom', color='blue')# Draw padded input (with causal padding)padded_x = np.arange(len(padded_input)) - paddingmask = padded_input >= 0  # Only show actual values (not padding)ax.plot(padded_x[mask], np.zeros_like(padded_x)[mask], 'o-',color='lightblue', markersize=8, label='Padded Input')# Draw kernel position (causal)kernel_start = ikernel_x = np.arange(kernel_start, kernel_start + kernel_size)valid_kernel_pos = kernel_x[(kernel_x >= 0) & (kernel_x < len(padded_input))]ax.plot(valid_kernel_pos - padding, np.zeros_like(valid_kernel_pos),'ro-', markersize=10, label='Kernel')# Annotate kernel weightsfor pos in valid_kernel_pos:k = pos - kernel_startax.text(pos - padding, -0.02, f'w{k}', ha='center', va='top', color='red')# Draw output arrayoutput_x = np.arange(output_length)output_y = np.ones(output_length) * 0.5ax.plot(output_x, output_y, 'go-', markersize=10, label='Output')# Highlight current output positionax.plot(i, 0.5, 'yo', markersize=15, alpha=0.5)# Fill in computed outputsfor x in output_x:if x <= i:ax.text(x, 0.52, f'y{x}', ha='center', va='bottom', color='green')# Draw operation lines (only to valid inputs)for k in range(kernel_size):input_pos = i + k - (kernel_size - 1)  # Causal: only past and currentif 0 <= input_pos < input_length:ax.plot([input_pos, i], [0, 0.5], 'k:', alpha=0.3)# Annotationsax.text(input_length / 2, -0.25, f'Causal Convolution (kernel_size={kernel_size})',ha='center', fontsize=12)ax.text(0, -0.4, f'• Each output y[t] depends only on x[t-1:t]', ha='left')  # Updated for window=2ax.text(0, -0.5, f'• Output length equals input length ({output_length})', ha='left')# Formattingax.set_xlim(-padding - 0.5, input_length + 0.5)ax.set_ylim(-0.6, 0.7)ax.set_title(f'1D Causal Convolution (Window=2) - Step {i + 1}/{output_length}\n'f'Computing y[{i}] = x[{max(0, i - kernel_size + 1)}:{i + 1}] · w[0:{kernel_size}]')ax.legend(loc='upper right')ax.axis('off')# Save frametemp_file = f"data/tmp/temp_frame_{i}.png"plt.savefig(temp_file, dpi=100, bbox_inches='tight')frames.append(imageio.imread(temp_file))plt.close()os.remove(temp_file)# Create GIFimageio.mimsave(output_path, frames, duration=1.0)print(f"✅ GIF saved to: {output_path}")if __name__ == '__main__':# visualize_1d_convolution("1d_conv_example.gif")visualize_1d_causal_convolution('1d_causal_conv_example.gif')

二、2D卷积与2D因果卷积

代码如下:

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import imageio
import osdef visualize_2d_convolution(output_path="2d_convolution.gif"):def draw_grid(ax, origin, size, color='gray', linewidth=0.5):x0, y0 = origindx, dy = sizefor x in range(dx + 1):ax.plot([x0 + x, x0 + x], [y0, y0 + dy], color=color, linewidth=linewidth)for y in range(dy + 1):ax.plot([x0, x0 + dx], [y0 + y, y0 + y], color=color, linewidth=linewidth)# Parametersinput_h, input_w = 6, 6kernel_h, kernel_w = 3, 3output_h = input_h - kernel_h + 1output_w = input_w - kernel_w + 1output_offset = input_w + 2  # Space between input and outputframes = []os.makedirs("data/tmp", exist_ok=True)# Create frames in row-major order (left to right, top to bottom)for h in range(output_h):for w in range(output_w):fig, ax = plt.subplots(figsize=(10, 5))# Input griddraw_grid(ax, (0, 0), (input_h, input_w), 'skyblue')ax.add_patch(Rectangle((0, 0), input_h, input_w,fill=False, edgecolor='skyblue', linewidth=2))# Input dimensions - labeled as t (time) and v (vector)ax.plot([0, input_h], [-0.5, -0.5], color='black')ax.text(input_h / 2, -0.8, 't', color='black', ha='center')  # Changed to 't'ax.plot([-0.5, -0.5], [0, input_w], color='black')ax.text(-0.8, input_w / 2, 'v', color='black', va='center')  # Changed to 'v'# Kernel position (highlighted cells)for kh in range(kernel_h):for kw in range(kernel_w):ax.add_patch(Rectangle((w + kw, h + kh), 1, 1,  # Switched to row-majorfacecolor='orange', alpha=0.5, edgecolor='black'))# Output griddraw_grid(ax, (output_offset, 0), (output_h, output_w), 'green')ax.add_patch(Rectangle((output_offset, 0), output_h, output_w,fill=False, edgecolor='green', linewidth=2))# Output dimensions - labeled as t (time) and v (vector)ax.plot([output_offset, output_offset + output_h], [-0.5, -0.5], color='black')ax.text(output_offset + output_h / 2, -0.8, 't', color='black', ha='center')  # Changed to 't'ax.plot([output_offset - 0.5, output_offset - 0.5], [0, output_w], color='black')ax.text(output_offset - 0.8, output_w / 2, 'v', color='black', va='center')  # Changed to 'v'# Current output cell (row-major order)ax.add_patch(Rectangle((output_offset + w, h), 1, 1,  # Switched to row-majorfacecolor='red', alpha=0.7, edgecolor='darkred'))# Kernel dimensions - labeled with kernel sizesstart = (w, h)  # Switched to row-majorax.plot([start[0], start[0] + kernel_w], [start[1] - 0.3, start[1] - 0.3], color='black')ax.text(start[0] + kernel_w / 2, start[1] - 0.6, f'{kernel_w}', color='black', ha='center')ax.plot([start[0] - 0.3, start[0] - 0.3], [start[1], start[1] + kernel_h], color='black')ax.text(start[0] - 0.6, start[1] + kernel_h / 2, f'{kernel_h}', color='black', va='center')# Output cell dimensionsout_start = (output_offset + w, h)  # Switched to row-majorax.plot([out_start[0], out_start[0] + 1], [out_start[1] - 0.3, out_start[1] - 0.3], color='black')ax.text(out_start[0] + 0.5, out_start[1] - 0.6, '1', color='black', ha='center')ax.plot([out_start[0] - 0.3, out_start[0] - 0.3], [out_start[1], out_start[1] + 1], color='black')ax.text(out_start[0] - 0.6, out_start[1] + 0.5, '1', color='black', va='center')# Connection lineax.plot([w + kernel_w / 2, output_offset + w + 0.5],  # Switched to row-major[h + kernel_h / 2, h + 0.5],color='purple', linestyle='--', alpha=0.5)# Formattingax.set_xlim(-1, output_offset + output_w + 1)  # Adjusted for row-majorax.set_ylim(-1, max(input_h, output_h) + 1)  # Adjusted for row-majorax.set_aspect('equal')ax.axis('off')plt.title(f"2D Convolution (Row-major order)\nInput (blue) → Kernel (orange) → Output (green)\nCurrent position: V={h}, T={w}",pad=20)temp_file = f"data/tmp/temp_frame_{h}_{w}.png"plt.savefig(temp_file, dpi=100, bbox_inches='tight')frames.append(imageio.imread(temp_file))plt.close()os.remove(temp_file)imageio.mimsave(output_path, frames, duration=0.7)print(f"✅ GIF saved to: {output_path}")def visualize_2d_causal_convolution(output_path="2d_causal_convolution.gif"):def draw_grid(ax, origin, size, color='gray', linewidth=0.5):x0, y0 = origindx, dy = sizefor x in range(dx + 1):ax.plot([x0 + x, x0 + x], [y0, y0 + dy], color=color, linewidth=linewidth)for y in range(dy + 1):ax.plot([x0, x0 + dx], [y0 + y, y0 + y], color=color, linewidth=linewidth)# Parametersinput_v, input_t = 6, 6  # v=vector dimension, t=time dimensionkernel_v, kernel_t = 3, 3output_v = input_v - kernel_v + 1output_t = input_t - kernel_t + 1output_offset = input_t + 2  # Space between input and outputframes = []os.makedirs("data/tmp", exist_ok=True)# Create frames in row-major order (left to right, top to bottom)for v in range(output_v):for t in range(output_t):fig, ax = plt.subplots(figsize=(10, 5))# Input griddraw_grid(ax, (0, 0), (input_t, input_v), 'skyblue')ax.add_patch(Rectangle((0, 0), input_t, input_v,fill=False, edgecolor='skyblue', linewidth=2))# Input dimensionsax.plot([0, input_t], [-0.5, -0.5], color='black')ax.text(input_t / 2, -0.8, 't', color='black', ha='center')ax.plot([-0.5, -0.5], [0, input_v], color='black')ax.text(-0.8, input_v / 2, 'v', color='black', va='center')# Kernel position - causal constraint (only past time information)for kv in range(kernel_v):for kt in range(kernel_t):# Only highlight if we're not looking into the "future"if t + kt <= t + kernel_t//2:  # Center of kernel is current positionax.add_patch(Rectangle((t + kt, v + kv), 1, 1,facecolor='orange', alpha=0.5, edgecolor='black'))else:# Gray out "future" informationax.add_patch(Rectangle((t + kt, v + kv), 1, 1,facecolor='gray', alpha=0.2, edgecolor='black'))# Output griddraw_grid(ax, (output_offset, 0), (output_t, output_v), 'green')ax.add_patch(Rectangle((output_offset, 0), output_t, output_v,fill=False, edgecolor='green', linewidth=2))# Output dimensionsax.plot([output_offset, output_offset + output_t], [-0.5, -0.5], color='black')ax.text(output_offset + output_t / 2, -0.8, 't', color='black', ha='center')ax.plot([output_offset - 0.5, output_offset - 0.5], [0, output_v], color='black')ax.text(output_offset - 0.8, output_v / 2, 'v', color='black', va='center')# Current output cellax.add_patch(Rectangle((output_offset + t, v), 1, 1,facecolor='red', alpha=0.7, edgecolor='darkred'))# Kernel dimensionsstart = (t, v)ax.plot([start[0], start[0] + kernel_t], [start[1] - 0.3, start[1] - 0.3], color='black')ax.text(start[0] + kernel_t / 2, start[1] - 0.6, f'{kernel_t}', color='black', ha='center')ax.plot([start[0] - 0.3, start[0] - 0.3], [start[1], start[1] + kernel_v], color='black')ax.text(start[0] - 0.6, start[1] + kernel_v / 2, f'{kernel_v}', color='black', va='center')# Causal constraint indicatorax.plot([t + kernel_t//2 + 0.5, t + kernel_t//2 + 0.5], [v - 0.5, v + kernel_v + 0.5],color='red', linestyle='--', linewidth=1)ax.text(t + kernel_t//2 + 0.7, v + kernel_v/2, 'Causal\nConstraint',color='red', va='center', fontsize=8)# Connection lineax.plot([t + kernel_t//2, output_offset + t + 0.5],[v + kernel_v/2, v + 0.5],color='purple', linestyle='--', alpha=0.5)# Formattingax.set_xlim(-1, output_offset + output_t + 1)ax.set_ylim(-1, max(input_v, output_v) + 1)ax.set_aspect('equal')ax.axis('off')plt.title(f"2D Causal Convolution\nInput (blue) → Kernel (orange=allowed, gray=blocked) → Output (green)\nCurrent position: v={v}, t={t}",pad=20)temp_file = f"data/tmp/temp_frame_{v}_{t}.png"plt.savefig(temp_file, dpi=100, bbox_inches='tight')frames.append(imageio.imread(temp_file))plt.close()os.remove(temp_file)imageio.mimsave(output_path, frames, duration=0.7)print(f"✅ GIF saved to: {output_path}")if __name__ == '__main__':# Example usage# visualize_2d_convolution(output_path="2d_convolution.gif")visualize_2d_causal_convolution(output_path="2d_causal_convolution.gif")

三、3D卷积与3D因果卷积

        代码如下:

import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import imageio.v2 as imageiodef visualize_3d_convolution(output_path="3d_convolution_separated.gif"):def draw_box(ax, origin, size, color, alpha=0.3, edge_color='k'):x, y, z = origindx, dy, dz = sizebox = [[(x, y, z), (x + dx, y, z), (x + dx, y + dy, z), (x, y + dy, z)],[(x, y, z + dz), (x + dx, y, z + dz), (x + dx, y + dy, z + dz), (x, y + dy, z + dz)],[(x, y, z), (x + dx, y, z), (x + dx, y, z + dz), (x, y, z + dz)],[(x, y + dy, z), (x + dx, y + dy, z), (x + dx, y + dy, z + dz), (x, y + dy, z + dz)],[(x, y, z), (x, y + dy, z), (x, y + dy, z + dz), (x, y, z + dz)],[(x + dx, y, z), (x + dx, y + dy, z), (x + dx, y + dy, z + dz), (x + dx, y, z + dz)],]ax.add_collection3d(Poly3DCollection(box, facecolors=color, alpha=alpha, edgecolor=edge_color))def draw_grid(ax, origin, size, color='gray', linewidth=0.2):x0, y0, z0 = origindx, dy, dz = sizefor x in range(dx + 1):for y in range(dy + 1):ax.plot([x0 + x, x0 + x], [y0 + y, y0 + y], [z0, z0 + dz], color=color, linewidth=linewidth)for x in range(dx + 1):for z in range(dz + 1):ax.plot([x0 + x, x0 + x], [y0, y0 + dy], [z0 + z, z0 + z], color=color, linewidth=linewidth)for y in range(dy + 1):for z in range(dz + 1):ax.plot([x0, x0 + dx], [y0 + y, y0 + y], [z0 + z, z0 + z], color=color, linewidth=linewidth)# 参数设置input_d, input_h, input_w = 6, 6, 6kernel_d, kernel_h, kernel_w = 3, 3, 3output_d = input_d - kernel_d + 1output_h = input_h - kernel_h + 1output_w = input_w - kernel_w + 1depth_scale = 10output_offset = input_d * depth_scale + 200frames = []os.makedirs("data/tmp", exist_ok=True)for d in range(output_d):for h in range(output_h):for w in range(output_w):fig = plt.figure(figsize=(8, 8))ax = fig.add_subplot(111, projection='3d')# 输入框和网格draw_box(ax, (0, 0, 0), (input_d * depth_scale, input_h, input_w), 'skyblue', alpha=0.1)draw_grid(ax, (0, 0, 0), (input_d * depth_scale, input_h, input_w), color='skyblue')# 输入张量维度标注input_origin = (0, 0, 0)ax.plot([input_origin[0], input_origin[0] + input_d * depth_scale],[input_origin[1], input_origin[1]],[input_origin[2], input_origin[2]], color='black')ax.text(input_origin[0] + (input_d * depth_scale) / 2,input_origin[1] - 0.8,input_origin[2] - 0.5, f'{input_d}', color='black', fontsize=10)ax.plot([input_origin[0], input_origin[0]],[input_origin[1], input_origin[1] + input_h],[input_origin[2], input_origin[2]], color='black')ax.text(input_origin[0] - 5,input_origin[1] + input_h / 2,input_origin[2] - 0.5, f'{input_h}', color='black', fontsize=10)ax.plot([input_origin[0], input_origin[0]],[input_origin[1], input_origin[1]],[input_origin[2], input_origin[2] + input_w], color='black')ax.text(input_origin[0] - 5,input_origin[1] - 0.5,input_origin[2] + input_w / 2, f'{input_w}', color='black', fontsize=10)# 卷积核位置# draw_box(ax, (d * depth_scale, h, w), (kernel_d * depth_scale, kernel_h, kernel_w), 'orange', alpha=0.4, edge_color='black')# 卷积核for kd in range(kernel_d):for kh in range(kernel_h):for kw in range(kernel_w):global_d = d + kdbox_color = 'orange'edge = 'black'draw_box(ax, (global_d * depth_scale, h + kh, w + kw), (depth_scale, 1, 1), box_color,alpha=0.5, edge_color=edge)# 输出框和网格draw_box(ax, (output_offset, 0, 0), (output_d * depth_scale, output_h, output_w), 'green', alpha=0.1)draw_grid(ax, (output_offset, 0, 0), (output_d * depth_scale, output_h, output_w), color='green')# 输出张量维度标注output_origin = (output_offset, 0, 0)ax.plot([output_origin[0], output_origin[0] + output_d * depth_scale],[output_origin[1], output_origin[1]],[output_origin[2], output_origin[2]], color='black')ax.text(output_origin[0] + (output_d * depth_scale) / 2,output_origin[1] - 0.8,output_origin[2] - 0.5, f'{output_d}', color='black', fontsize=10)ax.plot([output_origin[0], output_origin[0]],[output_origin[1], output_origin[1] + output_h],[output_origin[2], output_origin[2]], color='black')ax.text(output_origin[0] - 5,output_origin[1] + output_h / 2,output_origin[2] - 0.5, f'{output_h}', color='black', fontsize=10)ax.plot([output_origin[0], output_origin[0]],[output_origin[1], output_origin[1]],[output_origin[2], output_origin[2] + output_w], color='black')ax.text(output_origin[0] - 5,output_origin[1] - 0.5,output_origin[2] + output_w / 2, f'{output_w}', color='black', fontsize=10)# 输出值单元格draw_box(ax, (output_offset + d * depth_scale, h, w), (1 * depth_scale, 1, 1), 'red', alpha=0.9, edge_color='darkred')# 卷积核维度数字标注start = (d * depth_scale, h, w)ax.plot([start[0], start[0] + kernel_d], [start[1], start[1]], [start[2], start[2]], color='black')ax.text(start[0] + kernel_d / 2, start[1] - 0.5, start[2] - 0.5, f'{kernel_d}', color='black', fontsize=10)ax.plot([start[0], start[0]], [start[1], start[1] + kernel_h], [start[2], start[2]], color='black')ax.text(start[0] - 5, start[1] + kernel_h / 2, start[2] - 0.5, f'{kernel_h}', color='black', fontsize=10)ax.plot([start[0], start[0]], [start[1], start[1]], [start[2], start[2] + kernel_w], color='black')ax.text(start[0] - 5, start[1] - 0.5, start[2] + kernel_w / 2, f'{kernel_w}', color='black', fontsize=10)# 输出单元格维度数字标注out_start = (output_offset + d * depth_scale, h, w)ax.plot([out_start[0], out_start[0] + depth_scale], [out_start[1], out_start[1]], [out_start[2], out_start[2]], color='black')ax.text(out_start[0] + depth_scale / 2, out_start[1] - 0.5, out_start[2] - 0.5, '1', color='black', fontsize=10)ax.plot([out_start[0], out_start[0]], [out_start[1], out_start[1] + 1], [out_start[2], out_start[2]], color='black')ax.text(out_start[0] - 5, out_start[1] + 0.5, out_start[2] - 0.5, '1', color='black', fontsize=10)ax.plot([out_start[0], out_start[0]], [out_start[1], out_start[1]], [out_start[2], out_start[2] + 1], color='black')ax.text(out_start[0] - 5, out_start[1] - 0.5, out_start[2] + 0.5, '1', color='black', fontsize=10)# 设置范围和视角ax.set_xlim(-1, output_offset + output_d * depth_scale + 2)ax.set_ylim(-1, input_h + 1)ax.set_zlim(-1, input_w + 1)ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])ax.set_xlabel(''); ax.set_ylabel(''); ax.set_zlabel('')ax.grid(False); ax.set_axis_off()ax.view_init(25, 140)plt.title(f"3D Convolution\nInput (blue grid) → Kernel (orange) → Output (green grid)\nCurrent position: D={d}, H={h}, W={w}",pad=20)temp_file = f"data/tmp/temp_frame_{d}_{h}_{w}.png"plt.savefig(temp_file, dpi=100, bbox_inches='tight')frames.append(imageio.imread(temp_file))plt.close()os.remove(temp_file)imageio.mimsave(output_path, frames, duration=0.9)print(f"✅ GIF saved to: {output_path}")def visualize_3d_causal_convolution(output_path="3d_causal_convolution.gif"):def draw_box(ax, origin, size, color, alpha=0.3, edge_color='k'):x, y, z = origindx, dy, dz = sizebox = [[(x, y, z), (x + dx, y, z), (x + dx, y + dy, z), (x, y + dy, z)],[(x, y, z + dz), (x + dx, y, z + dz), (x + dx, y + dy, z + dz), (x, y + dy, z + dz)],[(x, y, z), (x + dx, y, z), (x + dx, y, z + dz), (x, y, z + dz)],[(x, y + dy, z), (x + dx, y + dy, z), (x + dx, y + dy, z + dz), (x, y + dy, z + dz)],[(x, y, z), (x, y + dy, z), (x, y + dy, z + dz), (x, y, z + dz)],[(x + dx, y, z), (x + dx, y + dy, z), (x + dx, y + dy, z + dz), (x + dx, y, z + dz)],]ax.add_collection3d(Poly3DCollection(box, facecolors=color, alpha=alpha, edgecolor=edge_color))def draw_grid(ax, origin, size, color='gray', linewidth=0.2):x0, y0, z0 = origindx, dy, dz = sizefor x in range(dx + 1):for y in range(dy + 1):ax.plot([x0 + x, x0 + x], [y0 + y, y0 + y], [z0, z0 + dz], color=color, linewidth=linewidth)for x in range(dx + 1):for z in range(dz + 1):ax.plot([x0 + x, x0 + x], [y0, y0 + dy], [z0 + z, z0 + z], color=color, linewidth=linewidth)for y in range(dy + 1):for z in range(dz + 1):ax.plot([x0, x0 + dx], [y0 + y, y0 + y], [z0 + z, z0 + z], color=color, linewidth=linewidth)input_d, input_h, input_w = 6, 6, 6kernel_d, kernel_h, kernel_w = 3, 3, 3output_d = input_d - kernel_d + 1output_h = input_h - kernel_h + 1output_w = input_w - kernel_w + 1depth_scale = 10output_offset = input_d * depth_scale + 300frames = []os.makedirs("data/tmp", exist_ok=True)for d in range(output_d):for h in range(output_h):for w in range(output_w):fig = plt.figure(figsize=(10, 10))ax = fig.add_subplot(111, projection='3d')draw_box(ax, (0, 0, 0), (input_d * depth_scale, input_h, input_w), 'skyblue', alpha=0.1)draw_grid(ax, (0, 0, 0), (input_d * depth_scale, input_h, input_w), color='skyblue')# 输入维度标注input_origin = (0, 0, 0)ax.plot([input_origin[0], input_origin[0] + input_d * depth_scale], [input_origin[1], input_origin[1]], [input_origin[2], input_origin[2]], color='black')ax.text(input_origin[0] + (input_d * depth_scale) / 2, input_origin[1] - 0.8, input_origin[2] - 0.5, f'{input_d}', color='black', fontsize=10)ax.plot([input_origin[0], input_origin[0]], [input_origin[1], input_origin[1] + input_h], [input_origin[2], input_origin[2]], color='black')ax.text(input_origin[0] - 5, input_origin[1] + input_h / 2, input_origin[2] - 0.5, f'{input_h}', color='black', fontsize=10)ax.plot([input_origin[0], input_origin[0]], [input_origin[1], input_origin[1]], [input_origin[2], input_origin[2] + input_w], color='black')ax.text(input_origin[0] - 5, input_origin[1] - 0.5, input_origin[2] + input_w / 2, f'{input_w}', color='black', fontsize=10)# 卷积核for kd in range(kernel_d):for kh in range(kernel_h):for kw in range(kernel_w):global_d = d + kdbox_color = 'gray' if global_d-1 > d else 'orange'edge = 'gray' if global_d-1 > d else 'black'draw_box(ax, (global_d * depth_scale, h + kh, w + kw), (depth_scale, 1, 1), box_color, alpha=0.5, edge_color=edge)# 卷积核维度标注(只在第一个非未来点标注)kernel_start = (d * depth_scale, h, w)ax.plot([kernel_start[0], kernel_start[0] + kernel_d], [kernel_start[1], kernel_start[1]], [kernel_start[2], kernel_start[2]], color='black')ax.text(kernel_start[0] + kernel_d / 2, kernel_start[1] - 0.5, kernel_start[2] - 0.5, f'{kernel_d}', color='black', fontsize=10)ax.plot([kernel_start[0], kernel_start[0]], [kernel_start[1], kernel_start[1] + kernel_h], [kernel_start[2], kernel_start[2]], color='black')ax.text(kernel_start[0] - 5, kernel_start[1] + kernel_h / 2, kernel_start[2] - 0.5, f'{kernel_h}', color='black', fontsize=10)ax.plot([kernel_start[0], kernel_start[0]], [kernel_start[1], kernel_start[1]], [kernel_start[2], kernel_start[2] + kernel_w], color='black')ax.text(kernel_start[0] - 5, kernel_start[1] - 0.5, kernel_start[2] + kernel_w / 2, f'{kernel_w}', color='black', fontsize=10)# 输出网格与框体draw_box(ax, (output_offset, 0, 0), (output_d * depth_scale, output_h, output_w), 'green', alpha=0.1)draw_grid(ax, (output_offset, 0, 0), (output_d * depth_scale, output_h, output_w), color='green')# 输出维度标注output_origin = (output_offset, 0, 0)ax.plot([output_origin[0], output_origin[0] + output_d * depth_scale], [output_origin[1], output_origin[1]], [output_origin[2], output_origin[2]], color='black')ax.text(output_origin[0] + (output_d * depth_scale) / 2, output_origin[1] - 0.8, output_origin[2] - 0.5, f'{output_d}', color='black', fontsize=10)ax.plot([output_origin[0], output_origin[0]], [output_origin[1], output_origin[1] + output_h], [output_origin[2], output_origin[2]], color='black')ax.text(output_origin[0] - 5, output_origin[1] + output_h / 2, output_origin[2] - 0.5, f'{output_h}', color='black', fontsize=10)ax.plot([output_origin[0], output_origin[0]], [output_origin[1], output_origin[1]], [output_origin[2], output_origin[2] + output_w], color='black')ax.text(output_origin[0] - 5, output_origin[1] - 0.5, output_origin[2] + output_w / 2, f'{output_w}', color='black', fontsize=10)# 输出单元格和其维度标注out_start = (output_offset + d * depth_scale, h, w)draw_box(ax, out_start, (depth_scale, 1, 1), 'red', alpha=0.9, edge_color='darkred')ax.plot([out_start[0], out_start[0] + depth_scale], [out_start[1], out_start[1]], [out_start[2], out_start[2]], color='black')ax.text(out_start[0] + depth_scale / 2, out_start[1] - 0.5, out_start[2] - 0.5, '1', color='black', fontsize=10)ax.plot([out_start[0], out_start[0]], [out_start[1], out_start[1] + 1], [out_start[2], out_start[2]], color='black')ax.text(out_start[0] - 5, out_start[1] + 0.5, out_start[2] - 0.5, '1', color='black', fontsize=10)ax.plot([out_start[0], out_start[0]], [out_start[1], out_start[1]], [out_start[2], out_start[2] + 1], color='black')ax.text(out_start[0] - 5, out_start[1] - 0.5, out_start[2] + 0.5, '1', color='black', fontsize=10)ax.set_xlim(-1, output_offset + output_d * depth_scale + 2)ax.set_ylim(-1, input_h + 1)ax.set_zlim(-1, input_w + 1)ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])ax.set_xlabel(''); ax.set_ylabel(''); ax.set_zlabel('')ax.grid(False); ax.set_axis_off()ax.view_init(25, 140)plt.title(f"3D Causal Convolution\nMasked future input in gray\nPosition: D={d}, H={h}, W={w}",pad=20)temp_file = f"data/tmp/temp_frame_{d}_{h}_{w}.png"plt.savefig(temp_file, dpi=100, bbox_inches='tight')frames.append(imageio.imread(temp_file))plt.close()os.remove(temp_file)imageio.mimsave(output_path, frames, duration=1)print(f"✅ GIF saved to: {output_path}")if __name__ == '__main__':# 测试batch_size = 2in_channels = 1  # RGBT, H, W = 10, 64, 64  # 10帧,每帧64x64kernel_size = (3, 5, 5)  # 时间维3,空间维5x5# Save the full traversal GIFvisualize_3d_convolution("data/tmp/3d_convolution.gif")# Save the full traversal GIFvisualize_3d_causal_convolution("data/tmp/3d_causal_convolution.gif")

http://www.lryc.cn/news/609622.html

相关文章:

  • Linux—yum仓库及NFS网络共享服务
  • [QMT量化交易小白入门]-七十六、从tick数据中获取高频交易的量价背离信号
  • 验证码等待时间技术在酒店自助入住、美容自助与社区场景中的应用必要性研究—仙盟创梦IDE
  • Dynamic Programming【DP】2
  • 9.感知机、神经网络
  • Antlr学习笔记 01、maven配置Antlr4插件案例Demo
  • 中标喜讯 | 安畅检测成功中标海南工信大脑(二期)软件测评服务
  • [Oracle] TO_NUMBER()函数
  • 【分享】拼团交易平台系统,分布式、高并发、微服务
  • 豆包1.6+PromptPilot实战:构建智能品牌评价情感分类系统的技术探索
  • Jetbrains IDE总是弹出“需要身份验证”窗口
  • uniapp 基础(三)
  • weapp-tailwindcss 已支持 uni-app x 多端构建
  • uniapp基础(四)性能优化
  • 使用opencv基于realsense D435i展示基本的图像
  • 计算机网络:有路由器参与的子网间通信原理
  • 阿里云与华为云产品的差异
  • 计算机网络:网络号和网络地址的区别
  • OpenCV轻松入门_面向python(第二章图像处理基础)
  • 从物理扇区到路径访问:Linux文件抽象的全景解析
  • Linux 网络深度剖析:传输层协议 UDP/TCP 原理详解
  • iostat 系统IO监控命令学习
  • 二叉树的概念以及二叉树的分类,添加,删除
  • OpenCV计算机视觉实战(18)——视频处理详解
  • Postman:配置环境变量
  • 【Unity3D实例-功能-镜头】第三人称视觉
  • VUE2 学习笔记17 路由
  • 算法训练营DAY50 第十一章:图论part01
  • 代码随想录day55图论5
  • [spring-cloud: 负载均衡]-源码分析