Iris Flower Classification (R)


Project Summary

Abstract

This project focuses on the classification of iris flowers into their respective species by using the K-Nearest Neighbors machine-learning algorithm. The three species in this classification problem include setosa, versicolor, and virginica. The explanatory variables include sepal length, sepal width, pedal length, petal width. See sepal wiki. See petal wiki. We are essentially trying to predict the species of the iris flower based on physical features!

The K-Nearest Neighbor algorithm is interesting because it is a simple yet powerful a machine learning method used for classification. It predicts based on majority votes, measuring a certain number of neighboring observation points (k) and classifies based on attribute prevalence using Euclidean distance.

Requirements

Steps

Contributors

1. Load Packages

First we load the appropriate packages into our R environment.

For this we use the library() method and include the package names as arguments. Make sure to first install the packages using the install.packages() method if you haven't done so already.

# Here if you haven't installed these packages do so!
install.packages("data.table")
install.packages("ggplot2")
install.packages("ggfortify")
install.packages("caret")
install.packages("class")
install.packages("gridExtra")
install.packages("GGally")
install.packages("RGraphics")
install.packages("gmodels")

library(data.table)
library(ggplot2)
library(ggfortify)
library(caret)
library(class)
library(gridExtra)
library(GGally)
library(RGraphics)
library(gmodels)

Next we load our data.

2. Get Data

The iris dataset is very popular in statistical learning, and is readily available in the R base. To load this dataset all we have to do is call iris using the attach() and data() methods. We also run head() to get a quick grance at our data.

attach(iris)
data(iris)
head(iris)

Terminal Output

> head(iris)
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1          5.1         3.5          1.4         0.2  setosa
2          4.9         3.0          1.4         0.2  setosa
3          4.7         3.2          1.3         0.2  setosa
4          4.6         3.1          1.5         0.2  setosa
5          5.0         3.6          1.4         0.2  setosa
6          5.4         3.9          1.7         0.4  setosa

Our terminal output above shows six observations of our data. We can appreciate a total of five variables. The goal is to predict species as a function of the other four variables.

Next we do exploratory analysis.

3. Exploratory Analysis

We begin our exploratory analysis by looking for relationships across our explanatory and predicted variables. For this, we use ggplot() and plotly to generate scatterplots of the sepal length (y-axis) and sepal width (x-axis), and the petal length (y-axis) and petal width (x-axis).

gg1<-ggplot(iris,aes(x=Sepal.Width,y=Sepal.Length, 
                     shape=Species, color=Species)) + 
  theme(panel.background = element_rect(fill = "gray98"),
        axis.line   = element_line(colour="black"),
        axis.line.x = element_line(colour="gray"),
        axis.line.y = element_line(colour="gray")) +
  geom_point(size=2) + 
  labs(title = "Sepal Width Vs. Sepal Length")

ggplotly(gg1)

Next we plot the petal length vs the petal width!

gg2<-ggplot(iris,aes(x=Petal.Width,y=Petal.Length, 
                     shape=Species, color=Species)) + 
  theme(panel.background = element_rect(fill = "gray98"),
        axis.line   = element_line(colour="black"),
        axis.line.x = element_line(colour="gray"),
        axis.line.y = element_line(colour="gray")) + 
  geom_point(size=2) + 
  labs(title = "Petal Length Vs. Petal Width")

ggplotly(gg2)

The plots below shows setosa to be most distinguisable of the three species with respect to both sepal and petal attributes. We can infer then that the setosa species will yield the least prediction errors, while the other two species, versicolor and virginica, might not.

Below is a plot that shows the relationships across our various explanatory variables.

pairs <- ggpairs(iris,mapping=aes(color=Species),columns=1:4) + 
  theme(panel.background = element_rect(fill = "gray98"),
        axis.line   = element_line(colour="black"),
        axis.line.x = element_line(colour="gray"),
        axis.line.y = element_line(colour="gray")) 
pairs
ggplotly(pairs) %>%
  layout(showlegend = FALSE)

This plot reduces the dimensions and gives an overarching view of the interactions of the different attributes. This plot will be handy for other classification models like Linear Discrimant Analysis which is not in this project but we included more statistical process on Iris in the Github respository.

4. Model Estimation

The K-Nearest Neighbor algorithm predicts based on majority votes, measuring a certain number of neighboring observation points (k) and classifies based on attribute prevalence using Euclidean distance. Here is documentation on the knn() method. Check the documentation on its packageCLASS as well.

We begin this section by creating training and test sets with 75% and 25% of observations generated randomly hence the set.seed function. Here is a nice stack overflow post on how to split data into training and testing sets. We didn't post the outputs of the indices or the sets because it would be too much code outputted so we recommend to run for yourself to understand what's going on! Once you run the code you'll get a pretty good idea of what we're doing.

# Creating training/test set 
set.seed(123)
samp.size <- floor(nrow(iris) * .75)
samp.size
set.seed(123)

train.ind <- sample(seq_len(nrow(iris)), size = samp.size) 
train.ind 
train <- iris[train.ind, ] 
head(train)

train.set <- subset(train, select = -c(Species))
head(train.set)

test <- iris[-train.ind, ]
head(test)

test.set <- subset(test, select = -c(Species))
head(test.set)

class <- train[,"Species"]

test.class <- test [,"Species"]

Before entering parameter values into the knn() method we use cross validation from the caret package to find the optimal value for k. This optimal k will help us produce the smallest test error rate. The trainControl() and train() methods help us with this task.

