# XGBoost learns the Canadian Flag

XGBoost is a machine learning library that’s great for classification tasks. It’s often seen in Kaggle competitions, and usually beats other classifiers like logistic regression, random forests, SVMs, and shallow neural networks. One day, I was feeling slightly patriotic, and wondered: can XGBoost learn the Canadian flag?

Above: Our home and native land

Let’s find out!

## Preparing the dataset

The task is to classify each pixel of the Canadian flag as either red or white, given limited data points. First, we read in the image with R and take the red channel:

library(png)
library(ggplot2)
library(xgboost)

red <- img[,,2]

HEIGHT <- dim(red)[1]
WIDTH <- dim(red)[2]


Next, we sample 7500 random points for training. Also, to make it more interesting, each point has a probability 0.05 of flipping to the opposite color.

ERROR_RATE <- 0.05

get_data_points <- function(N) {
x <- sample(1:WIDTH, N, replace = T)
y <- sample(1:HEIGHT, N, replace = T)
p <- red[cbind(y, x)]
p <- round(p)
flips <- sample(c(0, 1), N, replace = T,
prob = c(ERROR_RATE, 1 - ERROR_RATE))
p[flips == 1] <- 1 - p[flips == 1]
data.frame(x=as.numeric(x), y=as.numeric(y), p=p)
}

data <- get_data_points(7500)


This is what our classifier sees:

Alright, let’s start training.

## Quick introduction to XGBoost

XGBoost implements gradient boosted decision trees, which were first proposed by Friedman in 1999.

Above: XGBoost learns an ensemble of short decision trees

The output of XGBoost is an ensemble of decision trees. Each individual tree by itself is not very powerful, containing only a few branches. But through gradient boosting, each subsequent tree tries to correct for the mistakes of all the trees before it, and makes the model better. After many iterations, we get a set of decision trees; the sum of the all their outputs is our final prediction.

For more technical details of how this works, refer to this tutorial or the XGBoost paper.

## Experiments

Fitting an XGBoost model is very easy using R. For this experiment, we use decision trees of height 3, but you can play with the hyperparameters.

fit <- xgboost(data = matrix(c(data$x, data$y), ncol = 2), label = data$p, nrounds = 1, max_depth = 3)  We also need a way of visualizing the results. To do this, we run every pixel through the classifier and display the result: plot_canada <- function(dataplot) { dataplot$y <- -dataplot$y dataplot$p <- as.factor(dataplot$p) ggplot(dataplot, aes(x = x, y = y, color = p)) + geom_point(size = 1) + scale_x_continuous(limits = c(0, 240)) + scale_y_continuous(limits = c(-120, 0)) + theme_minimal() + theme(panel.background = element_rect(fill='black')) + theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank()) + scale_color_manual(values = c("white", "red")) } fullimg <- expand.grid(x = as.numeric(1:WIDTH), y = as.numeric(1:HEIGHT)) fullimg$p <- predict(fit, newdata = matrix(c(fullimg$x, fullimg$y), ncol = 2))
fullimg$p <- as.numeric(fullimg$p > 0.5)



In the first iteration, XGBoost immediately learns the two red bands at the sides:

After a few more iterations, the maple leaf starts to take form:

By iteration 60, it learns a pretty recognizable maple leaf. Note that the decision trees split on x or y coordinates, so XGBoost can’t learn diagonal decision boundaries, only approximate them with horizontal and vertical lines.

If we run it for too long, then it starts to overfit and capture the random noise in the training data. In practice, we would use cross validation to detect when this is happening. But why cross-validate when you can just eyeball it?

That was fun. If you liked this, check out this post which explores various classifiers using a flag of Australia.

The source code for this blog post is posted here. Feel free to experiment with it.

# What if math contests were scored using Principal Component Analysis?

In many math competitions, all problems are weighted equally, even though the problems have very different difficulties. Sometimes, the harder problems are weighted more. But how should we assign weights to each problem?

Usually, the organizers make up weights based on how difficult they believe the problems are. However, they sometimes misjudge the difficulty of problems. Wouldn’t it be better if the weightings were determined from data?

Let’s try Principal Component Analysis!

Principal Component Analysis (PCA) is a statistical procedure that finds a transformation of the data that maximizes the variance. In our case, the first principal component gives a relative weighting of the problems that maximizes the variance of the total scores. This makes sense because we want to separate the good and bad students in a math contest.

## IMO 2017 Data

The International Mathematics Olympiad (IMO) is an annual math competition for top high school students around the world. It consists of six problems, divided between two days: on each day, contestants are given 4.5 hours to solve three problems.

Here are the 2017 problems, if you want to try them.

Above: Score distribution for IMO 2017

This year, 615 students wrote the IMO. Problems 1 and 4 were the easiest, with the majority of contestants receiving full scores. Problems 3 and 6 were the hardest: only 2 students solved the third problem. Problems 2 and 5 were somewhere in between.

