AI

Apple MLX

What is MLX?

Apple MLX is a machine learning framework that is designed to be fast, secure, and easy to use. It is built on top of the TensorFlow framework, and provides a simple and intuitive API for building machine learning models.


Quick Start

Import mlx.core and create an array:

>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32

MLX 中的操作是惰性的。MLX 操作的输出直到需要时才会被计算。要强制评估一个数组,请使用 eval()。在一些情况下,数组会自动被评估。例如,使用 array.item() 检查标量、打印数组,或将数组从 array 转换为 numpy.ndarray 都会自动评估该数组。

>> c = a + b    # c 尚未评估
>> mx.eval(c)  # 评估 c
>> c = a + b
>> print(c)     # 也评估 c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c)   # 也评估 c
array([2., 4., 6., 8.], dtype=float32)

函数和图形变换

MLX 具有标准的函数变换,如 grad()vmap()。变换可以任意组合。例如,允许 grad(vmap(grad(fn)))(或任何其他组合)。

>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)

其他梯度变换包括 vjp() 用于向量-雅可比积和 jvp() 用于雅可比-向量积。

使用 value_and_grad() 可以高效地同时计算函数的输出和相对于函数输入的梯度。


Unified Memory

Apple Silicon 使用统一内存架构。CPU 和 GPU 可以直接访问相同的内存池,MLX 被开发出来,以便利用这一优势。

具体来说,当你在 MLX 中创建一个数组时,你不需要指定其位置:

a = mx.random.normal((100,))
b = mx.random.normal((100,))

a 和 b 都存在于统一内存中。

在 MLX 中,你不需要将数组移动到设备上,而是在运行操作时指定设备。任何设备都可以在不需要将它们从一个内存位置移动到另一个位置的情况下,对 a 和 b 进行任何操作。例如:

mx.add(a, b, stream=mx.cpu)
mx.add(a, b, stream=mx.gpu)

在上述情况中,CPU 和 GPU 都将执行相同的加法操作。这些操作可以(也很可能会)并行运行,因为它们之间没有依赖关系。有关 MLX 中流语义的更多信息,请参见使用流。

在上述加法示例中,操作之间没有依赖关系,因此不存在竞态条件的可能性。如果存在依赖关系,MLX 调度器将自动管理它们。例如:

c = mx.add(a, b, stream=mx.cpu)
d = mx.add(a, c, stream=mx.gpu)

在上述情况中,第二个加法在 GPU 上运行,但它依赖于在 CPU 上运行的第一个加法的输出。MLX 将自动在两个流之间插入一个依赖,以便在第一个加法完成且 c 可用后,第二个加法才开始执行。

一个简单的例子

这里有一个更有趣的(尽管有点牵强的)例子,说明统一内存是如何有用的。假设我们有以下计算:

def fun(a, b, d1, d2):
    x = mx.matmul(a, b, stream=d1)
    for _ in range(500):
        b = mx.exp(b, stream=d2)
    return x, b

我们希望使用以下参数运行它:

a = mx.random.uniform(shape=(4096, 512))
b = mx.random.uniform(shape=(512, 4))

第一个 matmul 操作非常适合 GPU,因为它计算密集。第二序列操作更适合 CPU,因为它们非常小,可能在 GPU 上会受到开销的限制。

如果我们完全在 GPU 上计时,我们得到 2.8 毫秒。但如果我们使用 d1=mx.gpud2=mx.cpu 运行计算,那么时间只有大约 1.4 毫秒,快了大约两倍。这些时间是在 M1 Max 上测量的。


Use Stream

所有操作(包括随机数生成)都接受一个可选的关键字参数 stream。stream 参数指定操作应该在哪个 Stream 上运行。如果未指定 stream,则操作在默认设备的默认流上运行:mx.default_stream(mx.default_device())。stream 参数也可以是一个设备(例如 stream=my_device),在这种情况下,操作在提供的设备的默认流上运行 mx.default_stream(my_device)。


线性回归

