Skip to content

3.2 Infrastructure at Scale

The Meta-Narrative

Training GPT-4 required an estimated ~$100M in compute. Serving it costs millions per day. At this scale, infrastructure is the product. Understanding GPU architectures, networking topologies, and system-level optimizations is as important as understanding Transformers. This chapter covers the engineering of large-scale AI infrastructure.


GPU Architecture Internals

NVIDIA GPU Memory Hierarchy

graph TD
    SM["Streaming Multiprocessor (SM)"] --> REG["Registers<br/>(256KB per SM, fastest)"]
    SM --> SMEM["Shared Memory / L1 Cache<br/>(164KB per SM)"]
    SMEM --> L2["L2 Cache<br/>(40-60MB, chip-wide)"]
    L2 --> HBM["HBM2e / HBM3<br/>(40-80GB, highest bandwidth)"]

GPU Comparison for AI

GPU Year FLOPS (FP16) Memory Memory BW Use Case
V100 2017 125 TFLOPS 32GB HBM2 900 GB/s Legacy training
A100 2020 312 TFLOPS 80GB HBM2e 2 TB/s Current workhorse
H100 2022 990 TFLOPS 80GB HBM3 3.35 TB/s LLM training
H200 2024 990 TFLOPS 141GB HBM3e 4.8 TB/s Larger models in memory
B200 2024 2.2 PFLOPS 192GB HBM3e 8 TB/s Next-gen frontier

Roofline Model: Understanding Performance Bottlenecks

Every operation is either compute-bound or memory-bound:

  • Compute-bound (matmuls): Performance limited by FLOPS → use Tensor Cores
  • Memory-bound (activations, attention): Performance limited by memory bandwidth → use FlashAttention

The arithmetic intensity \(I = \frac{\text{FLOPs}}{\text{Bytes transferred}}\) determines which category an operation falls into.


Networking for Distributed Training

Communication Patterns

Pattern Description Used In
All-Reduce Average gradients across all workers Data parallelism
All-Gather Collect full tensors from all workers Model parallelism (ZeRO-3)
Reduce-Scatter Reduce + redistribute shards ZeRO optimizer states
Point-to-Point Direct GPU-to-GPU transfer Pipeline parallelism

Interconnect Technologies

Technology Bandwidth Latency Topology
NVLink 4.0 900 GB/s Very low Intra-node (GPU↔GPU)
NVSwitch 3.6 TB/s Very low Full-mesh intra-node
InfiniBand (400G) 400 Gb/s Low Inter-node (RDMA)
Ethernet (100G) 100 Gb/s Medium Inter-node

Model Serving Architecture

Real-Time Inference Stack

graph LR
    A["Load Balancer<br/>(NGINX, ALB)"] --> B["API Gateway<br/>(rate limiting, auth)"]
    B --> C["Model Server<br/>(Triton, vLLM, TGI)"]
    C --> D["GPU Pool<br/>(auto-scaled)"]
    C --> E["Model Registry<br/>(versioned models)"]
    B --> F["Cache Layer<br/>(Redis: frequent queries)"]

Serving Optimization Techniques

Technique Mechanism Speedup
Batching Group requests for GPU efficiency 2-10× throughput
KV-Cache Cache attention keys/values from previous tokens Essential for autoregressive
Quantization INT8/INT4 inference 2-4× speedup
Compilation TorchScript, TensorRT, ONNX 1.5-3× speedup
Speculative Decoding Small draft model + large verifier 2-3× for LLMs
🚀 Lab: High-Performance Model Serving
"""FastAPI model server with batching and health checks."""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import time
from collections import deque

app = FastAPI(title="ML Model Server v2")

# Load model
model = torch.jit.load("model.pt")
model.eval()

# Request/response schemas
class PredictRequest(BaseModel):
    features: list[float]

class PredictResponse(BaseModel):
    prediction: list[float]
    latency_ms: float

class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    avg_latency_ms: float

latency_history = deque(maxlen=100)

@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    start = time.time()
    try:
        input_tensor = torch.FloatTensor([request.features])
        with torch.no_grad():
            output = model(input_tensor)
        latency = (time.time() - start) * 1000
        latency_history.append(latency)
        return PredictResponse(
            prediction=output.numpy().tolist()[0],
            latency_ms=round(latency, 2),
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health", response_model=HealthResponse)
async def health():
    avg_lat = sum(latency_history) / len(latency_history) if latency_history else 0
    return HealthResponse(
        status="healthy", model_loaded=True, avg_latency_ms=round(avg_lat, 2)
    )

References

  • Jia, Z. et al. (2018). Dissecting the NVIDIA Volta GPU Architecture via Microbenchmarking.
  • Rajbhandari, S. et al. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.
  • Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.