本文记录一下将PyTorch模型适配到MLX的过程。

什么是MLX?

MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research.
https://github.com/ml-explore/mlx/blob/main/README.md

MLX 是适应于苹果M系列芯片(Apple Silicon)的机器学习框架。

mlx的array设计更加接近于numpy[1],而不是PyTorch的tensor,即只存有结构信息(如形状、数据类型等),没有其它与深度学习训练相关的属性(如梯度)。与numpytorch不同的是,mlx 的arrayUnified Memory 可以在CPU和GPU之间共享,这也是mlx被单独开发而非拓展pytorch的 mps backend的理由[2]

array与tensor的区别

np.arraytorch.tensor的区别可以阅读What is a Tensor in Machine Learning?

模型转换实践

BigVGAN PyTorch -> mlx-BigVGAN

基本的映射

torchmlx
DataTypesData typesData Types
NNtorch.nn.*mlx.nn.*
Parameters/Weight/Buffertorch.Tensormlx.core.array
ModuleListtorch.nn.ModuleListlist
ModuleDicttorch.nn.ModuleDictdict
Transformtorch.fftmlx.core.fft

Pad mode

torch 的 pad 方法支持constant, reflect, replicate等模式,而mlxpad则支持constantedge模式。

reflect模式在mlx中并不支持,可以自定义实现:

1
2
3
4
5
6
7
8
9
10
11
import mlx.core as mx
def pad_reflect(x: mx.array, padding: tuple | int) -> mx.array:
"""
pad the input array with `reflect` mode in last axis
"""
if isinstance(padding, int):
padding = (padding, padding)

prefix = x[..., 1 : padding[0] + 1][..., ::-1]
suffix = x[..., -(padding[1] + 1) : -1][..., ::-1]
return mx.concatenate([prefix, x, suffix], axis=-1)

假设 x 为 2D (M, N) 的张量,在最后一个维度的右边增加宽度为1的padding,torchmlx的pad 对照如下:

pad modetorchmlx
constantF.pad(x, (0, 1), mode="constant", value=0)mx.pad(x, [(0, 0), (0, 1)], mode="constant", constant_values=1)
replicateF.pad(x, (0, 1), mode="replicate")mx.pad(x, [(0, 0), (0, 1)], mode="edge")
reflectF.pad(x, (0, 1), mode="reflect")pad_reflect(x, (0, 1)) (custom pad_reflect function)

nn.Module

mlxnn.Module,可以不用定义forward方法,直接使用__call__

如有一个 Snake 的 Module,torchmlx的实现如下:

