2025 AI & Machine Learning
WWDC25 · 19 min · AI & Machine Learning
Get started with MLX for Apple silicon
MLX is a flexible and efficient array framework for numerical computing and machine learning on Apple silicon. We’ll explore fundamental features including unified memory, lazy computation, and function transformations. We’ll also look at more advanced techniques for building and accelerating machine learning models across Apple’s platforms using Swift and Python APIs.
Watch at developer.apple.com ↗Chapters
Code shown on screen · 14 snippets
Basics
import mlx.core as mx
# Make an array
a = mx.array([1, 2, 3])
# Make another array
b = mx.array([4, 5, 6])
# Do an operation
c = a + b
# Access information about the array
shape = c.shape
dtype = c.dtype
print(f"Result c: {c}")
print(f"Shape: {shape}")
print(f"Data type: {dtype}") Unified memory
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.add(a, b, stream=mx.gpu)
d = mx.multiply(a, b, stream=mx.cpu)
print(f"c computed on the GPU: {c}")
print(f"d computed on the CPU: {d}") Lazy computation
import mlx.core as mx
# Make an array
a = mx.array([1, 2, 3])
# Make another array
b = mx.array([4, 5, 6])
# Do an operation
c = a + b
# Evaluates c before printing it
print(c)
# Also evaluates c
c_list = c.tolist()
# Also evaluates c
mx.eval(c)
print(f"Evaluate c by converting to list: {c_list}")
print(f"Evaluate c using print: {c}")
print(f"Evaluate c using mx.eval(): {c}") Function transformation
import mlx.core as mx
def sin(x):
return mx.sin(x)
dfdx = mx.grad(sin)
def sin(x):
return mx.sin(x)
d2fdx2 = mx.grad(mx.grad(mx.sin))
# Computes the second derivative of sine at 1.0
d2fdx2(mx.array(1.0)) Neural Networks in MLX
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
class MLP(nn.Module):
"""A simple MLP."""
def __init__(self, dim, h_dim):
super().__init__()
self.linear1 = nn.Linear(dim, h_dim)
self.linear2 = nn.Linear(h_dim, dim)
def __call__(self, x):
x = self.linear1(x)
x = nn.relu(x)
x = self.linear2(x)
return x MLX and PyTorch
# MLX version
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
class MLP(nn.Module):
"""A simple MLP."""
def __init__(self, dim, h_dim):
super().__init__()
self.linear1 = nn.Linear(dim, h_dim)
self.linear2 = nn.Linear(h_dim, dim)
def __call__(self, x):
x = self.linear1(x)
x = nn.relu(x)
x = self.linear2(x)
return x
# PyTorch version
import torch
import torch.nn as nn
import torch.optim as optim
class MLP(nn.Module):
"""A simple MLP."""
def __init__(self, dim, h_dim):
super().__init__()
self.linear1 = nn.Linear(dim, h_dim)
self.linear2 = nn.Linear(h_dim, dim)
def forward(self, x):
x = self.linear1(x)
x = x.relu()
x = self.linear2(x)
return x Compiling MLX functions
import mlx.core as mx
import math
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
def compiled_gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.normal(shape=(4,))
out = gelu(x)
compiled_out = compiled_gelu(x)
print(f"gelu: {out}")
print(f"compiled gelu: {compiled_out}") MLX Fast package
import mlx.core as mx
import time
def rms_norm(x, weight, eps=1e-5):
y = x.astype(mx.float32)
y = y * mx.rsqrt(mx.mean(
mx.square(y),
axis=-1,
keepdims=True,
) + eps)
return (weight * y).astype(x.dtype)
batch_size = 8192
feature_dim = 4096
iterations = 1000
x = mx.random.normal([batch_size, feature_dim])
weight = mx.ones(feature_dim)
bias = mx.zeros(feature_dim)
start_time = time.perf_counter()
for _ in range(iterations):
y = rms_norm(x, weight, eps=1e-5)
mx.eval(y)
rms_norm_time = time.perf_counter() - start_time
print(f"rms_norm execution: {gelu_time:0.4f} sec")
start_time = time.perf_counter()
for _ in range(iterations):
mx.eval(mx.fast.rms_norm(x, weight, eps=1e-5))
fast_rms_norm_time = time.perf_counter() - start_time
print(f"mx.fast.rms_norm execution: {compiled_gelu_time:0.4f} sec")
print(f"mx.fast.rms_norm speedup: {rms_norm_time/fast_rms_norm_time:0.2f}x") Custom Metal kernel
import mlx.core as mx
# Build the kernel
source = """
uint elem = thread_position_in_grid.x;
out[elem] = metal::exp(inp[elem]);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
# Call the kernel on a sample input
x = mx.array([1.0, 2.0, 3.0])
out = kernel(
inputs=[x],
grid=(x.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[x.shape],
output_dtypes=[x.dtype],
)[0]
print(out) Quantization
import mlx.core as mx
x = mx.random.normal([1024])
weight = mx.random.normal([1024, 1024])
quantized_weight, scales, biases = mx.quantize(
weight, bits=4, group_size=32,
)
y = mx.quantized_matmul(
x,
quantized_weight,
scales=scales,
biases=biases,
bits=4,
group_size=32,
)
w_orig = mx.dequantize(
quantized_weight,
scales=scales,
biases=biases,
bits=4,
group_size=32,
) Quantized models
import mlx.nn as nn
model = nn.Sequential(
nn.Embedding(100, 32),
nn.Linear(32, 32),
nn.Linear(32, 32),
nn.Linear(32, 1),
)
print(model)
nn.quantize(
model,
bits=4,
group_size=32,
)
print(model) Distributed
import mlx.core as mx
group = mx.distributed.init()
world_size = group.size()
rank = group.rank()
x = mx.array([1.0])
x_sum = mx.distributed.all_sum(x)
print(x_sum) Distributed launcher
mlx.launch --hosts ip1, ip2, ip3, ip4 my_script.py MLX Swift
// Swift
import MLX
// Make an array
let a = MLXArray([1, 2, 3])
// Make another array
let b = MLXArray([1, 2, 3])
// Do an operation
let c = a + b
// Access information about the array
let shape = c.shape
let dtype = c.dtype
// Print results
print("a: \(a)")
print("b: \(b)")
print("c = a + b: \(c)")
print("shape: \(shape)")
print("dtype: \(dtype)") Resources
Related sessions
-
20 min