Optimization Techniques in ML

Lecture 20 – Example 4: SGD, Mini-batch, Momentum & LR Schedules

Stochastic Gradient Descent (SGD)

For loss \(\mathcal{L}(\mathbf{w}) = \frac{1}{n}\sum_{i=1}^n \ell(\mathbf{w}; \mathbf{x}_i,y_i)\), pick a random sample (or mini-batch) and update

\[ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \, \widehat{\nabla} \mathcal{L}(\mathbf{w}_t). \]

With diminishing step-sizes (e.g., \(\eta_t = \eta_0/(1+\gamma t)\)) and convex \(\ell\), SGD converges in expectation.

Momentum & Nesterov

# Classical momentum
v = beta * v + (1-beta) * grad
w = w - eta * v

# Nesterov lookahead (pseudo)
w_look = w - eta*beta*v
grad = grad_at(w_look)
v = beta*v + (1-beta)*grad
w = w - eta*v

Momentum accelerates along gentle valleys and damps oscillations in steep directions.

Learning-Rate Schedules

  • Step decay: \(\eta_t = \eta_0\,\gamma^{\lfloor t/T\rfloor}\)
  • Cosine: \(\eta_t = \eta_{\min}+\tfrac{1}{2}(\eta_0-\eta_{\min})(1+\cos(\pi t/T))\)
  • Cyclical: triangular/triangular2 with periodic restarts
  • Warmup: start small then ramp to \(\eta_0\)

Worked Example (Mini-batch Logistic with Momentum)

import numpy as np
rng = np.random.default_rng(0)
X = rng.normal(size=(400, 3))
w_true = np.array([1.5,-2.0,0.5]); b_true = -0.3
logits = X @ w_true + b_true
p = 1/(1+np.exp(-logits))
y = (rng.uniform(size=400) < p).astype(float)

w = np.zeros(3); b = 0.0
eta = 0.1; beta = 0.9
v_w = np.zeros_like(w); v_b = 0.0
batch = 32
for t in range(200):
    idx = rng.choice(len(X), size=batch, replace=False)
    Xb, yb = X[idx], y[idx]
    z = Xb @ w + b
    pb = 1/(1+np.exp(-z))
    grad_w = Xb.T @ (pb - yb) / batch
    grad_b = np.sum(pb - yb) / batch
    v_w = beta*v_w + (1-beta)*grad_w
    v_b = beta*v_b + (1-beta)*grad_b
    w -= eta * v_w
    b -= eta * v_b

Adam (Bonus)

Bias-corrected moment estimates:

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
mh = m/(1-beta1**t)
vh = v/(1-beta2**t)
w -= eta * mh/(np.sqrt(vh) + 1e-8)

Adam adapts per-parameter step sizes; useful for sparse/ill-scaled problems.

← Back to L19