Course 4 — #courses
Note
Workflow labs use the variant template: Goal → Approach → Execution → Check → Report.
torch.Regularisation from Week 1 Session 2; basic gradient descent intuition.
Neural networks rarely beat gradient-boosted trees on small tabular biomedical problems. They come into their own when the feature representation itself must be learned (images, sequences, free text) or when the sample size is large enough to feed a flexible model. On a feature matrix of hundreds of rows and tens of columns, a logistic or linear regression with appropriate regularisation is a better baseline and often a better final model.
That said, understanding the training loop — forward pass, loss, backward pass, optimiser step — is a prerequisite for everything in imaging, sequence, and modern generative modelling. This lab walks through that loop on a small simulated regression so the shapes and operations are transparent.
We mark the full training loop with #| eval: false because installation of torch downloads a library that may not be present in every rendering environment. A tiny parallel example using base R gradient descent is shown first so the ideas run end-to-end no matter what is installed.
Fit a small regression by manual gradient descent, then show the equivalent torch code with #| eval: false.
Base-R gradient descent on the same problem.
w <- rep(0, p); lr <- 0.01
losses <- numeric(200)
for (step in 1:200) {
yhat <- X %*% w
resid <- as.numeric(y - yhat)
grad <- -2 * t(X) %*% resid / n
w <- w - lr * grad
losses[step] <- mean(resid^2)
}
tibble(step = 1:200, loss = losses) |>
ggplot(aes(step, loss)) + geom_line() +
labs(x = "iteration", y = "mean squared error")The equivalent torch code.
library(torch)
x_t <- torch_tensor(X, dtype = torch_float())
y_t <- torch_tensor(matrix(y, ncol = 1), dtype = torch_float())
net <- nn_sequential(
nn_linear(p, 16), nn_relu(),
nn_linear(16, 1)
)
optimizer <- optim_adam(net$parameters, lr = 1e-2)
loss_fn <- nn_mse_loss()
for (epoch in 1:200) {
optimizer$zero_grad()
pred <- net(x_t)
loss <- loss_fn(pred, y_t)
loss$backward()
optimizer$step()
}Compare our gradient-descent weights against the OLS solution.
A small feed-forward network and a linear model both recover the true coefficients on a simulated
n = 500,p = 10regression. Under this signal-to-noise ratio and sample size, the linear model is the appropriate baseline; the neural network is shown here for machinery, not as the recommended estimator.
The training-loop pattern — zero grads, forward, loss, backward, step — is the same whether you are fitting a 3-layer MLP on penguins or a ResNet on histology. The difference is the architecture, the data pipeline, and the compute.
torch for R book by Keydana (online).R version 4.4.1 (2024-06-14)
Platform: x86_64-pc-linux-gnu
Running under: Ubuntu 24.04.4 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
locale:
[1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
[4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
[7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C
time zone: UTC
tzcode source: system (glibc)
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] lubridate_1.9.5 forcats_1.0.1 stringr_1.6.0 dplyr_1.2.1
[5] purrr_1.2.2 readr_2.2.0 tidyr_1.3.2 tibble_3.3.1
[9] ggplot2_4.0.3 tidyverse_2.0.0
loaded via a namespace (and not attached):
[1] gtable_0.3.6 jsonlite_2.0.0 compiler_4.4.1 tidyselect_1.2.1
[5] scales_1.4.0 yaml_2.3.12 fastmap_1.2.0 R6_2.6.1
[9] labeling_0.4.3 generics_0.1.4 knitr_1.51 htmlwidgets_1.6.4
[13] pillar_1.11.1 RColorBrewer_1.1-3 tzdb_0.5.0 rlang_1.2.0
[17] stringi_1.8.7 xfun_0.57 S7_0.2.2 otel_0.2.0
[21] timechange_0.4.0 cli_3.6.6 withr_3.0.2 magrittr_2.0.5
[25] digest_0.6.39 grid_4.4.1 hms_1.1.4 lifecycle_1.0.5
[29] vctrs_0.7.3 evaluate_1.0.5 glue_1.8.1 farver_2.1.2
[33] rmarkdown_2.31 tools_4.4.1 pkgconfig_2.0.3 htmltools_0.5.9