This is a good dataset to play with, because the individual results show what each student scored for every problem.

## Derivation of PCA for the 1-dimensional case

Let $X$ be a matrix containing all the data, where each column represents one problem. There are 615 contestants and 6 problems so $X$ has 615 rows and 6 columns.

We wish to find a weight vector $\vec u \in \mathbb{R}^{6 \times 1}$ such that the variance of $X \vec u$ is maximized. Of course, scaling up $\vec u$ by a constant factor also increases the variance, so we need the constraint that $| \vec u | = 1$.

First, PCA requires that we center $X$ so that the mean for each of the problems is 0, so we subtract each column by its mean. This transformation shifts the total score by a constant, and doesn’t affect the relative weights of the problems.

Now, $X \vec u$ is a vector containing the total scores of all the contestants; its variance is the sum of squares of its elements, or $| X \vec u |^2$.

To maximize $|X \vec u |^2$ subject to $|\vec u| = 1$, we take the singular value decomposition of $X = U \Sigma V^T$. Then, the leftmost column of $V$ (corresponding to the largest singular value) gives $\vec u$ that maximizes $| X \vec u|^2$. This gives the first principal axis, and we are done.

## Experiments

Running PCA on the IMO 2017 data produced interesting results. After re-scaling the weights so that the minimum possible score is 0 and the maximum possible score is 42 (to match IMO’s scoring), PCA recommends the following weights:

• Problem 1: 9.15 points
• Problem 2: 9.73 points
• Problem 3: 0.15 points
• Problem 4: 15.34 points
• Problem 5: 5.59 points
• Problem 6: 2.05 points

This is the weighting that produces the highest variance. That’s right, solving the hardest problem in the history of the IMO would get you a fraction of 1 point. P4 had the highest variance of the six problems, so PCA gave it the highest weight.

The scores and rankings produced by the PCA scheme are reasonably well-correlated with the original scores. Students that did well still did well, and students that did poorly still did poorly. The top students that solved the harder problems (2, 3, 5, 6) usually also solved the easier problems (1 and 2). The students that would be the unhappiest with this scheme are a small number of people who solved P3 or P6, but failed to solve P4.

Here’s a comparison of score distributions with the original and PCA scheme. There is a lot less separation between the best of the best students and the middle of the pack. It is easy to check that PCA does indeed produce higher variance than weighing all six problems equally.

Now, let me comment on the strange results.

It’s clearly absurd to give 0.15 points to the hardest problem on the IMO, and make P4, a much easier problem, be worth 100 times more. But it makes sense from PCA’s perspective. About 99% of the students scored zero on P3, so its variance is very low. Given that PCA has a limited amount of weight to “spend” to increase the total variance, it would be wasteful to use much of it on P3.

The PCA score distribution has less separation between the good students and the best students. However, by giving a lot of weight to P1 and P4, it clearly separates mediocre students that solve one problem from the ones who couldn’t solve anything at all.

In summary, scoring math contests using PCA doesn’t work very well. Although it maximizes overall variance, math contests are asymmetrical as we care about differentiating between the students on the top end of the spectrum.

## Source Code

If you want to play with the data, I uploaded it as a Kaggle dataset.

The code for this analysis is available here.

# Learning R as a Computer Scientist

If you’re into statistics and data science, you’ve probably heard of the R programming language. It’s a statistical programming language that has gained much popularity lately. It comes with an environment specifically designed to be good at exploring data, plotting visualizations, and fitting models.

R is not like most programming languages. It’s quite different from any other language I’ve worked with: it’s developed by statisticians, who think differently from programmers. In this blog post, I describe some of the pitfalls that I ran into learning R with a computer science background. I used R extensively in two stats courses in university, and afterwards for a bunch of data analysis projects, and now I’m just starting to be comfortable and efficient with it.

## Why a statistical programming language?

When I encountered R for the first time, my first reaction was: “why do we need a new language to do stats? Can’t we just use Python and import some statistical libraries?”

Sure, you can, but R is very streamlined for it. In Python, you would need something like scipy for fitting models, and something like matplotlib to display things on screen. With R, you get RStudio, a complete environment, and it’s very much batteries-included. In RStudio, you can parse the data, run statistics on it, and visualize results with very few lines of code.

Aside: RStudio is an IDE for R. Although it’s possible to run R standalone from the command line, in practice almost everyone uses RStudio.

I’ll do a quick demo of fitting a linear regression on a dataset to demonstrate how easy it is to do in R. First, let’s load the CSV file:

df <- read.csv("fossum.csv")


This reads a dataset containing body length measurements for a bunch of possums. Don’t ask why, it was used in a stats course I took. R parses the CSV file into a data frame and automatically figures out the dimensions and variable names and types.

