Restormer模型代码解析

35 minute read

Published:

上一篇我们对Restormer的论文进行了解析。这篇对Restormer的代码进行解析。

论文地址:Restormer: Efficient Transformer for High-Resolution Image Restoration。代码地址:Restormer

以Deraining项目中的test.py文件为切入点,来分析其model。

test.py进行了修改,适配Mac的mps backend。

## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881
import numpy as np
import os
import argparse
from tqdm import tqdm
import sys
import torch.nn as nn
import torch
import torch.nn.functional as F
import utils

from natsort import natsorted
from glob import glob
from restormer_arch import Restormer
from skimage import img_as_ubyte
from pdb import set_trace as stx

parser = argparse.ArgumentParser(description='Image Deraining using Restormer')

parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images')
parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
parser.add_argument('--weights', default='./pretrained_models/deraining.pth', type=str, help='Path to weights')

args = parser.parse_args()

####### Load yaml #######
yaml_file = 'Options/Deraining_Restormer.yml'
import yaml

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader

x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)

s = x['network_g'].pop('type')
##########################

# Detect available device (CUDA, MPS, or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

print(f"{x['network_g'] = }")
model_restoration = Restormer(**x['network_g'])
print(f'{model_restoration = }')

sys.exit()
checkpoint = torch.load(args.weights, map_location='cpu')
model_restoration.load_state_dict(checkpoint['params'])
print("===>Testing using weights: ", args.weights)

model_restoration.to(device)
if device.type == 'cuda' and torch.cuda.device_count() > 1:
    model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()

factor = 8
# datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800']
datasets = ['my']


for dataset in datasets:
    print(f'{dataset = }')

    result_dir = os.path.join(args.result_dir, dataset)
    os.makedirs(result_dir, exist_ok=True)

    inp_dir = os.path.join(args.input_dir, 'test', dataset, 'input')
    files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
    with torch.no_grad():
        for file_ in tqdm(files):
            if device.type == 'cuda':
                torch.cuda.ipc_collect()
                torch.cuda.empty_cache()

            img = np.float32(utils.load_img(file_)) / 255.0
            img = torch.from_numpy(img).permute(2, 0, 1)
            input_ = img.unsqueeze(0).to(device)

            # Padding in case images are not multiples of 8
            h, w = input_.shape[2], input_.shape[3]
            H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
            padh = H - h if h % factor != 0 else 0
            padw = W - w if w % factor != 0 else 0
            input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')

            restored = model_restoration(input_)

            # Unpad images to original dimensions
            restored = restored[:, :, :h, :w]

            restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()

            #utils.save_img(os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0] + '.png'), img_as_ubyte(restored))
            utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))

我们运行,首先看一下传入Restormer类进行实例化的参数。

x['network_g'] = {'inp_channels': 3, 'out_channels': 3, 'dim': 48, 'num_blocks': [4, 6, 6, 8], 'num_refinement_blocks': 4, 'heads': [1, 2, 4, 8], 'ffn_expansion_factor': 2.66, 'bias': False, 'LayerNorm_type': 'WithBias', 'dual_pixel_task': False}

打印一下model看。

