XGBoost learns the Canadian Flag

XGBoost is a machine learning library that’s great for classification tasks. It’s often seen in Kaggle competitions, and usually beats other classifiers like logistic regression, random forests, SVMs, and shallow neural networks. One day, I was feeling slightly patriotic, and wondered: can XGBoost learn the Canadian flag?

canada_original.pngAbove: Our home and native land

Let’s find out!

Preparing the dataset

The task is to classify each pixel of the Canadian flag as either red or white, given limited data points. First, we read in the image with R and take the red channel:

library(png)
library(ggplot2)
library(xgboost)

img <- readPNG("canada.png")
red <- img[,,2]

HEIGHT <- dim(red)[1]
WIDTH <- dim(red)[2]

Next, we sample 7500 random points for training. Also, to make it more interesting, each point has a probability 0.05 of flipping to the opposite color.

ERROR_RATE <- 0.05

get_data_points <- function(N) {
  x <- sample(1:WIDTH, N, replace = T)
  y <- sample(1:HEIGHT, N, replace = T)
  p <- red[cbind(y, x)]
  p <- round(p)
  flips <- sample(c(0, 1), N, replace = T,
                  prob = c(ERROR_RATE, 1 - ERROR_RATE))
  p[flips == 1] <- 1 - p[flips == 1]
  data.frame(x=as.numeric(x), y=as.numeric(y), p=p)
}

data <- get_data_points(7500)

This is what our classifier sees:

noisy.png

Alright, let’s start training.

Quick introduction to XGBoost

XGBoost implements gradient boosted decision trees, which were first proposed by Friedman in 1999.

1.png

Above: XGBoost learns an ensemble of short decision trees

The output of XGBoost is an ensemble of decision trees. Each individual tree by itself is not very powerful, containing only a few branches. But through gradient boosting, each subsequent tree tries to correct for the mistakes of all the trees before it, and makes the model better. After many iterations, we get a set of decision trees; the sum of the all their outputs is our final prediction.

For more technical details of how this works, refer to this tutorial or the XGBoost paper.

Experiments

Fitting an XGBoost model is very easy using R. For this experiment, we use decision trees of height 3, but you can play with the hyperparameters.

fit <- xgboost(data = matrix(c(data$x, data$y), ncol = 2), label = data$p,
               nrounds = 1,
               max_depth = 3)

We also need a way of visualizing the results. To do this, we run every pixel through the classifier and display the result:

plot_canada <- function(dataplot) {
  dataplot$y <- -dataplot$y
  dataplot$p <- as.factor(dataplot$p)

  ggplot(dataplot, aes(x = x, y = y, color = p)) +
    geom_point(size = 1) +
    scale_x_continuous(limits = c(0, 240)) +
    scale_y_continuous(limits = c(-120, 0)) +
    theme_minimal() +
    theme(panel.background = element_rect(fill='black')) +
    theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank()) +
    scale_color_manual(values = c("white", "red"))
}

fullimg <- expand.grid(x = as.numeric(1:WIDTH), y = as.numeric(1:HEIGHT))
fullimg$p <- predict(fit, newdata = matrix(c(fullimg$x, fullimg$y), ncol = 2))
fullimg$p <- as.numeric(fullimg$p > 0.5)

plot_canada(fullimg)

In the first iteration, XGBoost immediately learns the two red bands at the sides:

round1.png

After a few more iterations, the maple leaf starts to take form:

round7.png

round15

round60

By iteration 60, it learns a pretty recognizable maple leaf. Note that the decision trees split on x or y coordinates, so XGBoost can’t learn diagonal decision boundaries, only approximate them with horizontal and vertical lines.

If we run it for too long, then it starts to overfit and capture the random noise in the training data. In practice, we would use cross validation to detect when this is happening. But why cross-validate when you can just eyeball it?

round300.png

That was fun. If you liked this, check out this post which explores various classifiers using a flag of Australia.

The source code for this blog post is posted here. Feel free to experiment with it.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s