set.seed(123)

contrl <- trainControl(method="repeatedcv",repeats = 3)

knn.K <- train(Species ~ ., data = train, method = "knn", trControl = contrl, preProcess = c("center","scale"),tuneLength = 20)


# Here we output the results from knn.K! 
knn.K

Terminal Output

> knn.K
k-Nearest Neighbors 

112 samples
  4 predictors
  3 classes: 'setosa', 'versicolor', 'virginica' 

Pre-processing: centered (4), scaled (4) 
Resampling: Cross-Validated (10 fold, repeated 3 times) 
Summary of sample sizes: 100, 101, 100, 100, 101, 101, ... 
Resampling results across tuning parameters:

  k  Accuracy   Kappa    
  5  0.9496970  0.9242859
  7  0.9438384  0.9155406
  9  0.9590404  0.9386181

Accuracy was used to select the optimal model using  the largest value.
The final value used for the model was 9.
    
    

The terminal output above shows that our optimal k = 9 based on the Accuracy and Kappa values.

5. Prediction Results

Now we can plug in k = 9 as one of the parameters into the knn() method. We also call our knn.iris model to see its output. Finally, we run the CrossTable() method to evaluate our model versus the test data.

knn.iris <- knn(train = train.set, test = test.set, cl = class, k = 9, prob = T)

dim(train)
dim(class)
length(class)
table(test.class, knn.iris)

knn.iris

CrossTable(x = test.class, y = knn.iris, prop.chisq=FALSE)
    
    

Below we can see our results of our model.

Terminal Output

> knn.iris <- knn(train = train.set, test = test.set, cl = class, k = 9, prob = T)

> dim(train)
[1] 112   5

> dim(class) # This is important since knn only registers the class componnent as a vector. You will get an error code if you use a data frame!!
NULL

> length(class)
[1] 112

> table(test.class, knn.iris)
knn.iris
test.class   setosa versicolor virginica
setosa         11          0         0
versicolor      0         13         0
virginica       0          1        13

> knn.iris
[1] setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa    
[10] setosa     setosa     versicolor versicolor versicolor versicolor versicolor versicolor versicolor
[19] versicolor versicolor versicolor versicolor versicolor versicolor virginica  virginica  virginica 
[28] versicolor virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[37] virginica  virginica 
attr(,"prob")
[1] 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000
[11] 1.0000000 0.7692308 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000
[21] 1.0000000 1.0000000 1.0000000 1.0000000 0.9090909 1.0000000 1.0000000 0.8181818 1.0000000 1.0000000
[31] 0.6363636 1.0000000 1.0000000 1.0000000 1.0000000 0.9090909 1.0000000 1.0000000
Levels: setosa versicolor virginica

> CrossTable(x = test.class, y = knn.iris, prop.chisq=FALSE)


Cell Contents
|-------------------------|
|                       N |
|           N / Row Total |
|           N / Col Total |
|         N / Table Total |
|-------------------------|


Total Observations in Table:  38 


 | knn.iris 
test.class |     setosa | versicolor |  virginica |  Row Total | 
-------------|------------|------------|------------|------------|
setosa |         11 |          0 |          0 |         11 | 
 |      1.000 |      0.000 |      0.000 |      0.289 | 
 |      1.000 |      0.000 |      0.000 |            | 
 |      0.289 |      0.000 |      0.000 |            | 
-------------|------------|------------|------------|------------|
versicolor |          0 |         13 |          0 |         13 | 
 |      0.000 |      1.000 |      0.000 |      0.342 | 
 |      0.000 |      0.929 |      0.000 |            | 
 |      0.000 |      0.342 |      0.000 |            | 
-------------|------------|------------|------------|------------|
virginica |          0 |          1 |         13 |         14 | 
 |      0.000 |      0.071 |      0.929 |      0.368 | 
 |      0.000 |      0.071 |      1.000 |            | 
 |      0.000 |      0.026 |      0.342 |            | 
-------------|------------|------------|------------|------------|
Column Total |         11 |         14 |         13 |         38 | 
 |      0.289 |      0.368 |      0.342 |            | 
-------------|------------|------------|------------|------------|
    
    

We can see here that the model predicted virginica when it was actually versicolor which from our exploratory analysis we assumed there would be some prediction errors since these two species were the least distinguishable among all three species. Both tables are just reiterating the results, but we recieved one wrong prediction. Thus we calculate the test error as:

Terminal Output

# Total Number of observations is 38 
# Total Number of observations predicted incorrectly is 1
> 1/38
[1] 0.02631579

Thus we get a test error rate of 0.263158 which is not that bad, but granted this is a very easy set to use Knn modeling!

6. Conclusions

Our model yielded test error rates is 0.263158 for the three different species, not bad! As this project shows, K-Nearest Neighbors modeling is fairly simple. For datasets with a small amount of variables KNN is a viable method, but for datasets with many variables we run into the curse of dimensionality. Check this stack exchange article for an explanation on the curse of dimensionality. We feel like this project is a good intro to training and test sets which are very important components not just to data science but to statistical learning overall!

CONGRATULATIONS!

Congratulations for getting this far! We hope you enjoyed this project. Please reach out to us here if you have any feedback or would like to publish your own project.

GitHub Repo Link Fork this project on GitHub

Try this project next:


Forecasting the Stock Market (Python)

Forecasting the Stock Market (Python)

Time-Series Analysis of the S&P 500 Stock Index with Python