将PyTorch模型适配到MLX
本文记录一下将PyTorch
模型适配到MLX
的过程。
什么是MLX?
MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research.
https://github.com/ml-explore/mlx/blob/main/README.md
MLX 是适应于苹果M系列芯片(Apple Silicon)的机器学习框架。
mlx的array
设计更加接近于numpy
[1],而不是PyTorch的tensor
,即只存有结构信息(如形状、数据类型等),没有其它与深度学习训练相关的属性(如梯度)。与numpy
和torch
不同的是,mlx 的array
是 Unified Memory 可以在CPU和GPU之间共享,这也是mlx
被单独开发而非拓展pytorch的 mps backend的理由[2]。
array与tensor的区别
np.array
和torch.tensor
的区别可以阅读What is a Tensor in Machine Learning?
模型转换实践
BigVGAN PyTorch -> mlx-BigVGAN
基本的映射
torch | mlx | |
---|---|---|
DataTypes | Data types | Data Types |
NN | torch.nn.* | mlx.nn.* |
Parameters/Weight/Buffer | torch.Tensor | mlx.core.array |
ModuleList | torch.nn.ModuleList | list |
ModuleDict | torch.nn.ModuleDict | dict |
Transform | torch.fft | mlx.core.fft |
Pad mode
torch 的 pad
方法支持constant
, reflect
, replicate
等模式,而mlx
的pad
则支持constant
、edge
模式。
reflect
模式在mlx
中并不支持,可以自定义实现:
1 | import mlx.core as mx |
假设 x 为 2D (M, N) 的张量,在最后一个维度的右边增加宽度为1的padding,torch
和mlx
的pad 对照如下:
pad mode | torch | mlx |
---|---|---|
constant | F.pad(x, (0, 1), mode="constant", value=0) | mx.pad(x, [(0, 0), (0, 1)], mode="constant", constant_values=1) |
replicate | F.pad(x, (0, 1), mode="replicate") | mx.pad(x, [(0, 0), (0, 1)], mode="edge") |
reflect | F.pad(x, (0, 1), mode="reflect") | pad_reflect(x, (0, 1)) (custom pad_reflect function) |
nn.Module
mlx
的nn.Module
,可以不用定义forward
方法,直接使用__call__
。
如有一个 Snake
的 Module,torch
和mlx
的实现如下:
torch | mlx | ||||
---|---|---|---|---|---|
|
|
weight_norm
torch
提供了一个weight_norm
方法,用于优化训练的稳定性和泛化能力。weight_norm
将权重分解为两个部分:v
和g
,其中v
是一个向量,g
是一个标量。权重的计算公式如下:
然而mlx
暂时没有提供类似的功能(社区的PR还在review中 Implement Weight Normalization
还好,我们可以在权重转换时做一次类似remove_weight_norm
的操作,针对仅需推理的模型。。。
1 | import torch |
Conv1d
mlx
的Conv1d
和torch
的接口设计基本一致,主要差异在于torch
的Conv1d
的输入数据格式为(B, C, L)
,而mlx
的输入数据格式为(B, L, C)
。(B为batch size,C为通道数,L为序列长度如时间帧数量)torch
的Conv1d
的权重格式为(out_channels, in_channels // groups, kernel_size)
,而mlx
的Conv1d
的权重格式为(out_channels, kernel_size, in_channels // groups)
。bias
的形状则是一样的,都是(out_channels)
。
torch | mlx | ||||
---|---|---|---|---|---|
|
|
权重转换:
1 | torch_conv1d_weight = ... # Tensor in shape: (out_channels, in_channels, kernel_size) |
Conv1d | torch | mlx |
---|---|---|
Input shape | (B, C_in, L_seq) | (B, L_seq, C_in) |
Weight shape | (C_out, C_in // groups, kernel_size) | (C_out, kernel_size, C_in // groups) |
ConvTranspose1D
ConvTranspose1D 的输入和权重的差异点也和 Conv1d 类似
ConvTranspose1D | torch | mlx |
---|---|---|
Input shape | (B, C_in, L_seq) | (B, L_seq, C_in) |
Weight shape | (C_in, C_out, kernel_size) | (C_out, kernel_size, C_in) |
特别注意,mlx ConvTranspose1D 不支持
groups
参数,即groups=1
。
性能并不如预期
在将BigVGAN
成功适配到mlx
之后,用我这台 Apple M3 (16G) Macbook Pro 14” 与原版实现进行对比,好家伙,发现还不如原来的pytorch
实现(白忙活了一场…)
BigVGAN: 2.3289 seconds per inference
MLX BigVGAN: 4.5342 seconds per inference
Compiled MLX BigVGAN: 4.3205 seconds per inference
进一步 Profile 发现[5],mlx
的 conv1d
和 conv_transpose1d
的性能不如 torch
mps backend,提了个Issue mlx#2180, 截至当前没得到回复。
conv1d input: 8x256x1000 weight: 256x32x12 groups: 8
conv_transpose1d input: 8x256x1000 weight: 256x1x12 groups: 256
torch(mps) conv1d: 0.943 ms
mlx_conv1d: 3.906 ms
diff: -2.9631301250046818
torch(mps) conv_transpose1d: 2.912 ms
mlx conv_transpose1d: 5.282 ms
diff: -2.3704653340100776
由此可见,MLX
还有很大的提升空间~ 期待官方的更新和社区的贡献。
Refs
- 1.numpy NumPy is the fundamental package for scientific computing in Python. ↩
- 2.awni's answer to Why not implement this in Pytorch? ↩
- 3.freeze Freeze the Module’s parameters or some of them ↩
- 4.torch.Tensor.requires_grad_ Sets the requires_grad attribute of the tensor ↩
- 5.benchmark scripts for conv1d and conv_transpose1d ↩
Author: Yrom
Link: https://yrom.net/blog/2025/05/14/adapt-a-pytorch-model-to-mlx/
License: 知识共享署名-非商业性使用 4.0 国际许可协议