将PyTorch模型适配到MLX
本文记录一下将PyTorch模型适配到MLX的过程。
什么是MLX?
MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research.
https://github.com/ml-explore/mlx/blob/main/README.md
MLX 是适应于苹果M系列芯片(Apple Silicon)的机器学习框架。
mlx的array设计更加接近于numpy[1],而不是PyTorch的tensor,即只存有结构信息(如形状、数据类型等),没有其它与深度学习训练相关的属性(如梯度)。与numpy和torch不同的是,mlx 的array是 Unified Memory 可以在CPU和GPU之间共享,这也是mlx被单独开发而非拓展pytorch的 mps backend的理由[2]。
array与tensor的区别
np.array和torch.tensor的区别可以阅读What is a Tensor in Machine Learning?
模型转换实践
BigVGAN PyTorch -> mlx-BigVGAN
基本的映射
| torch | mlx | |
|---|---|---|
| DataTypes | Data types | Data Types |
| NN | torch.nn.* | mlx.nn.* |
| Parameters/Weight/Buffer | torch.Tensor | mlx.core.array |
| ModuleList | torch.nn.ModuleList | list |
| ModuleDict | torch.nn.ModuleDict | dict |
| Transform | torch.fft | mlx.core.fft |
Pad mode
torch 的 pad 方法支持constant, reflect, replicate等模式,而mlx的pad则支持constant、edge模式。
reflect模式在mlx中并不支持,可以自定义实现:
1 | import mlx.core as mx |
假设 x 为 2D (M, N) 的张量,在最后一个维度的右边增加宽度为1的padding,torch和mlx的pad 对照如下:
| pad mode | torch | mlx |
|---|---|---|
| constant | F.pad(x, (0, 1), mode="constant", value=0) | mx.pad(x, [(0, 0), (0, 1)], mode="constant", constant_values=1) |
| replicate | F.pad(x, (0, 1), mode="replicate") | mx.pad(x, [(0, 0), (0, 1)], mode="edge") |
| reflect | F.pad(x, (0, 1), mode="reflect") | pad_reflect(x, (0, 1)) (custom pad_reflect function) |
nn.Module
mlx的nn.Module,可以不用定义forward方法,直接使用__call__。
如有一个 Snake 的 Module,torch和mlx的实现如下:
| torch | mlx | ||||
|---|---|---|---|---|---|
|
|
weight_norm
torch 提供了一个weight_norm 方法,用于优化训练的稳定性和泛化能力。weight_norm将权重分解为两个部分:v和g,其中v是一个向量,g是一个标量。权重的计算公式如下:
然而mlx暂时没有提供类似的功能(社区的PR还在review中 Implement Weight Normalization
还好,我们可以在权重转换时做一次类似remove_weight_norm的操作,针对仅需推理的模型。。。
1 | import torch |
Conv1d
mlx的Conv1d和torch的接口设计基本一致,主要差异在于torch的Conv1d的输入数据格式为(B, C, L),而mlx的输入数据格式为(B, L, C)。(B为batch size,C为通道数,L为序列长度如时间帧数量)torch的Conv1d的权重格式为(out_channels, in_channels // groups, kernel_size),而mlx的Conv1d的权重格式为(out_channels, kernel_size, in_channels // groups)。bias的形状则是一样的,都是(out_channels)。
| torch | mlx | ||||
|---|---|---|---|---|---|
|
|
权重转换:
1 | torch_conv1d_weight = ... # Tensor in shape: (out_channels, in_channels, kernel_size) |
| Conv1d | torch | mlx |
|---|---|---|
| Input shape | (B, C_in, L_seq) | (B, L_seq, C_in) |
| Weight shape | (C_out, C_in // groups, kernel_size) | (C_out, kernel_size, C_in // groups) |
ConvTranspose1D
ConvTranspose1D 的输入和权重的差异点也和 Conv1d 类似
| ConvTranspose1D | torch | mlx |
|---|---|---|
| Input shape | (B, C_in, L_seq) | (B, L_seq, C_in) |
| Weight shape | (C_in, C_out, kernel_size) | (C_out, kernel_size, C_in) |
特别注意,mlx ConvTranspose1D 不支持
groups参数,即groups=1。
性能并不如预期
在将BigVGAN成功适配到mlx之后,用我这台 Apple M3 (16G) Macbook Pro 14” 与原版实现进行对比,好家伙,发现还不如原来的pytorch实现(白忙活了一场…)
BigVGAN: 2.3289 seconds per inference
MLX BigVGAN: 4.5342 seconds per inference
Compiled MLX BigVGAN: 4.3205 seconds per inference
进一步 Profile 发现[5],mlx 的 conv1d 和 conv_transpose1d 的性能不如 torch mps backend,提了个Issue mlx#2180, 截至当前没得到回复。
conv1d input: 8x256x1000 weight: 256x32x12 groups: 8
conv_transpose1d input: 8x256x1000 weight: 256x1x12 groups: 256
torch(mps) conv1d: 0.943 ms
mlx_conv1d: 3.906 ms
diff: -2.9631301250046818
torch(mps) conv_transpose1d: 2.912 ms
mlx conv_transpose1d: 5.282 ms
diff: -2.3704653340100776
由此可见,MLX 还有很大的提升空间~ 期待官方的更新和社区的贡献。
Refs
- 1.numpy NumPy is the fundamental package for scientific computing in Python. ↩
- 2.awni's answer to Why not implement this in Pytorch? ↩
- 3.freeze Freeze the Module’s parameters or some of them ↩
- 4.torch.Tensor.requires_grad_ Sets the requires_grad attribute of the tensor ↩
- 5.benchmark scripts for conv1d and conv_transpose1d ↩
Author: Yrom
Link: https://yrom.net/blog/2025/05/14/adapt-a-pytorch-model-to-mlx/
License: 知识共享署名-非商业性使用 4.0 国际许可协议