生信碱移

torch模块可视化

小编近期冲浪的时候发现一个torch模型架构可视化的神级python库VisualTorch,给各位铁子分享一下doge。

VisualTorch旨在帮助可视化基于Torch的神经网络架构,似乎是今年才上传到github上。它目前支持为PyTorch的Sequential和Custom模型生成分层风格、图形风格和LeNet风格的架构。工具的灵感源自visualkeras、pytorchviz和pytorch-summary。

图片

▲ 可视化示例

0.安装

使用以下代码安装该库

pip install visualtorch

环境依赖如下,实测的时候发现python版本还需要大于3.10:

pillow>=10.0.0
numpy>=1.18.1
aggdraw>=1.3.11
torch>=2.0.0

1.Layered可视化

2D可视化

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

基础自定义模型的可视化

图片

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as func
import visualtorch
from torch import nn


# Example of a simple CNN model
class SimpleCNN(nn.Module):
    """Simple CNN Model."""

    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Define the forward pass."""
        x = self.conv1(x)
        x = func.relu(x)
        x = func.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = func.relu(x)
        x = func.max_pool2d(x, 2, 2)
        x = self.conv3(x)
        x = func.relu(x)
        x = func.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = func.relu(x)
        return self.fc2(x)


# Create an instance of the SimpleCNN
model = SimpleCNN()

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, legend=True)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

基本Sequential模型的可视化

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, legend=True)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的颜色

图片

from collections import defaultdict

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

color_map: dict = defaultdict(dict)
color_map[nn.Conv2d]["fill"] = "LightSlateGray"  # Light Slate Gray
color_map[nn.ReLU]["fill"] = "#87CEFA"  # Light Sky Blue
color_map[nn.MaxPool2d]["fill"] = "LightSeaGreen"  # Light Sea Green
color_map[nn.Flatten]["fill"] = "#98FB98"  # Pale Green
color_map[nn.Linear]["fill"] = "LightSteelBlue"  # Light Steel Blue

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的不透明度

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, opacity=100)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的方向

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(
    model,
    input_shape=input_shape,
    one_dim_orientation="x",
    spacing=40,
)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块的阴影

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(model, input_shape=input_shape, shade_step=50)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

自定义模块间空间距离

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

img = visualtorch.layered_view(model, input_shape=input_shape, spacing=50)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

忽略某些模块,即仅可视化某些层

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10),  # Assuming 10 output classes
)

ignored_layers = [nn.ReLU, nn.Flatten]

input_shape = (1, 3, 224, 224)
img = visualtorch.layered_view(
    model,
    input_shape=input_shape,
    type_ignore=ignored_layers,
)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

2.全连接层可视化

可视化基本的全连接层,当然像颜色、空间啥的也都可以调整:

图片

import matplotlib.pyplot as plt
import torch
import visualtorch
from torch import nn


class SimpleDense(nn.Module):
    """Simple Dense Model."""

    def __init__(self) -> None:
        super().__init__()
        self.h0 = nn.Linear(4, 8)
        self.h1 = nn.Linear(8, 8)
        self.h2 = nn.Linear(8, 4)
        self.out = nn.Linear(4, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Define the forward pass."""
        x = self.h0(x)
        x = self.h1(x)
        x = self.h2(x)
        return self.out(x)


model = SimpleDense()

input_shape = (1, 4)

img = visualtorch.graph_view(model, input_shape)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

LeNet风格示例

图片

import matplotlib.pyplot as plt
import visualtorch
from torch import nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1),
    nn.MaxPool2d(2, 2),
)

input_shape = (1, 3, 128, 128)

img = visualtorch.lenet_view(model, input_shape=input_shape)

plt.axis("off")
plt.tight_layout()
plt.imshow(img)
plt.show()

上面用到的几个API参数,这里就介绍了,可以自行查看文档:

  • https://visualtorch.readthedocs.io/en/latest/index.html

够猛,宝

赶紧收藏关注起来

Logo

永洪科技,致力于打造全球领先的数据技术厂商,具备从数据应用方案咨询、BI、AIGC智能分析、数字孪生、数据资产、数据治理、数据实施的端到端大数据价值服务能力。

更多推荐