让我们实现一个基本的线性回归模型,作为学习 MLX 的起点。首先导入核心包并设置一些问题元数据:

import mlx.core as mx

num_features = 100
num_examples = 1_000
num_iters = 10_000  # iterations of SGD
lr = 0.01  # learning rate for SGD

我们将通过以下方式生成一个合成数据集:

  1. 抽样设计矩阵 X
  2. 抽样真实参数向量 w_star
  3. 通过向 X @ w_star 添加高斯噪声来计算因变量 y
# True parameters
w_star = mx.random.normal((num_features,))

# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))

# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps

我们将使用 SGD(随机梯度下降)来找到最优权重。首先,定义平方损失并获取损失相对于参数的梯度函数。

def loss_fn(w):
    return 0.5 * mx.mean(mx.square(X @ w - y))

grad_fn = mx.grad(loss_fn)

启动优化,首先随机初始化参数 w。然后重复更新参数,进行 num_iters 次迭代。

w = 1e-2 * mx.random.normal((num_features,))

for _ in range(num_iters):
    grad = grad_fn(w)
    w = w - lr * grad
    mx.eval(w)

最后,计算学习到的参数的损失,并验证它们是否接近于真实参数。

loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5

print(
    f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)

# Should print something close to: Loss 0.00005, |w-w*| = 0.00364

完整的线性回归和逻辑回归示例可在 MLX 的 GitHub 仓库中找到。


多层感知器

多层感知器(MLP)是一种基本的神经网络架构,由多个全连接层组成。MLP 通常用于分类任务,但也可以用于回归任务。

在这个例子中,我们将学习使用 mlx.nn,通过实现一个简单的多层感知器来对 MNIST 进行分类。

作为第一步,导入我们需要的 MLX 包:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

模型被定义为继承自 mlx.nn.ModuleMLP 类。

我们遵循创建新模块的标准惯例:

定义一个 init 方法,在这里设置参数和/或子模块。有关 mlx.nn.Module 如何注册参数的更多信息,请参阅 Module 类文档。 定义一个 call 方法,实现计算过程。

class MLP(nn.Module):
    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = mx.maximum(l(x), 0.0)
        return self.layers[-1](x)

LLM Influence

MLX 可以在 Apple Silicon 上高效推理大型转换器模型,而不会影响使用的便利性。在这个示例中,我们将创建一个用于 Llama 系列转换器模型的推理脚本,其中模型在不到 200 行 Python 代码中定义。

模型的实现

我们将使用 mlx.nn 模块中定义的神经网络构建块来简洁地定义模型架构。

注意力层

我们将从 Llama 注意力层开始,该层 notably 使用 RoPE 位置编码。此外,我们的注意力层还可以选择使用一个键/值缓存,该缓存将与提供的键和值连接在一起,以支持高效的推理。

我们的实现使用 mlx.nn.Linear 来进行所有的投影,使用 mlx.nn.RoPE 进行位置编码。

import mlx.core as mx
import mlx.nn as nn

class LlamaAttention(nn.Module):
    def __init__(self, dims: int, num_heads: int):
        super().__init__()

        self.num_heads = num_heads

        self.rope = nn.RoPE(dims // num_heads, traditional=True)
        self.query_proj = nn.Linear(dims, dims, bias=False)
        self.key_proj = nn.Linear(dims, dims, bias=False)
        self.value_proj = nn.Linear(dims, dims, bias=False)
        self.out_proj = nn.Linear(dims, dims, bias=False)

    def __call__(self, queries, keys, values, mask=None, cache=None):
        queries = self.query_proj(queries)
        keys = self.key_proj(keys)
        values = self.value_proj(values)

        # 提取一些形状信息
        num_heads = self.num_heads
        B, L, D = queries.shape

        # 为注意力计算准备查询、键和值
        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)

        # 将 RoPE 添加到查询和键中,并与缓存组合
        if cache is not None:
            key_cache, value_cache = cache
            queries = self.rope(queries, offset=key_cache.shape[2])
            keys = self.rope(keys, offset=key_cache.shape[2])
            keys = mx.concatenate([key_cache, keys], axis=2)
            values = mx.concatenate([value_cache, values], axis=2)
        else:
            queries = self.rope(queries)
            keys = self.rope(keys)

        # 最后进行注意力计算
        scale = math.sqrt(1 / queries.shape[-1])
        scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
        if mask is not None:
            scores = scores + mask
        scores = mx.softmax(scores, axis=-1)
        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)

        # 注意我们返回键和值,可能用作缓存
        return self.out_proj(values_hat), (keys, values)

