Dunfey · Hotel WWDC as data, est. 1983
Front desk everything
Years
Topics

2025 AI & Machine Learning

WWDC25 · 19 min · AI & Machine Learning

Get started with MLX for Apple silicon

MLX is a flexible and efficient array framework for numerical computing and machine learning on Apple silicon. We’ll explore fundamental features including unified memory, lazy computation, and function transformations. We’ll also look at more advanced techniques for building and accelerating machine learning models across Apple’s platforms using Swift and Python APIs.

Watch at developer.apple.com ↗

Transcript all transcripts

Chapters

Code shown on screen · 14 snippets

Basics python · at 3:48 ↗
import mlx.core as mx

# Make an array
a = mx.array([1, 2, 3])

# Make another array
b = mx.array([4, 5, 6])

# Do an operation
c = a + b

# Access information about the array
shape = c.shape
dtype = c.dtype

print(f"Result c: {c}")
print(f"Shape: {shape}")
print(f"Data type: {dtype}")
Unified memory python · at 5:31 ↗
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])

c = mx.add(a, b, stream=mx.gpu)
d = mx.multiply(a, b, stream=mx.cpu)

print(f"c computed on the GPU: {c}")
print(f"d computed on the CPU: {d}")
Lazy computation python · at 6:20 ↗
import mlx.core as mx

# Make an array
a = mx.array([1, 2, 3])

# Make another array
b = mx.array([4, 5, 6])

# Do an operation
c = a + b

# Evaluates c before printing it
print(c)

# Also evaluates c
c_list = c.tolist()

# Also evaluates c
mx.eval(c)

print(f"Evaluate c by converting to list: {c_list}")
print(f"Evaluate c using print: {c}")
print(f"Evaluate c using mx.eval(): {c}")
Function transformation python · at 7:32 ↗
import mlx.core as mx

def sin(x):
    return mx.sin(x)

dfdx = mx.grad(sin)

def sin(x):
    return mx.sin(x)

d2fdx2 = mx.grad(mx.grad(mx.sin))

# Computes the second derivative of sine at 1.0
d2fdx2(mx.array(1.0))
Neural Networks in MLX python · at 9:16 ↗
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

class MLP(nn.Module):
    """A simple MLP."""

    def __init__(self, dim, h_dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, h_dim)
        self.linear2 = nn.Linear(h_dim, dim)

    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        return x
MLX and PyTorch python · at 9:57 ↗
# MLX version
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

class MLP(nn.Module):
    """A simple MLP."""

    def __init__(self, dim, h_dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, h_dim)
        self.linear2 = nn.Linear(h_dim, dim)

    def __call__(self, x):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.linear2(x)
        return x

# PyTorch version
import torch
import torch.nn as nn
import torch.optim as optim

class MLP(nn.Module):
    """A simple MLP."""

    def __init__(self, dim, h_dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, h_dim)
        self.linear2 = nn.Linear(h_dim, dim)

    def forward(self, x):
        x = self.linear1(x)
        x = x.relu()
        x = self.linear2(x)
        return x
Compiling MLX functions python · at 11:35 ↗
import mlx.core as mx
import math

def gelu(x):
    return x * (1 + mx.erf(x / math.sqrt(2))) / 2

@mx.compile
def compiled_gelu(x):
    return x * (1 + mx.erf(x / math.sqrt(2))) / 2

x = mx.random.normal(shape=(4,))

out = gelu(x)
compiled_out = compiled_gelu(x)
print(f"gelu:          {out}")
print(f"compiled gelu: {compiled_out}")
MLX Fast package python · at 12:32 ↗
import mlx.core as mx
import time

def rms_norm(x, weight, eps=1e-5):
    y = x.astype(mx.float32)
    y = y * mx.rsqrt(mx.mean(
        mx.square(y), 
        axis=-1, 
        keepdims=True,
    ) + eps)
    return (weight * y).astype(x.dtype)

batch_size = 8192
feature_dim = 4096
iterations = 1000

x = mx.random.normal([batch_size, feature_dim])
weight = mx.ones(feature_dim)
bias = mx.zeros(feature_dim)

start_time = time.perf_counter()
for _ in range(iterations):
    y = rms_norm(x, weight, eps=1e-5)
    mx.eval(y)
rms_norm_time = time.perf_counter() - start_time
print(f"rms_norm execution: {gelu_time:0.4f} sec")

start_time = time.perf_counter()
for _ in range(iterations):
    mx.eval(mx.fast.rms_norm(x, weight, eps=1e-5))
fast_rms_norm_time = time.perf_counter() - start_time
print(f"mx.fast.rms_norm execution: {compiled_gelu_time:0.4f} sec")

print(f"mx.fast.rms_norm speedup: {rms_norm_time/fast_rms_norm_time:0.2f}x")
Custom Metal kernel python · at 13:30 ↗
import mlx.core as mx

# Build the kernel
source = """
    uint elem = thread_position_in_grid.x;
    out[elem] = metal::exp(inp[elem]);
"""
kernel = mx.fast.metal_kernel(
    name="myexp",
    input_names=["inp"],
    output_names=["out"],
    source=source,
)

# Call the kernel on a sample input
x = mx.array([1.0, 2.0, 3.0])
out = kernel(
    inputs=[x],
    grid=(x.size, 1, 1),
    threadgroup=(256, 1, 1),
    output_shapes=[x.shape],
    output_dtypes=[x.dtype],
)[0]
print(out)
Quantization python · at 14:41 ↗
import mlx.core as mx

x = mx.random.normal([1024])
weight = mx.random.normal([1024, 1024])

quantized_weight, scales, biases = mx.quantize(
        weight, bits=4, group_size=32,
)

y = mx.quantized_matmul(
    x,
    quantized_weight,
    scales=scales,
    biases=biases,
    bits=4,
    group_size=32,
)

w_orig = mx.dequantize(
    quantized_weight,
    scales=scales,
    biases=biases,
    bits=4,
    group_size=32,
)
Quantized models python · at 15:23 ↗
import mlx.nn as nn

model = nn.Sequential(
    nn.Embedding(100, 32),
    nn.Linear(32, 32),
    nn.Linear(32, 32),
    nn.Linear(32, 1),
)

print(model)

nn.quantize(
    model,
    bits=4,
    group_size=32,
)

print(model)
Distributed python · at 16:50 ↗
import mlx.core as mx

group = mx.distributed.init()

world_size = group.size()
rank = group.rank()

x = mx.array([1.0])

x_sum = mx.distributed.all_sum(x)

print(x_sum)
Distributed launcher bash · at 17:20 ↗
mlx.launch --hosts ip1, ip2, ip3, ip4 my_script.py
MLX Swift swift · at 18:20 ↗
// Swift
import MLX

// Make an array
let a = MLXArray([1, 2, 3])

// Make another array
let b = MLXArray([1, 2, 3])

// Do an operation
let c = a + b

// Access information about the array
let shape = c.shape
let dtype = c.dtype

// Print results
print("a: \(a)")
print("b: \(b)")
print("c = a + b: \(c)")
print("shape: \(shape)")
print("dtype: \(dtype)")

Resources