Next, we fit a linear regression model of the total length of the possum versus the head length:

model <- lm(totlngth ~ hdlngth, df)


It’s one line of code with the lm function. What’s more, fitting linear models is so common in R that the syntax is baked into the language.

Aside: Here, we did totlngth ~ hdlngth to perform a single variable linear regression, but the notation allows fancier stuff. For example, if we did lm(totlngth ~ (hdlngth + age)^2), then we would get a model including two variables and the second order interaction effects. This is called Wilkinson-Rogers notation, if you want to read more about it.

We want to know how the model is doing, so we run the summary command:

> summary(model)

Call:
lm(formula = totlngth ~ hdlngth, data = df)

Residuals:
Min     1Q Median     3Q    Max
-7.275 -1.611  0.136  1.882  5.250

Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept)  -28.722     14.655  -1.960   0.0568 .
hdlngth        1.266      0.159   7.961  7.5e-10 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 2.653 on 41 degrees of freedom
Multiple R-squared:  0.6072,	Adjusted R-squared:  0.5976
F-statistic: 63.38 on 1 and 41 DF,  p-value: 7.501e-10


Don’t worry if this doesn’t mean anything to you, it’s just dumping the parameters of the models it fit, and ran a bunch of tests to determine how significant the model is.

Lastly, let’s visualize the regression with a scatterplot:

plot(df$hdlngth, df$totlngth)
abline(model)


And R gives us a nice plot:

All of this only took 4 lines of R code! Hopefully I’ve piqued your interest by now — R is great for quickly trying out a lot of different models on your data without too much effort.

That being said, R has a somewhat steep learning curve as a lot of things don’t work the way you’d expect. Next, I’ll mention some pitfalls I came across.

## Don’t worry about the type system

As computer scientists, we’re used to thinking about type systems, type casting rules, variable scoping rules, closures, stuff like that. These details form the backbone of any programming language, or so I thought. Not the case with R.

R is designed by statisticians, and statisticians are more interested in doing statistics than worrying about intricacies of their programming language. Types do exist, but it’s not worth your time to worry about the difference between a list and a vector; most likely, your code will just work on both.

The most fundamental object in R is the data frame, which stores rows of data. Data frames are as ubiquitous in R as objects are in Java. They also don’t have a close equivalent in most programming languages; it’s similar to a SQL table or an Excel spreadsheet.

## Use dplyr for data wrangling

The base library in R is not the most well-designed package in the world. There are many inconsistencies, arbitrary design decisions, and common operations are needlessly unintuitive. Fortunately, R has an excellent ecosystem of packages that make up for the shortcomings of the base system.

In particular, I highly recommend using the packages dplyr and tidyr instead of the base package for data wrangling tasks. I’m talking about operations you do to data to get it to be a certain form, like sorting by a variable, grouping by a set of variables and computing the aggregate sum over each group, etc. Dplyr and tidyr provide a consistent set of functions that make this easy. I won’t go into too much detail, but you can see this page for a comparison between dplyr and base R for some common data wrangling tasks.

## Use ggplot2 for plotting

Plotting is another domain where the base package falls short. The functions are inconsistent and worse, you’re often forced to hardcode arbitrary constants in your code. Stupid things like plot(..., pch=19) where 19 is the constant for “solid circle” and 17 means “solid triangle”.

There’s no reason to learn the base plotting system — ggplot2 is a much better alternative. Its functions allow you to build graphs piece by piece in a consistent manner (and they look nicer by default). I won’t go into the comparison in detail, but here’s a blog post that describes the advantages of ggplot2 over base graphics.

It’s unfortunate that R’s base package falls short in these two areas. But with the package manager, it’s super easy to install better alternatives. Both ggplot2 and dplyr are widely used (currently, both are in the top 5 most downloaded R packages).

## How to self-study R

First off, check out Swirl. It’s a package for teaching beginners the basics of R, interactively within RStudio itself. It guides you through its courses on topics like regression modelling and dplyr, and only takes a few hours to complete.

At some point, read through the tidyverse style guide to get up to speed on the best practices on naming files and variables and stuff like that.

Now go and analyze data! One major difference between R and other languages is that you need a dataset to do anything interesting. There are many public datasets out there; Kaggle provides a sizable repository.

For me, it’s a lot more motivating to analyze data I care about. Analyze your bank statement history, or data on your phone’s pedometer app, or your university’s enrollment statistics data to find which electives have the most girls. Turn it into a mini data-analysis project. Fit some regression models and draw a few graphs with R, this is a great way to learn.

The best thing about R is the number of packages out there. If you read about a statistical model, chances are that someone’s written an R package for it. You can download it and be up and running in minutes.

It takes a while to get used to, but learning R is definitely a worthwhile investment for any aspiring data scientist.