编码器层

Llama 模型的另一个组成部分是编码器层,它使用 RMS 标准化和 SwiGLU。对于 RMS 标准化,我们将使用 mlx.nn.RMSNorm,该模块已经在 mlx.nn 中提供。

class LlamaEncoderLayer(nn.Module):
    def __init__(self, dims: int, mlp_dims: int, num_heads: int):
        super().__init__()

        self.attention = LlamaAttention(dims, num_heads)

        self.norm1 = nn.RMSNorm(dims)
        self.norm2 = nn.RMSNorm(dims)

        self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
        self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
        self.linear3 = nn.Linear(mlp_dims, dims, bias=False)

    def __call__(self, x, mask=None, cache=None):
        y = self.norm1(x)
        y, cache = self.attention(y, y, y, mask, cache)
        x = x + y

        y = self.norm2(x)
        a = self.linear1(y)
        b = self.linear2(y)
        y = a * mx.sigmoid(a) * b
        y = self.linear3(y)
        x = x + y

        return x, cache

完整的模型

要实现任意 Llama 模型,我们只需将 LlamaEncoderLayer 实例与 mlx.nn.Embedding 结合起来以嵌入输入 Token。

class Llama(nn.Module):
    def __init__(
        self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, dims)
        self.layers = [
            LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
        ]
        self.norm = nn.RMSNorm(dims)
        self.out_proj = nn.Linear(dims, vocab_size, bias=False)

    def __call__(self, x):
        mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
        mask = mask.astype(self.embedding.weight.dtype)

        x = self.embedding(x)
        for l in self.layers:
            x, _ = l(x, mask)
        x = self.norm(x)
        return self.out_proj(x)

请注意,在上面的实现中,我们使用一个简单的列表来保存编码器层,但使用 model.parameters() 仍然会考虑这些层。

生成

我们的 Llama 模块可以用于训练,但不适用于推理,因为上面的 call 方法处理一个输入,完全忽略了缓存,并且不进行任何抽样。在本小节的其余部分,我们将实现推理函数作为 Python 生成器,该函数会处理提示,并自动逐个生成 Token。

class Llama(nn.Module):
    ...

    def generate(self, x, temp=1.0):
        cache = []

        # Make an additive causal mask. We will need that to process the prompt.
        mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
        mask = mask.astype(self.embedding.weight.dtype)

        # First we process the prompt x the same way as in __call__ but
        # save the caches in cache
        x = self.embedding(x)
        for l in self.layers:
            x, c = l(x, mask=mask)
            cache.append(c)  # <--- we store the per layer cache in a
                             #      simple python list
        x = self.norm(x)
        y = self.out_proj(x[:, -1])  # <--- we only care about the last logits
                                     #      that generate the next token
        y = mx.random.categorical(y * (1/temp))

        # y now has size [1]
        # Since MLX is lazily evaluated nothing is computed yet.
        # Calling y.item() would force the computation to happen at
        # this point but we can also choose not to do that and let the
        # user choose when to start the computation.
        yield y

        # Now we parsed the prompt and generated the first token we
        # need to feed it back into the model and loop to generate the
        # rest.
        while True:
            # Unsqueezing the last dimension to add a sequence length
            # dimension of 1
            x = y[:, None]

            x = self.embedding(x)
            for i in range(len(cache)):
                # We are overwriting the arrays in the cache list. When
                # the computation will happen, MLX will be discarding the
                # old cache the moment it is not needed anymore.
                x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
            x = self.norm(x)
            y = self.out_proj(x[:, -1])
            y = mx.random.categorical(y * (1/temp))

            yield y

