We start with simulated toy data for illustration
sessionInfo()
## R version 4.1.2 (2021-11-01)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 20.04.3 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/atlas/libblas.so.3.10.3
## LAPACK: /usr/lib/x86_64-linux-gnu/atlas/liblapack.so.3.10.3
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## loaded via a namespace (and not attached):
## [1] digest_0.6.28 R6_2.5.1 jsonlite_1.7.2 magrittr_2.0.1
## [5] evaluate_0.14 rlang_0.4.12 stringi_1.7.5 jquerylib_0.1.4
## [9] bslib_0.3.1 rmarkdown_2.11 tools_4.1.2 stringr_1.4.0
## [13] xfun_0.27 yaml_2.2.1 fastmap_1.1.0 compiler_4.1.2
## [17] htmltools_0.5.2 knitr_1.36 sass_0.4.0
set.seed(7360)
# Attach Packages
library(tidyverse) # data manipulation and visualization
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.5 ✓ purrr 0.3.4
## ✓ tibble 3.1.5 ✓ dplyr 1.0.7
## ✓ tidyr 1.1.4 ✓ stringr 1.4.0
## ✓ readr 2.1.0 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(kernlab) # SVM methodology
##
## Attaching package: 'kernlab'
## The following object is masked from 'package:purrr':
##
## cross
## The following object is masked from 'package:ggplot2':
##
## alpha
library(e1071) # SVM methodology
library(RColorBrewer) # customized coloring of plots
# construct data set
x <- matrix(rnorm(200*2), ncol = 2)
x[1:100,] <- x[1:100,] + 2.5
x[101:150,] <- x[101:150,] - 2.5
y <- c(rep(1,150), rep(2,50))
dat <- data.frame(x=x,y=as.factor(y))
# plot data set
ggplot(data = dat, aes(x = x.2, y = x.1, color = y, shape = y)) +
geom_point(size = 2) +
scale_color_manual(values=c("#000000","#FF0000","#00BA00")) +
theme(legend.position = "none")
Use radial kernel for svm
function from e1071
package based on the shape of the data and plot results for SVM
# sample training data and fit model
train <- base::sample(200,100, replace = FALSE)
svmfit <- svm(y~., data = dat[train,], kernel = "radial", gamma = 1, cost = 1)
# plot classifier
plot(svmfit, dat)
We can use kernlab
package for the same procedure
# Fit radial-based SVM in kernlab
kernfit <- ksvm(x[train,],y[train], type = "C-svc", kernel = 'rbfdot', C = 1, scaled = c())
# Plot training data
plot(kernfit, data = x[train,])
Now tune the model to find optimal cost, gamma values
# tune model to find optimal cost, gamma values
tune.out <- tune(svm, y~., data = dat[train,], kernel = "radial",
ranges = list(cost = c(0.1,1,10,100,1000),
gamma = c(0.5,1,2,3,4)))
# show best model
tune.out$best.model
##
## Call:
## best.tune(method = svm, train.x = y ~ ., data = dat[train, ], ranges = list(cost = c(0.1,
## 1, 10, 100, 1000), gamma = c(0.5, 1, 2, 3, 4)), kernel = "radial")
##
##
## Parameters:
## SVM-Type: C-classification
## SVM-Kernel: radial
## cost: 1
##
## Number of Support Vectors: 30
Now take a look at model performance
# validate model performance
(valid <- table(true = dat[-train,"y"], pred = predict(tune.out$best.model,
newx = dat[-train,])))
## pred
## true 1 2
## 1 57 21
## 2 12 10
sum(diag(valid)) / sum(valid)
## [1] 0.67
Acquire data:
library(keras)
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
library(e1071)
library(kernlab)
library(tidyverse)
mnist <- dataset_mnist()
## Loaded Tensorflow version 2.7.0
x_train <- mnist$train$x
y_train <- mnist$train$y
x_train <- array_reshape(x_train, c(nrow(x_train), 28 * 28))
selected.columns <- colSums(x_train) > 0
x_train <- x_train[, selected.columns]
x_test <- mnist$test$x
y_test <- mnist$test$y
training <- as_tibble(cbind(y_train, x_train))
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
colnames(training) <- c("label", as.character(1: length(selected.columns)))
training <- slice_sample(training, prop = 0.1)
x_test <- array_reshape(x_test, c(nrow(x_test), 28 * 28))
x_test <- x_test[, selected.columns]
test <- as_tibble(cbind(y_test, x_test))
colnames(test) <- c("label", as.character(1: length(selected.columns)))
fit svm
model <- ksvm(label ~ ., data = training, type = "C-svc", kernel = "rbfdot",
C = 100, gamma = 0.001, scaled = FALSE)
The post claimed a 0.99 accuracy using the full dataset.
predicted <- predict(model, new = x_test)
sum(predicted == y_test) / length(y_test)
## [1] 0.9618
We get 96% accuracy using only 10% of the training set, not bad at all.