首先,我分析決定生還或決定罹難的變數,發現乘客有沒有上救生艇是很重要的變數。
當然為了預測,我也用了很多其他的變數(例如稱謂、性別等等),提高準確率。
在學習方面,我使用的是隨機樹(Random Forest)。
library('gridExtra')
library('ggplot2')
library('ggthemes')
library('scales')
library('dplyr')
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:gridExtra':
##
## combine
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library('mice')
## Loading required package: lattice
library('randomForest')
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
## The following object is masked from 'package:gridExtra':
##
## combine
# 讀檔(訓練集及測試集)
train.data <- read.csv("titanicTrain.csv", stringsAsFactors = F, na.strings = c(""))
test.data <- read.csv("titanicQuestion.csv", stringsAsFactors = F, na.strings = c(""))
train.data <- train.data[1:1000, ]
test.data <- test.data[1:309, ]
full <- rbind(train.data, test.data)
str(full)
## 'data.frame': 1309 obs. of 14 variables:
## $ pclass : int 1 1 1 1 1 1 1 1 1 1 ...
## $ survived : int 1 1 0 0 0 1 1 0 1 0 ...
## $ name : chr "Allen, Miss. Elisabeth Walton" "Allison, Master. Hudson Trevor" "Allison, Miss. Helen Loraine" "Allison, Mr. Hudson Joshua Creighton" ...
## $ sex : chr "female" "male" "female" "male" ...
## $ age : num 29 0.917 2 30 25 ...
## $ sibsp : int 0 1 1 1 1 0 1 0 2 0 ...
## $ parch : int 0 2 2 2 2 0 0 0 0 0 ...
## $ ticket : chr "24160" "113781" "113781" "113781" ...
## $ fare : num 211 152 152 152 152 ...
## $ cabin : chr "B5" "C22 C26" "C22 C26" "C22 C26" ...
## $ embarked : chr "S" "S" "S" "S" ...
## $ boat : chr "2" "11" NA NA ...
## $ body : int NA NA NA 135 NA NA NA NA NA 22 ...
## $ home.dest: chr "St Louis, MO" "Montreal, PQ / Chesterville, ON" "Montreal, PQ / Chesterville, ON" "Montreal, PQ / Chesterville, ON" ...
full$survived <- as.factor(full$survived)
full$pclass <- as.factor(full$pclass)
full$sex <- as.factor(full$sex)
train <- full[1:1000,]
# 生還人數
ggplot(train, aes(x = survived, fill = survived)) +
geom_bar(stat = 'count') +
labs(x = 'Survival') +
geom_label(stat = 'count',aes(label = ..count..), size = 6) +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
# 男女小孩生還比較
ggplot(train, aes(x = sex, fill = survived)) +
geom_bar(stat = 'count') +
labs(x = 'Survival of Male and Female') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
kids.survival <- train[(!is.na(train$age)) & (train$age < 10), ]
ggplot(kids.survival, aes(x = survived, fill = survived)) +
geom_bar(stat = 'count') +
labs(x = 'Survival of Kids aged under 10') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
可以看出男人的生還率最低,女人及小孩的生還率較高。
#依照座艙等級分類
ggplot(train, aes(x = pclass, fill = survived)) +
geom_bar(stat = 'count') +
labs(x = 'Survival of Different Class Levels') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
可以發現座艙等級越高,生還率越高。
full[!is.na(full$boat), "onboat"] <- "1"
full[is.na(full$boat), "onboat"] <- "0"
ggplot(full, aes(x = onboat, fill = onboat)) +
geom_bar(stat = 'count') +
labs(x = '\nON or NOT onboat') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
ggplot(full[1:1000,], aes(x = onboat, fill = survived)) +
geom_bar(stat = 'count') +
labs(x = '\nSurvival of ON or NOT onboat') +
theme_grey(base_size = 16)
上船者有很大的機率生還,反之則難逃一劫。
MF.OB <- ggplot(full, aes(x = onboat, fill = sex)) +
geom_bar(stat = 'count') +
labs(x = 'Sex of those who were and weren\'t Onboat') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
MF.OB
上船者有一大部分是女人,沒上船的有一大部分是男人。
CL.OB <- ggplot(full, aes(x = onboat, fill = pclass)) +
geom_bar(stat = 'count') +
labs(x = 'Sex of those who were and weren\'t Onboat') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
CL.OB
上船者有一大部分是座艙等級最高的,沒上船的有一半是座艙等級最低的。
full$age <- floor(full$age)
ggplot(full[!is.na(full$age) & (full$age <= 10),], aes(x = onboat, fill = onboat)) +
geom_bar(stat = 'count') +
labs(x = 'Kids ON or NOT onboat') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
ggplot(full[!is.na(full$age) & (full$age >= 65),], aes(x = onboat, fill = onboat)) +
geom_bar(stat = 'count') +
labs(x = 'Elders ON or NOT onboat') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
ggplot(full[!is.na(full$age) & (full$age < 65) & (full$age > 10),], aes(x = onboat, fill = onboat)) +
geom_bar(stat = 'count') +
labs(x = 'Others ON or NOT onboat') +
theme_grey(base_size = 16) +
coord_polar("x", start=0)
發現小孩蠻多有上船,老人大多沒上船,其他人則為沒上船者多一些。
# 從乘客名字中提取稱謂
full$title <- gsub('(.*, )|(\\..*)', '', full$name)
# 查看按照性別劃分的稱謂數量
table(full$sex, full$title)
##
## Capt Col Don Dona Dr Jonkheer Lady Major Master Miss Mlle Mme
## female 0 0 0 1 1 0 1 0 0 260 2 1
## male 1 4 1 0 7 1 0 2 61 0 0 0
##
## Mr Mrs Ms Rev Sir the Countess
## female 0 197 2 0 0 1
## male 757 0 0 8 1 0
# 對於那些出現頻率較低的稱謂合併為一類
rare_title <- c('Dona', 'Lady', 'the Countess','Capt', 'Col', 'Don', 'Dr', 'Major', 'Rev', 'Sir', 'Jonkheer', 'Mme')
full$title[full$title == 'Mlle']<- 'Miss'
full$title[full$title == 'Ms'] <- 'Miss'
full$title[full$title %in% rare_title] <- 'Rare Title'
table(full$sex, full$title)
##
## Master Miss Mr Mrs Rare Title
## female 0 264 0 197 5
## male 61 0 757 0 25
ggplot(full[1:1000,]) +
geom_histogram(aes(x = title, fill = survived), stat="count")
## Warning: Ignoring unknown parameters: binwidth, bins, pad
可得知女性稱謂生還率較高,女性生還率高(符合前述事實)。
full$surname <- sapply(full$name, function(x) strsplit(x, split = '[,.]')[[1]][1])
最後從乘客姓名中,提取姓氏。
full$familycount <- full$sibsp + full$parch + 1
ggplot(full[1:1000, ], aes(x = familycount, fill = survived)) +
geom_bar(stat='count', position='dodge') +
scale_x_continuous(breaks=c(1:8)) + labs(x = 'Family Size') +
theme_gray()
full$familysize[full$familycount == 1] <- 'single'
full$familysize[full$familycount < 5 & full$familycount > 1]<- 'small'
full$familysize[full$familycount > 4] <- 'large'
ggplot(full[1:1000, ], aes(x = familysize, fill = survived)) +
geom_bar(stat = 'count', position = 'fill') +
labs(x = 'Survival of Different Family Size') +
theme_grey(base_size = 16)
一個人的罹難率高,家庭人數二到四(小家庭)的生存率最高,家庭人數四個以上(大家庭)生存率低。
sum(is.na(full$embarked))
## [1] 2
full[is.na(full$embarked),]
## pclass survived name sex age
## 169 1 1 Icard, Miss. Amelie female 38
## 285 1 1 Stone, Mrs. George Nelson (Martha Evelyn) female 62
## sibsp parch ticket fare cabin embarked boat body home.dest onboat
## 169 0 0 113572 80 B28 <NA> 6 NA <NA> 1
## 285 0 0 113572 80 B28 <NA> 6 NA Cincinatti, OH 1
## title surname familycount familysize
## 169 Miss Icard 1 single
## 285 Mrs Stone 1 single
發現有兩行有缺失值:第169和第285行
我估計對於有相同艙位等級(pclass)和票價(fare)的乘客也許有著相同的登船港口位置( embarked)
ggplot(full, aes(x = embarked, y = fare, fill = factor(pclass))) +
geom_boxplot() +
geom_hline(aes(yintercept=80), colour='red', linetype='dashed', lwd=2) +
scale_y_continuous(labels=dollar_format()) +
theme_gray()
## Warning: Removed 1 rows containing non-finite values (stat_boxplot).
港口C的票價的中位數為80,因此我們可以把乘客169, 285的出發港口缺失值替換為‘C’
full[c(169, 285),"embarked"] <- 'C'
sum(is.na(full$fare))
## [1] 1
full[is.na(full$fare),]
## pclass survived name sex age sibsp parch ticket fare
## 1226 3 <NA> Storey, Mr. Thomas male 60 0 0 3701 NA
## cabin embarked boat body home.dest onboat title surname familycount
## 1226 <NA> S <NA> 261 <NA> 0 Mr Storey 1
## familysize
## 1226 single
有一個缺失值,在第1226行。
這位乘客是座艙等級3且出發港口為S,所以我決定找出有同樣條件的人的票價的中位數,當作此缺失值之票價。
ggplot(full[full$pclass == '3' & full$embarked == 'S', ], aes(x = fare)) +
geom_density(fill = '#99d6ff', alpha=0.4) +
geom_vline(aes(xintercept=median(fare, na.rm=T)),colour='red', linetype='dashed', lwd=1) +
scale_x_continuous(labels=dollar_format()) +
theme_few()
## Warning: Removed 1 rows containing non-finite values (stat_density).
median(full[!is.na(full$fare) & (full$embarked == 'S') & (full$pclass == '3'), "fare"])
## [1] 8.05
找出中位數大概在8左右,我們用median函數找出實際值,為8.05。
full[1226,"fare"] <- median(full[!is.na(full$fare) & (full$embarked == 'S') & (full$pclass == '3'), "fare"])
把此乘客的票價設定為同樣條件的人的票價的中位數。
sum(is.na(full$age))
## [1] 263
發現缺失值過多,有263個。
所以我決定用mice包去填補它。
# 使因子變量因子化
factor_vars <- c('pclass', 'sex', 'embarked', 'title', 'surname', 'familysize')
full[factor_vars] <- lapply(full[factor_vars],function(x) as.factor(x))
# 設置隨機種子
set.seed(2018)
# 執行多重插補法,剔除一些沒什麼用的變量:
mice.age <- mice(full[, !names(full) %in% c('name', 'ticket', 'cabin', 'surname', 'survived', 'home.dest', 'body')], method='rf')
##
## iter imp variable
## 1 1 age
## 1 2 age
## 1 3 age
## 1 4 age
## 1 5 age
## 2 1 age
## 2 2 age
## 2 3 age
## 2 4 age
## 2 5 age
## 3 1 age
## 3 2 age
## 3 3 age
## 3 4 age
## 3 5 age
## 4 1 age
## 4 2 age
## 4 3 age
## 4 4 age
## 4 5 age
## 5 1 age
## 5 2 age
## 5 3 age
## 5 4 age
## 5 5 age
# 保存完整輸出
mice.output <- complete(mice.age)
par(mfrow=c(1,2))
hist(full$age, freq=F, main='Age: Original Data', col='orange', ylim=c(0,0.04))
hist(mice.output$age, freq=F, main='Age: MICE Output', col='red', ylim=c(0,0.04))
效果不錯!
# 對原年齡數據進行更換
full$age <- mice.output$age
sum(is.na(full$age))
## [1] 0
確認年齡缺失值已全數填補。
train <- full[1:1000, ]
test <- full[1001:1309,]
set.seed(2018)
rf_model <- randomForest(survived ~ pclass + sex + age + sibsp
+ parch + fare + embarked + title + familysize + onboat, data = train)
prediction.train <- predict(rf_model, train)
correction.rate <- sum(prediction.train == full[1:1000, "survived"]) / 1000
correction.rate
## [1] 0.99
針對原訓練資料,正確率不錯,有99%左右!
prediction <- predict(rf_model, test)
answer <- data.frame(survived = prediction )
ggplot(answer, aes(x = survived, fill = survived)) +
geom_bar(stat = 'count') +
labs(x = 'Prediction of Survival') +
theme_grey(base_size = 16)
看一下生還與罹難人數比較。
write.csv(answer, file = "predict1.csv", row.names = FALSE)