汇总

我们现在已经拥有创建Llama模型并从中生成示例标记所需的一切。在以下代码中,我们随机初始化了一个小型的 Llama 模型,处理了 6 个提示标记并生成了 10 个标记。

model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)

# Since MLX is lazily evaluated nothing has actually been materialized yet.
# We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
# code above would still run. Let's actually materialize the model.
mx.eval(model.parameters())

prompt = mx.array([[1, 10, 8, 32, 44, 7]])  # <-- Note the double brackets because we
                                            #     have a batch dimension even
                                            #     though it is 1 in this case

generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]

# Since we haven't evaluated anything, nothing is computed yet. The list
# `generated` contains the arrays that hold the computation graph for the
# full processing of the prompt and the generation of 10 tokens.
#
# We can evaluate them one at a time, or all together. Concatenate them or
# print them. They would all result in very similar runtimes and give exactly
# the same results.
mx.eval(generated)

权重转换

本节假设您可以访问原始的 Llama 权重以及随附的 SentencePiece 模型。我们将编写一个小脚本将 PyTorch 权重转换为MLX兼容的权重,并将它们写入可以直接由 MLX 加载的 NPZ 文件中。

import argparse
from itertools import starmap

import numpy as np
import torch

def map_torch_to_mlx(key, value):
    if "tok_embedding" in key:
        key = "embedding.weight"

    elif "norm" in key:
        key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")

    elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
        key = key.replace("wq", "query_proj")
        key = key.replace("wk", "key_proj")
        key = key.replace("wv", "value_proj")
        key = key.replace("wo", "out_proj")

    elif "w1" in key or "w2" in key or "w3" in key:
        # The FFN is a separate submodule in PyTorch
        key = key.replace("feed_forward.w1", "linear1")
        key = key.replace("feed_forward.w3", "linear2")
        key = key.replace("feed_forward.w2", "linear3")

    elif "output" in key:
        key = key.replace("output", "out_proj")

    elif "rope" in key:
        return None, None

    return key, value.numpy()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
    parser.add_argument("torch_weights")
    parser.add_argument("output_file")
    args = parser.parse_args()

    state = torch.load(args.torch_weights)
    np.savez(
        args.output_file,
        **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
    )

加载权重

在将权重转换为与我们的实现兼容后,剩下的就是从磁盘加载它们,然后我们终于可以使用LLM生成文本了。我们可以使用mlx.core.load() 操作加载numpy格式的文件。

为了从 NPZ 文件的键/值表示中创建参数字典,我们将使用mlx.utils.tree_unflatten()辅助方法,如下所示:

from mlx.utils import tree_unflatten
model.update(tree_unflatten(list(mx.load(weight_file).items())))

mlx.utils.tree_unflatten()会接受看起来像layers.2.attention.query_proj.weight的NPZ文件中的键,并将它们转换为

{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}

然后可以用于更新模型。请注意,上述方法会导致从磁盘到numpy再到MLX的多次不必要的拷贝。它将在将来用直接加载到MLX的方式替代。

您可以在mlx-examples中下载完整的示例代码。假设在当前工作目录中存在 weights.pth 和 tokenizer.model,我们可以如下方式使用我们的推理脚本(计时代表了M1 Ultra和7B参数的Llama模型):

$ python convert.py weights.pth llama-7B.mlx.npz
$ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
[INFO] Loading model from disk: 5.247 s
Press enter to start generation
------
, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
------
[INFO] Prompt processing: 0.437 s
[INFO] Full generation: 4.330 s

我们观察到生成100个标记需要4.3秒,其中0.4秒用于处理提示。这相当于每个标记略超过39毫秒。

通过使用更大的提示来运行,我们可以看到每个标记的生成时间以及提示处理时间几乎保持不变。

Previous
Getting started