torchmlx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Snake(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
super(Snake, self).__init__()
self.in_features = in_features
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else:
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001

def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import mlx.core as mx
import mlx.nn as nn

class Snake(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
super(Snake, self).__init__()
self.in_features = in_features
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = mx.zeros(in_features) * alpha
else:
self.alpha = mx.ones(in_features) * alpha
if not alpha_trainable:
self.freeze(keys="alpha")
self.no_div_by_zero = 0.000000001

def __call__(self, x):
# Line up with x to [B, T, C]
alpha = mx.expand_dims(self.alpha, axis=(0,-1))
if self.alpha_logscale:
alpha = mx.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * mx.power(mx.sin(x * alpha), 2)
return x

差异点

mlxmlx.nn.Module.freeze[3]方法来冻结无需训练的参数,torch则用torch.Tensor.requires_grad_[4]来控制是否需要训练。

weight_norm

torch 提供了一个weight_norm 方法,用于优化训练的稳定性和泛化能力。
weight_norm将权重分解为两个部分:vg,其中v是一个向量,g是一个标量。权重的计算公式如下:

w=gvvw = g \frac{v}{\|{v}\|}

然而mlx暂时没有提供类似的功能(社区的PR还在review中 Implement Weight Normalization
还好,我们可以在权重转换时做一次类似remove_weight_norm的操作,针对仅需推理的模型。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
origin_state_dict = torch.load("origin_model.pth", map_location="cpu", weight_only=True)
out_weights = {}
for k, v in origin_state_dict.items():
# handle weight norm
if k.endswith(("weight_v", "weight_g")):
basename, pname = k.rsplit(".", 1)
if pname == "weight_v":
g = origin_state_dict[basename + ".weight_g"]
# compute weight
k = basename + ".weight"
v = torch._weight_norm(v, g, dim=0)
else: # pname == "weight_g"
continue
...
out_weights[k] =v

Conv1d

mlxConv1dtorch的接口设计基本一致,主要差异在于torchConv1d的输入数据格式为(B, C, L),而mlx的输入数据格式为(B, L, C)。(B为batch size,C为通道数,L为序列长度如时间帧数量)
torchConv1d的权重格式为(out_channels, in_channels // groups, kernel_size),而mlxConv1d的权重格式为(out_channels, kernel_size, in_channels // groups)
bias的形状则是一样的,都是(out_channels)

torchmlx
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=dilation,
padding=(kernel_size * dilation - dilation) // 2,
)
print(conv.weight.shape) # (out_channels, in_channels, kernel_size)

x = ... # in shape (B, in_channels, seq_len)
y = conv(x)
1
2
3
4
5
6
7
8
9
10
11
12
import mlx.core as mx
conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=dilation,
padding=(kernel_size * dilation - dilation) // 2,
)
print(conv.weight.shape) # (out_channels, kernel_size, in_channels)
x = ... # in shape (B, seq_len, in_channels)
y = conv(x)

权重转换:

1
2
3
4
5
torch_conv1d_weight = ... # Tensor in shape: (out_channels, in_channels, kernel_size)

mlx_conv_weight = mx.array(torch_conv1d_weight.permute(0, 2, 1)) # (out_channels, kernel_size, in_channels)
# or
mlx_conv_weight = mx.array(torch_conv1d_weight.moveaxis(1, 2)) # (out_channels, kernel_size, in_channels)
Conv1dtorchmlx
Input shape(B, C_in, L_seq)(B, L_seq, C_in)
Weight shape(C_out, C_in // groups, kernel_size)(C_out, kernel_size, C_in // groups)

ConvTranspose1D

ConvTranspose1D 的输入和权重的差异点也和 Conv1d 类似

ConvTranspose1Dtorchmlx
Input shape(B, C_in, L_seq)(B, L_seq, C_in)
Weight shape(C_in, C_out, kernel_size)(C_out, kernel_size, C_in)

特别注意,mlx ConvTranspose1D 不支持 groups参数,即 groups=1

性能并不如预期

在将BigVGAN成功适配到mlx之后,用我这台 Apple M3 (16G) Macbook Pro 14” 与原版实现进行对比,好家伙,发现还不如原来的pytorch实现(白忙活了一场…

BigVGAN: 2.3289 seconds per inference
MLX BigVGAN: 4.5342 seconds per inference
Compiled MLX BigVGAN: 4.3205 seconds per inference

进一步 Profile 发现[5]mlxconv1dconv_transpose1d 的性能不如 torch mps backend,提了个Issue mlx#2180, 截至当前没得到回复。

conv1d input: 8x256x1000 weight: 256x32x12 groups: 8
conv_transpose1d input: 8x256x1000 weight: 256x1x12 groups: 256
torch(mps) conv1d: 0.943 ms
mlx_conv1d: 3.906 ms
diff: -2.9631301250046818
torch(mps) conv_transpose1d: 2.912 ms
mlx conv_transpose1d: 5.282 ms
diff: -2.3704653340100776

由此可见,MLX 还有很大的提升空间~ 期待官方的更新和社区的贡献。

Refs


  1. 1.numpy NumPy is the fundamental package for scientific computing in Python.
  2. 2.awni's answer to Why not implement this in Pytorch?
  3. 3.freeze Freeze the Module’s parameters or some of them
  4. 4.torch.Tensor.requires_grad_ Sets the requires_grad attribute of the tensor
  5. 5.benchmark scripts for conv1d and conv_transpose1d