model_restoration = Restormer(
  (patch_embed): OverlapPatchEmbed(
    (proj): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (encoder_level1): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (project_out): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(48, 254, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(254, 254, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=254, bias=False)
        (project_out): Conv2d(127, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (project_out): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(48, 254, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(254, 254, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=254, bias=False)
        (project_out): Conv2d(127, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (project_out): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(48, 254, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(254, 254, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=254, bias=False)
        (project_out): Conv2d(127, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (project_out): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(48, 254, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(254, 254, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=254, bias=False)
        (project_out): Conv2d(127, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (down1_2): Downsample(
    (body): Sequential(
      (0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): PixelUnshuffle(downscale_factor=2)
    )
  )
  (encoder_level2): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (4): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (5): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (down2_3): Downsample(
    (body): Sequential(
      (0): Conv2d(96, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): PixelUnshuffle(downscale_factor=2)
    )
  )
  (encoder_level3): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (4): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (5): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (down3_4): Downsample(
    (body): Sequential(
      (0): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): PixelUnshuffle(downscale_factor=2)
    )
  )
  (latent): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (4): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (5): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (6): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (7): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1152, bias=False)
        (project_out): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(384, 2042, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(2042, 2042, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2042, bias=False)
        (project_out): Conv2d(1021, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (up4_3): Upsample(
    (body): Sequential(
      (0): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): PixelShuffle(upscale_factor=2)
    )
  )
  (reduce_chan_level3): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (decoder_level3): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (4): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (5): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(192, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
        (project_out): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(192, 1020, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(1020, 1020, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1020, bias=False)
        (project_out): Conv2d(510, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (up3_2): Upsample(
    (body): Sequential(
      (0): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): PixelShuffle(upscale_factor=2)
    )
  )
  (reduce_chan_level2): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (decoder_level2): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (4): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (5): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (up2_1): Upsample(
    (body): Sequential(
      (0): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): PixelShuffle(upscale_factor=2)
    )
  )
  (decoder_level1): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (refinement): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=288, bias=False)
        (project_out): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(96, 510, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(510, 510, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=510, bias=False)
        (project_out): Conv2d(255, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (output): Conv2d(96, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

基本结构还是一个encocder-decoder的架构。

Restormer 类的实现为:

##---------- Restormer -----------------------
class Restormer(nn.Module):
    def __init__(self, 
        inp_channels=3, 
        out_channels=3, 
        dim = 48,
        num_blocks = [4,6,6,8], 
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(Restormer, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        
        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
        
        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
        ###########################
            
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img):

        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        
        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3) 

        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4) 
                        
        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3) 

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2) 

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)
        
        out_dec_level1 = self.refinement(out_dec_level1)

        #### For Dual-Pixel Defocus Deblurring Task ####
        if self.dual_pixel_task:
            out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
            out_dec_level1 = self.output(out_dec_level1)
        ###########################
        else:
            out_dec_level1 = self.output(out_dec_level1) + inp_img


        return out_dec_level1

Restormer 的代码解析

这段代码实现了 Restormer 的网络结构,它是一种基于 Transformer 的深度学习模型,主要用于图像修复任务。以下是它的主要组成部分和详细解析:


主要组件:

  1. Patch Embedding(self.patch_embed:
    • 从输入图像中提取重叠的图像块,并将它们嵌入到更高维的特征空间中。
    • 将输入图像转换为适合 Transformer 处理的张量。
  2. Transformer Blocks:
    • TransformerBlock 使用自注意力机制和前馈神经网络进行特征提取。
    • Transformer 块被广泛应用于编码器、潜在空间(latent space)、解码器和后续的精细化阶段。
  3. 多级编码器:
    • 编码器由多层级组成(encoder_level1encoder_level2encoder_level3latent),每一级提取更抽象的特征。
    • 使用 Downsample 层在层级之间减少空间分辨率。
  4. 多级解码器:
    • 解码器通过对称的多层级(decoder_level3decoder_level2decoder_level1)逐步重建输出图像。
    • 使用 Upsample 层在层级之间增加空间分辨率。
    • 从编码器传递的特征图与解码器中上采样的输出拼接,保留高分辨率细节。
  5. 精细化处理:
    • 使用额外的 Transformer 块(self.refinement)对解码器的输出进行进一步优化,提高图像质量。
  6. 双像素散焦任务(Dual-Pixel Defocus Task,可选):
    • dual_pixel_task=True 时,增加一个跳跃连接(self.skip_conv),用于专门处理散焦去模糊任务。
  7. 最终输出:
    • 使用一个卷积层(self.output)将处理后的特征图映射到所需的输出通道。
    • 默认情况下,模型将输入图像添加到最终输出,形成残差学习(output + inp_img)。

参数说明:

  • inp_channelsout_channels:输入和输出图像的通道数(例如,RGB 图像有 3 个通道)。
  • dim:Transformer 块的初始特征维度。
  • num_blocks:每个层级中 Transformer 块的数量。
  • heads:每个层级的注意力头数量。
  • ffn_expansion_factor:Transformer 块中前馈网络的扩展因子。
  • LayerNorm_type:Layer Normalization 的类型(支持 WithBiasBiasFree)。

前向传播过程:

  1. 输入图像通过 Patch Embedding 层进行初步处理。
  2. 编码器通过多个层级提取特征,每经过一层,空间分辨率逐渐减小。
  3. 在潜在空间中,Transformer 块对抽象特征进行处理。
  4. 解码器逐步重建图像,并利用编码器传递的特征保留细节。
  5. 解码器的输出经过精细化处理,最终生成修复后的图像。
  6. 对于普通图像修复任务,输出通过残差连接(output + inp_img)返回结果;对于特定任务(如双像素散焦任务),会采用特定的跳跃连接。

适用场景:

  • 图像去噪
  • 超分辨率
  • 去模糊
  • 一般的图像修复任务

1. OverlapPatchEmbed 模块

其实现为

## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)

        return x

其主要特点是重叠特性,相邻像素块之间共享信息(与非重叠方式相比),增强了局部特征的提取能力。

2. TransformerBlock 模块

实现为

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x

LayerNorm

class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)

在前面打印的模型结构中,我们看到使用的都是WithBias_LayerNorm。

其实现为:

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias

这段代码实现了一个自定义的 LayerNorm 层,命名为 WithBias_LayerNorm,它通过加权和偏置参数对输入进行归一化。此实现保留了可学习的权重参数(self.weight)和偏置参数(self.bias)。


主要功能

Layer Normalization 是一种归一化技术,通常用于深度学习网络中的非卷积层。
它通过对输入的最后一个维度计算均值和方差进行归一化,目的在于加快训练速度和提高模型性能。


代码解析

1. 构造函数(__init__

  • 参数
    • normalized_shape:归一化的维度大小,通常与输入张量的最后一个维度一致。
  • 初始化
    • self.weight:形状与 normalized_shape 一致的可学习参数,用于对归一化后的输出进行缩放(初始化为全 1)。
    • self.bias:形状与 normalized_shape 一致的可学习参数,用于对归一化后的输出添加偏置(初始化为全 0)。
    • normalized_shape 被转换为 torch.Size 对象,确保它是一个有效的形状。

2. 前向传播(forward

  • 输入张量 x 的最后一个维度被归一化:
    • 计算均值mu = x.mean(-1, keepdim=True)
      沿最后一个维度计算均值,结果保留维度。
    • 计算方差sigma = x.var(-1, keepdim=True, unbiased=False)
      沿最后一个维度计算无偏估计的方差。
    • 归一化
      将输入减去均值,然后除以标准差 ((\text{sqrt(variance)} + 1e-5)),用于稳定计算。
    • 缩放与偏移
      对归一化后的值应用 self.weightself.bias

输出的形状与输入一致。


使用示例

假设输入张量的形状为 ([B, C, L]),最后一个维度 (L) 是需要归一化的维度。

import torch
import torch.nn as nn

# 创建 WithBias_LayerNorm 实例
normalized_shape = 64
layer_norm = WithBias_LayerNorm(normalized_shape)

# 随机生成输入张量 (batch_size=8, features=64, sequence_length=128)
x = torch.randn(8, 128, 64)

# 应用 LayerNorm
output = layer_norm(x)

# 输出形状
print(output.shape)  # torch.Size([8, 128, 64])

特点

  1. 可学习参数
    • self.weightself.bias 提供了灵活性,使模型可以在训练过程中学习如何缩放和偏移归一化后的值。
  2. 最后一维归一化
    • 默认只对最后一个维度进行归一化操作,这通常用于序列数据(如 NLP 或 Transformer 相关任务)。
  3. 数值稳定性
    • 在标准差的计算中添加 (1e-5) 防止除零错误。

与 PyTorch 内置 LayerNorm 的区别

  • PyTorch 的 nn.LayerNorm 同样支持可学习的权重和偏置。
  • 此实现是一个简化版本,但功能类似,并提供自定义的灵活性。

Attention

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

这段代码实现了 Multi-DConv Head Transposed Self-Attention (MDTA),是用于图像或特征图的自注意力机制。MDTA 结合了多头注意力机制和深度可分离卷积,能够有效提取局部和全局特征。


主要组成部分解析

1. 初始化方法 (__init__)

  • 输入参数
    • dim:输入特征的通道数。
    • num_heads:多头注意力的头数。
    • bias:是否为卷积操作添加偏置。
  • 主要组件
    1. self.num_heads
      • 多头注意力的头数,每个头分别学习一部分特征。
    2. self.temperature
      • 一个可学习参数,形状为 ([ \text{num_heads}, 1, 1 ])。在注意力权重计算中起到缩放作用。
    3. self.qkv
      • 一个 (1 \times 1) 的卷积层,将输入张量映射到查询(q)、键(k)和值(v)三个特征空间。
    4. self.qkv_dwconv
      • 一个深度可分离卷积(Depthwise Convolution),对 qkv 特征进行空间上的信息整合。
    5. self.project_out
      • 一个 (1 \times 1) 的卷积层,用于映射输出特征回原始通道数。

2. 前向传播方法 (forward)

  1. 输入形状解析
    • 输入 x 的形状为 ([b, c, h, w]),分别是批量大小、通道数、高度和宽度。
  2. 生成 qkv 特征
    • self.qkv:将输入张量通过 (1 \times 1) 卷积生成查询、键、值。
    • self.qkv_dwconv:进一步通过深度可分离卷积整合局部空间信息。
  3. 特征拆分
    • 使用 chunk 方法将 qkv 按通道数拆分为 q(查询)、k(键)、v(值),每个张量的通道数为 (\text{dim})。
  4. 特征重排列
    • 使用 rearrangeqkv 转换为多头格式,形状变为 ([b, \text{num_heads}, c’, h \times w]),其中 (c’ = \text{dim} / \text{num_heads})。
  5. 特征归一化
    • qk 的最后一维进行 (L_2) 归一化,便于计算相似性。
  6. 注意力权重计算
    • (\text{attn} = \text{softmax}((q \cdot k^T) \cdot \text{temperature})),形状为 ([b, \text{num_heads}, h \cdot w, h \cdot w])。
  7. 注意力加权特征聚合
    • 通过 (\text{out} = \text{attn} \cdot v) 将注意力权重作用到 v 上,生成最终聚合的特征。
  8. 输出重排列与映射
    • 使用 rearrange 将多头特征重新排列为原始通道格式。
    • 通过 self.project_out 将输出特征映射回输入的通道数。

使用示例

假设输入特征图形状为 ([1, 64, 32, 32]):

import torch
import torch.nn as nn
from einops import rearrange

# 创建 Attention 实例
attention = Attention(dim=64, num_heads=8, bias=False)

# 随机生成输入张量
x = torch.randn(1, 64, 32, 32)

# 前向传播
output = attention(x)

# 输出形状
print(output.shape)  # torch.Size([1, 64, 32, 32])

特点

  1. 多头注意力
    • 将输入特征拆分为多个子空间,每个子空间独立计算注意力,从而捕获更丰富的特征。
  2. 深度可分离卷积
    • 增强局部特征提取能力,同时减少参数量和计算量。
  3. 可学习温度参数
    • 为不同头的注意力权重提供可调节的缩放因子。
  4. 全局与局部特征整合
    • 结合自注意力(全局关系)和卷积操作(局部特征),提升模型性能。

Gated-Dconv Feed-Forward Network (GDFN)

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

Gated-Dconv Feed-Forward Network (GDFN) 详解

GDFN 是一种改进的前馈网络架构,结合了卷积操作和门控机制,用于增强 Transformer 的局部特征提取能力。以下是对代码中每一部分的详细分析。


代码解析

1. 初始化方法:__init__

hidden_features = int(dim*ffn_expansion_factor)
  • dim: 输入特征的通道数。
  • ffn_expansion_factor: 前馈网络扩展因子,用于增加特征维度,提升网络的表示能力。
  • hidden_features: 中间特征维度,通常是输入维度的扩展版本。
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
  • 功能: 用 ( 1 \times 1 ) 卷积将输入特征从 ( \text{dim} ) 映射到 ( 2 \times \text{hidden_features} )。
  • 目的:
    1. 提升特征维度,为后续操作提供更大的表示能力。
    2. 将特征拆分为两部分,用于门控机制。
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
  • 功能: 深度卷积操作,对 ( 2 \times \text{hidden_features} ) 的特征进行局部特征提取。
  • 特点:
    • 深度卷积: 通过 groups=hidden_features*2,使每个通道独立进行卷积操作。
    • 目的:
      1. 提取局部空间上下文信息。
      2. 轻量化:相比标准卷积,深度卷积计算量更低。
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
  • 功能: 用 ( 1 \times 1 ) 卷积将中间特征映射回原始维度 ( \text{dim} )。
  • 目的: 保持网络的输入输出维度一致。

2. 前向传播:forward

x = self.project_in(x)
  • 输入张量 x 的形状为 ([b, \text{dim}, h, w])。
  • 经过 project_in 后,形状变为 ([b, 2 \times \text{hidden_features}, h, w])。
x1, x2 = self.dwconv(x).chunk(2, dim=1)
  • 功能:
    • 使用 dwconv 提取局部空间特征。
    • 通过 chunk 将输出特征分为两部分 ( x1 ) 和 ( x2 ):
      • ( x1 ) 作为激活部分。
      • ( x2 ) 作为门控部分。
  • 形状变化:
    • dwconv 的输出形状与输入相同:([b, 2 \times \text{hidden_features}, h, w])。
    • chunk 将其分成两部分:([b, \text{hidden_features}, h, w]) 各一份。
x = F.gelu(x1) * x2
  • 门控机制:
    • 对 ( x1 ) 应用非线性激活函数 GELU: [ \text{GELU}(x) = 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} (x + 0.044715x^3)\right)\right) ]
    • 将激活后的 ( x1 ) 与 ( x2 ) 逐元素相乘,形成门控机制: [ x = \text{GELU}(x1) \odot x2 ]
    • 目的: 利用 ( x2 ) 对 ( x1 ) 进行动态调节,增强网络的表达能力。
x = self.project_out(x)
  • 使用 ( 1 \times 1 ) 卷积将结果映射回原始维度 ( \text{dim} )。
  • 输出形状为 ([b, \text{dim}, h, w])。

GDFN 的核心设计思想

  1. 前馈网络的扩展:
    • 标准 Transformer 的前馈网络是两层全连接层,而 GDFN 使用卷积操作,增强了局部空间特征建模能力。
  2. 门控机制:
    • 通过将中间特征分为两部分 ( x1 ) 和 ( x2 ),并引入动态门控机制,提升了网络的非线性表达能力。
  3. 轻量化卷积:
    • 使用深度卷积(dwconv)替代标准卷积,减少计算量的同时保留局部特征提取能力。
  4. 组合非线性与空间特征提取:
    • ( x1 ) 经 GELU 提供非线性激活,而 ( x2 ) 提供空间信息的动态调节,两者结合提升了网络的表示能力。

总结

GDFN 通过结合深度卷积和门控机制,在保持轻量化的同时,增强了局部特征建模能力和表达灵活性。这种设计非常适合图像恢复等任务,能够有效提取局部上下文信息,同时具备 Transformer 的全局建模能力。

Downsample模块

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

Downsample 模块解析

Downsample 是一个降采样模块,用于减少特征图的空间分辨率和通道数。以下是代码的详细分析:


代码解析

1. 初始化方法:__init__

self.body = nn.Sequential(
    nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
    nn.PixelUnshuffle(2)
)
a. 卷积层
nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False)
  • 输入: 通道数为 ( n_feat )。
  • 输出: 通道数减半,变为 ( n_feat // 2 )。
  • 卷积操作参数:
    • kernel_size=3: 使用 ( 3 \times 3 ) 卷积核。
    • stride=1: 卷积操作不会跳过像素。
    • padding=1: 确保输出分辨率与输入相同(零填充)。
  • 作用: 将特征的通道数减少一半,同时进行局部信息提取。
b. 像素去重排列 (PixelUnshuffle)
nn.PixelUnshuffle(2)
  • 功能: 按 ( 2 \times 2 ) 的块对特征图进行像素重新排列,将空间分辨率缩小为原来的 ( 1/2 ),而通道数增加 ( 4 ) 倍。
    • 假设输入形状为 ([b, c, h, w]):
      • 输出形状变为 ([b, c \times 4, h/2, w/2])。
  • 结合卷积后的输出:
    • 通道数经过卷积变为 ( n_feat // 2 )。
    • 经过 PixelUnshuffle 后通道数变为 ( (n_feat // 2) \times 4 = 2 \times n_feat )。
  • 目的: 减少空间分辨率,增加特征表示的局部信息密度。

2. 前向传播:forward

def forward(self, x):
    return self.body(x)
  • 输入 ( x ) 的形状为 ([b, n_feat, h, w])。
  • 经过卷积: 通道数从 ( n_feat ) 变为 ( n_feat // 2 ),空间分辨率保持不变: [ [b, n_feat, h, w] \to [b, n_feat//2, h, w] ]
  • 经过 PixelUnshuffle: 将空间分辨率缩小为 ( h/2 \times w/2 ),通道数增加为 ( 2 \times n_feat ): [ [b, n_feat//2, h, w] \to [b, 2 \times n_feat, h/2, w/2] ]

模块功能总结

  1. 降采样:
    • 将输入的空间分辨率减少 ( 1/2 )。
    • 同时,通道数增加为原来的 ( 2 ) 倍,特征表示密度更高。
  2. 像素去重排列:
    • PixelUnshuffle 提供了一种轻量级的降采样方式,与传统的池化或步幅卷积不同,它直接重新排列像素以减少分辨率。
  3. 灵活性:
    • 通过卷积减少通道数,结合 PixelUnshuffle 的空间降采样,既提取了局部特征,又保持了降采样后的信息丰富度。

这种设计在图像重建、超分辨率等任务中非常常见,因为它在降采样的同时,能够保留更多细节信息供后续网络使用。

Upsample模块

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)

### Upsample 模块解析

Upsample 是一个上采样模块,用于增加特征图的空间分辨率和减少通道数。以下是代码的详细分析:


代码解析

1. 初始化方法:__init__

self.body = nn.Sequential(
    nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
    nn.PixelShuffle(2)
)
a. 卷积层
nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False)
  • 输入: 通道数为 ( n_feat )。
  • 输出: 通道数增加为 ( n_feat \times 2 )。
  • 卷积操作参数:
    • kernel_size=3: 使用 ( 3 \times 3 ) 卷积核。
    • stride=1: 卷积操作不会跳过像素。
    • padding=1: 确保输出分辨率与输入相同(零填充)。
  • 目的: 增加通道数,同时提取特征信息。
b. 像素移位排列 (PixelShuffle)
nn.PixelShuffle(2)
  • 功能: 使用 ( 2 \times 2 ) 的分块方式对通道维度进行重新排列,将通道数减少为原来的 ( 1/4 ),同时将空间分辨率增加 ( 2 ) 倍。
    • 假设输入形状为 ([b, c, h, w]):
      • 输出形状变为 ([b, c/4, h \times 2, w \times 2])。
  • 结合卷积后的输出:
    • 通道数经过卷积变为 ( n_feat \times 2 )。
    • 经过 PixelShuffle 后通道数变为 ( (n_feat \times 2) / 4 = n_feat / 2 )。
    • 但由于卷积部分输出通道数为 ( n_feat \times 2 ),实际 PixelShuffle 将通道数恢复为 ( n_feat )。

2. 前向传播:forward

def forward(self, x):
    return self.body(x)
  • 输入 ( x ) 的形状为 ([b, n_feat, h, w])。
  • 经过卷积: 通道数从 ( n_feat ) 增加为 ( n_feat \times 2 ),空间分辨率保持不变: [ [b, n_feat, h, w] \to [b, n_feat \times 2, h, w] ]
  • 经过 PixelShuffle: 将空间分辨率扩大为 ( h \times 2 ) 和 ( w \times 2 ),通道数恢复为 ( n_feat ): [ [b, n_feat \times 2, h, w] \to [b, n_feat, h \times 2, w \times 2] ]

模块功能总结

  1. 上采样:
    • 将输入的空间分辨率增加 ( 2 ) 倍。
    • 同时,将通道数调整为保持一致的数量。
  2. 像素移位排列:
    • PixelShuffle 提供了一种轻量级的上采样方式,它以低计算成本实现空间分辨率的提升。
  3. 结合卷积的灵活性:
    • 卷积不仅增加了通道数,同时也引入了局部特征的提取能力,为上采样后的特征表示提供更丰富的信息。

这种设计在图像生成、超分辨率以及去噪等任务中非常常见,因为它能够有效地提升分辨率并保持特征的一致性。

forward计算过程分析

如图是forward的计算过程。

从这个结构图可以看出,Restormer 采用了一个经典的编码器-解码器结构,并结合了多级Transformer块和降采样/上采样模块。以下是对其结构的详细分析:


整体设计

  1. 输入部分
    • OverlapPatchEmbed模块用作输入特征提取,通过卷积将原始图像转化为嵌入特征。
  2. 编码器部分
    • 包含了多个阶段的 Transformer Blocks,每个阶段特征图通过 Downsample 模块进行降采样,逐渐提取更高层次的特征。
    • 降采样后,特征维度增加(比如通道数翻倍),而空间分辨率减小。
  3. 瓶颈部分 (latent)
    • 在网络的中间部分,处理最小分辨率的特征,通过更多 Transformer Blocks 捕获全局上下文信息。
  4. 解码器部分
    • 解码器通过 Upsample 模块将特征逐步上采样回原始分辨率。
    • 每个上采样阶段都会与对应编码器阶段的输出进行特征拼接(如图中显示的跳跃连接)。
    • 拼接后的特征通过 reduce_channel_conv 减少通道数,并传入 Transformer Blocks 进行处理。
  5. 输出部分
    • 最终通过 refinement 模块进一步细化特征,最后使用卷积生成与输入分辨率相同的输出图像。

关键模块解析

  1. Overlap Patch Embedding
    • 使用卷积将输入图像分割为有重叠的图像块,并将其映射到嵌入空间。
  2. Downsample 和 Upsample
    • Downsample 通过 PixelUnshuffle 和卷积实现空间分辨率的减小和通道数的增加。
    • Upsample 则使用 PixelShuffle 和卷积将通道数减小并恢复空间分辨率。
  3. Transformer Blocks
    • 提取每一层的特征,并通过自注意力机制捕获远距离依赖和全局特征。
    • 这些模块在编码器和解码器中重复使用,确保跨层次特征的融合。
  4. 跳跃连接 (Skip Connections)
    • 编码器中各阶段的输出与解码器中相应阶段的输入进行拼接,这种跳跃连接可以有效保留低层次特征,提高重建的精确度。
  5. Refinement
    • 在解码完成后,通过额外的 Transformer Blocks 进一步细化特征,确保输出图像的质量。

优点分析

  1. 多级特征提取
    • 编码器逐步提取高层次语义特征,解码器逐步恢复空间分辨率,同时通过跳跃连接保留了底层细节。
  2. Transformer的引入
    • Transformer块捕获了远距离依赖和全局上下文信息,非常适合用于图像复原任务。
  3. 高效的降采样/上采样
    • 采用 PixelShufflePixelUnshuffle 进行分辨率调整,降低了计算复杂度。
  4. 细化机制
    • 最后的 Refinement 模块进一步提升了输出图像的质量。

适用场景

  • 图像复原任务,如去噪、超分辨率重建、去模糊等。
  • 需要捕获全局依赖关系和局部细节的高质量图像处理任务。