AI Infra 学习路线
AI Infra(人工智能基础设施)是大模型时代壁垒最高、最核心的技术高地。本文从前置基础到推理部署,系统梳理 AI Infra 的完整学习路线,为每个模块列出需要掌握的知识点、推荐学习资料以及可量化的检验标准,帮助从业者建立体系化的知识树。
🗺️ 全景概览:三层架构
AI Infra 的本质是 "用系统工程释放硬件算力"。自底向上分为三个核心层级加一个前置知识层:
| 层级 | 名称 | 核心关注点 |
|---|---|---|
| 第零层 | 前置知识 | 编程语言、数学基础、Transformer 架构、PyTorch、通信拓扑 |
| 第一层 | CUDA编程与算子优化 | GPU架构、存储层次、Kernel编写、FlashAttention、AI编译器 |
| 第二层 | 分布式训练 | 数据并行、3D并行、ZeRO、混合精度 |
| 第三层 | 推理与部署 | KV Cache、PagedAttention、量化、Speculative Decoding |
所有的优化都是在 "计算、通信、显存" 这个不可能三角中做取舍:ZeRO 是用通信换显存;重计算(Activation Checkpointing)是用计算换显存;量化是用精度换显存和带宽。学习时始终问自己:这个技术牺牲了什么,换取了什么?
📖 第零层:前置知识
0.1 知识点
编程语言
- Python:熟练使用面向对象、装饰器、生成器、多进程/多线程、性能 profiling
- C/C++:理解指针、内存管理、编译链接过程,能读懂 C++ 项目代码
- Linux 基础:命令行操作、Shell 脚本、进程管理、环境变量配置
数学基础
- 线性代数:矩阵乘法、转置、分块矩阵、特征值分解基本概念。看到 (B, S, H) × (H, V) 能立刻知道结果是 (B, S, V)
- 基础概率论:概率分布、期望、方差、Softmax 的概率解释、交叉熵损失含义
- 微积分(了解):链式法则、梯度的含义
Transformer 架构
必须理解:
- Self-Attention:Q、K、V 的含义与计算过程(QK^T → scale → softmax → PV),Attention 的计算复杂度为
- 前馈网络(FFN):两层线性变换 + 激活函数
- 位置编码:Sinusoidal、RoPE 等
- LayerNorm:Pre-Norm vs Post-Norm 的区别
- 完整前向过程:从 token embedding 开始,能逐步跟踪数据在一个 Transformer Block 中的流转,说清每一步的输入输出维度
PyTorch 框架
- Tensor 操作、自动微分(autograd)、Module / Parameter 的组织方式
- 训练循环:DataLoader → forward → loss → backward → optimizer.step
- 基本调试:
torch.cuda.memory_summary()、torch.profiler
通信拓扑
- 单机内部:NVLink / NVSwitch 带宽与拓扑
- 多机间:InfiniBand(IB)网络、RoCE 协议
- 集合通信原语:AllReduce、AllGather、ReduceScatter 的含义与通信量公式
- NCCL:NVIDIA 集合通信库的基本用法
0.2 推荐资料
| 类型 | 资料 | 说明 |
|---|---|---|
| 论文 | Attention Is All You Need | Transformer 原始论文,必读 |
| 教程 | The Illustrated Transformer (Jay Alammar) | 图文并茂的 Transformer 入门 |
| 工具 | Andrej Karpathy:Let's build GPT from scratch | 从零手写 GPT,每个模块都过一遍 |
| 教程 | PyTorch 官方教程(60 Minute Blitz) | PyTorch 快速入门 |
| 书籍 | 3Blue1Brown:线性代数的本质(视频系列) | 建立线性代数几何直觉 |
| 官方文档 | NVIDIA NCCL 文档 | 集合通信原语与多卡编程 |
0.3 检验标准
- Transformer 白板默写:不看资料,能画出一个 Decoder Block 完整结构,标注每步输入输出维度
- 维度推导:给定 7B 模型配置(hidden_dim=4096, num_heads=32, num_layers=32, vocab_size=32000),能手算总参数量(误差 ≤20%)
- PyTorch 训练脚本:能独立写出完整训练循环(含 DataLoader、forward、loss、backward、optimizer step、checkpoint),在 GPU 上跑通
- Linux 日常:SSH 登录、tmux、conda/pip 管理、nvidia-smi、git、bash 脚本
💻 第一层:CUDA编程与算子优化
1.1 知识点
GPU 硬件架构
把 GPU 想象成一座拥有数千个简单工人的超级工厂——每个工人(CUDA Core)只会基本加减乘除,但胜在人多,吞吐量远超 CPU。
- SM(流多处理器)、Tensor Core、CUDA Core 的区别与协作
- 主流 GPU 规格:A100 / H100 / H200 的算力、显存带宽、HBM 容量
- Memory Wall:显存带宽瓶颈往往比算力瓶颈更致命
- 存储层次:寄存器 > 共享内存 > L1/L2 Cache > HBM > 主机内存
CUDA 编程基础
- 编程模型:Grid / Block / Thread 层级,线程索引计算
- 内存模型:全局内存、共享内存、寄存器、常量内存
- 关键概念:Warp(32线程的最小调度单位)、Bank Conflict、Coalesced Access、Occupancy
常见算子实现与优化
- Reduce:并行归约(Warp Shuffle、多级归约)
- GEMM:分块、向量化、Shared Memory Tiling、Tensor Core
- Softmax:Online normalizer calculation
- 算子融合:将多个小算子合并为一个 kernel,减少全局内存读写
Attention 算子
- FlashAttention V1/V2:通过 tiling 减少 HBM 访问——把大桌子上的拼图分成小块,每次只搬一小块到手边,避免把所有碎片一股脑倒出来。HBM 读写从 降到
- FlashAttention-3:在 Hopper 架构上进一步拉高利用率
- Flash-Decoding:面向 Decode 阶段的 Attention 加速
- PagedAttention CUDA Kernel:vLLM 中 PagedAttention 的底层实现
AI 编译器
- Triton:OpenAI 开源的 GPU 编程语言,大幅降低高效算子编写门槛
torch.compile:PyTorch 2.x 的编译模式,理解 Graph Break 与性能收益
1.2 推荐资料
| 类型 | 资料 | 说明 |
|---|---|---|
| 入门教程 | 小小将:CUDA编程入门极简教程 | CUDA 零基础入门 |
| 官方文档 | NVIDIA CUDA Programming Guide | CUDA 编程权威参考 |
| GEMM | 猛猿:从啥也不会到CUDA GEMM优化 | 从基础分块到极致优化 |
| Attention | FlashAttention V1/V2 Paper | Memory-aware Attention 里程碑 |
| 解读 | 猛猿:图解FlashAttention V1/V2 系列 | 适合新手入门的图文解读 |
| 编译器 | Triton 官方教程 | GPU 编程新范式 |
| 工具 | Nsight Systems User Guide | CPU-GPU 交互分析 |
| 工具 | Nsight Compute Profiling Guide | Kernel 级下钻,定位瓶颈 |
1.3 检验标准
- 硬件参数直觉:拿到 H100,不查资料能说出 HBM 容量(80GB)、带宽(~3.35TB/s)的量级
- Reduce 三连:从全局内存原子加 → 共享内存+树形归约 → Warp Shuffle,三版本跑 Nsight Compute 对比
- GEMM 分块:实现基于 Shared Memory Tiling 的 GEMM kernel,达到 cuBLAS 50% 以上性能
- FlashAttention 白板推导:能在白板上画出 tiling 过程,说清为什么 HBM 读写从 降到
- Profiling 实战:用 Nsight Systems 定位 GPU idle gap 来源;用 Nsight Compute 判断 kernel 是 memory bound 还是 compute bound
🏋️ 第二层:分布式训练
打个比方,训练千亿参数大模型就像抄写一本数万页的百科全书——数据并行是把同一本书复印多份、每人抄不同章节内容然后汇总;张量并行是把每一页拆成几列、每人只抄自己那几列;流水线并行则是第一个人抄完第一章就传给第二人继续。
2.1 知识点
优化器
- Adam / AdamW:每个参数维护两个状态——一阶动量(梯度的指数移动平均)和二阶动量(梯度平方的指数移动平均)
- 优化器状态的显存开销:以 AdamW + 混合精度为例,每个参数额外需要 FP32 参数副本(4B)+ 一阶动量(4B)+ 二阶动量(4B)= 12字节/参数。7B 模型的优化器状态就要占 ~84GB
数据并行
- DDP(DistributedDataParallel):多进程数据并行,理解 AllReduce 梯度同步
- FSDP:PyTorch 原生的 ZeRO-3 实现
模型并行(3D 并行)
- 张量并行(TP):将矩阵乘法沿特定维度切分到多卡,通信密集,通常限于单机
- 流水线并行(PP):将模型不同层切分到不同机器
- 序列并行(SP):沿序列维度切分,与 TP 配合减少激活显存
显存优化
ZeRO 系列(好比合租房里每人只存自己那份家具,需要时互相借用):
- ZeRO-1:优化器状态切分
- ZeRO-2:优化器状态 + 梯度切分
- ZeRO-3:优化器状态 + 梯度 + 参数切分(用通信换显存)
混合精度训练:FP16 / BF16 / FP8 训练,减少显存占用。BF16 比 FP16 指数位更宽(8位 vs 5位),动态范围接近 FP32,不容易 overflow/underflow。
Activation Checkpointing:只保存部分激活值,需要时重新计算,用计算换显存。
2.2 推荐资料
| 类型 | 资料 | 说明 |
|---|---|---|
| 论文 | Megatron-LM Paper | TP 与 PP 原理的里程碑论文 |
| 论文 | ZeRO Paper(DeepSpeed) | 显存优化的核心方法 |
| 文档 | DeepSpeed 官方文档 | ZeRO 配置与使用 |
| 文档 | PyTorch DDP / FSDP 教程 | 原生分布式训练入门 |
| 论文 | DeepSeek V2 技术报告 | MLA 注意力机制 |
| 论文 | DeepSeekMoE Paper | MoE 架构设计 |
2.3 检验标准
- 显存账本:拿到 7B 模型,能口算 FP16 下参数占 ~14GB、Adam 优化器状态占 ~56GB,判断单卡 80GB 能否放下完整训练状态
- ZeRO 拆解:能一句话讲清 ZeRO-2 和 ZeRO-3 的差异(参数是否切分,通信量差异)
- DDP 改造:拿到单卡训练脚本,30 分钟内改成 DDP 多卡版本并跑通
- 3D 并行拓扑:给 64 卡集群(8节点×8卡),能设计出 TP=8(机内)、PP=4(跨机)、DP=2 的并行方案,说明为什么 TP 不能跨机
🚀 第三层:推理与部署
训练是"教会模型知识",推理是"让模型上考场答题"——最重要的是答题速度和同时服务多少考生。
3.1 LLM 推理基础
- 两阶段:Prefill(处理输入,compute-bound)与 Decode(逐 token 生成,memory-bound)
- KV Cache:自回归生成的"草稿纸",把已算过的 K/V 缓存起来避免重复计算,但会随序列长度线性增长显存
- 关键指标:TTFT(首 token 延迟)、TPOT(每 token 延迟)、吞吐量(token/s)、P50/P95 尾延迟
KV Cache 估算:给定 LLaMA-2-7B(32层、32头、head_dim=128),上下文长度 4096,batch_size=16,FP16:
3.2 推理引擎
- PagedAttention:vLLM 提出的虚拟内存分页思想管理 KV Cache,解决碎片化问题
- Continuous Batching:动态组批,请求随到随处理(类似网约车拼单,随到随拼)
- Prefix Cache / RadixAttention:复用已计算的 KV Cache,优化重复前缀场景
| 框架 | 核心特性 | 适用场景 |
|---|---|---|
| vLLM | PagedAttention、Continuous Batching、Prefix Cache | 通用推理服务,社区活跃 |
| SGLang | RadixAttention、cFSM 结构化输出加速 | 复杂 Agent、多轮生成 |
| TensorRT-LLM | Inflight Batching、深度硬件优化 | 追求极限性能、NVIDIA 生态 |
3.3 量化
量化的本质是把高清照片压缩成缩略图——用更少的比特位表示权重,省下显存和带宽,代价是精度会有一定损失。
- W8A8(SmoothQuant):将 activation 的 outlier 难题转移到 weights,工程友好
- INT4(GPTQ / AWQ):只量化权重到 3/4-bit,减少显存和带宽占用
- KV Cache 量化(KIVI):2-bit 量化,长上下文场景效果显著
目标:省显存?省带宽?提吞吐?
├─ 通用、工程友好 → W8A8 (SmoothQuant)
├─ 更省显存/带宽 → INT4 weight-only (AWQ/GPTQ)
└─ 长上下文/大并发 → KV Cache 量化 (KIVI)
3.4 Speculative Decoding
好比让实习生先快速起草一段文字,再让资深主编一次性审阅:猜对的直接用,猜错的当场改,比主编逐字逐句从头写快得多。
- Speculative Sampling:小模型批量猜测 → 大模型一次性验证,保证分布无偏
- Medusa:多个 Decoding Heads 并行预测多 token
- EAGLE-2:动态 Draft Tree,更激进地产生可接受 token
正确性保证:rejection sampling 机制保证接受的 token 严格服从 target model 的分布。
3.5 性能分析工具
torch.profiler:PyTorch 官方 profiler,定位算子与 shape- Nsight Systems:CPU-GPU 交互全链路分析,找到"哪里慢"
- Nsight Compute:Kernel 级分析,找到"为什么慢"
- GenAI-Perf:LLM 指标一站式输出(TTFT/TPOT/throughput)
🧭 新人破局指南
推荐学习路径
基础阶段(0-3个月)
- 完成第零层全部检验标准
- 学习 CUDA 编程基础,能写简单的 Reduce / GEMM kernel
- 用 PyTorch DDP 将训练分布到两张卡上
专项深入(3-6个月)
- 精读四篇里程碑论文并对照代码:Megatron-LM、ZeRO、FlashAttention、vLLM
- 参与开源项目(vLLM、DeepSpeed、SGLang)
工程实践(6个月以上)
- 在 GPU 集群上部署百亿/千亿参数模型,优化端到端性能
- 建立完整的性能分析与回归体系
核心权衡思维
| 优化技术 | 牺牲了什么 | 换取了什么 |
|---|---|---|
| ZeRO | 通信带宽 | 显存空间 |
| Activation Checkpointing | 计算时间 | 显存空间 |
| 量化 | 精度 | 显存 + 带宽 + 吞吐 |
| Speculative Decoding | Prefill 开销 | Decode 速度 |
| FlashAttention | 实现复杂度 | 显存 + 速度 |
| Prefill/Decode 解耦 | 系统复杂度 | 尾延迟 + goodput |
📚 核心参考论文
- Attention Is All You Need:arxiv.org/abs/1706.03762
- Megatron-LM:arxiv.org/abs/1909.08053
- ZeRO:arxiv.org/abs/1910.02054
- FlashAttention V2:arxiv.org/abs/2307.08691
- vLLM / PagedAttention:arxiv.org/abs/2309.06180
- SGLang:arxiv.org/abs/2312.07104
- SmoothQuant:arxiv.org/abs/2211.10438
- AWQ:arxiv.org/abs/2306.00978
- DistServe:arxiv.org/abs/2401.09670