【R语言学习笔记】 Day1 CART 逻辑回归、分类树以及随机森林的应用及对比
shanshant 人气:2
1. 目的:根据人口普查数据来预测收入(预测每个个体年收入是否超过$50,000)
2. 数据来源:1994年美国人口普查数据,数据中共含31978个观测值,每个观测值代表一个个体
3. 变量介绍:
(1)age: 年龄(以年表示)
(2)workclass: 工作类别/性质 (e.g., 国家机关工作人员、当地政府工作人员、无收入人员等)
(3)education: 受教育水平 (e.g., 小学、初中、高中、本科、硕士、博士等)
(4)maritalstatus: 婚姻状态(e.g., 未婚、离异等)
(5)occupation: 工作类型 (e.g., 行政/文员、农业养殖人员、销售人员等)
(6)relationship: 家庭身份 (e.g., 丈夫、妻子、孩子等)
(7)race: 种族
(8)sex: 性别
(9)capitalgain: 1994年的资本收入 (买卖股票、债券等)
(10)capitalloss: 1994年的资本支出 (买卖股票、债券等)
(11)hoursperweek: 每周工作时长
(12)nativecountry: 国籍
(13)over50k: 1994年全年工资是否超过$50,000
4. 应用及分析
census <- read.csv("census.csv") #读取文件
library(caTools) # 加载caTools包
# 将数据分为测试集和训练集 set.seed(2000) spl <- sample.split(census$over50k, SplitRatio = 0.6) census.train <- subset(census, spl == T) # 测试集 census.test <- subset(census, spl == F) # 训练集
# 构建逻辑回归模型 census.logistic <- glm(over50k ~ ., data = census.train, family = 'binomial') summary(census.logistic) # 查看模型拟合结果
# 在临界值为0.5的情况下,逻辑回归模型应用到测试集的准确性 ## method1 census.logistic.pred <- predict(census.logistic, newdata = census.test, type = 'response') library(caret) confusionMatrix(as.factor(ifelse(census.logistic.pred >= 0.5, " >50K", " <=50K")), as.factor(census.test$over50k)) ## method2 table(census.test$over50k, census.logistic.pred>= 0.5) sum(diag(table(census.test$over50k, census.logistic.pred>= 0.5)))/nrow(census.test) #0.8552
# 测试集的基础准确性
table(census.test$over50k)/nrow(census.test) #0.759
# ROC 以及 AUC library(ROCR) census.pred <- prediction(census.logistic.pred, census.test$over50k) census.perf <- performance(census.pred, 'tpr', 'fpr') plot(census.perf, colorize = T) #ROC curve as.numeric(performance(census.pred, 'auc')@y.values) #AUC value is 0.9061598
虽然逻辑回归模型准确率高达0.8572,且变量的显著性有助于我们判断个体的收入情况;但是在自变量中的分类变量类别太多的情况下,我们无法判断哪些变量更重要。
因此,接下来构建CART模型。
# 默认的CART模型 library(rpart) library(rpart.plot) census.cart <- rpart(over50k ~ ., data = census.train, method = 'class') prp(census.cart) # 作图
# 模型准确性 census.cart.pred <- predict(census.cart, newdata = census.test, type = 'class') ## method1 table(census.test$over50k, census.cart.pred) sum(diag(table(census.test$over50k, census.cart.pred)))/nrow(census.test) ## method2 confusionMatrix(census.cart.pred, as.factor(census.test$over50k)) # 模型准确性为0.8474
# ROC 以及 AUC census.cart.pred2 <- predict(census.cart, newdata = census.test) census.cart.pred2 census.cart.pred3 <- prediction(census.cart.pred2[,2], census.test$over50k) census.cart.perf <- performance(census.cart.pred3, 'tpr', 'fpr') plot(census.cart.perf, colorize = T) # ROC as.numeric(performance(census.cart.pred3, 'auc')@y.values) #AUC value is 0.8470256
# 随机森林模型 set.seed(1) census.train.small <- census.train[sample(nrow(census.train), 2000),] ## 构建随机森林模型之前先减小训练集样本数量。 ## 因为随机森林过程中包含大量运算过程,小样本更益于模型的建立 library(randomForest) census.train.small.rf <- randomForest(over50k ~ ., data = census.train.small) # 模型预测 census.train.small.rf.pred <- predict(census.train.small.rf, newdata = census.test) # 模型准确性 confusionMatrix(census.train.small.rf.pred, as.factor(census.test$over50k)) # 0.8533
因为随机森林模型是一系列分类决策树的集合,因此与分类决策树相比,随机森林模型的解释性稍差,但仍可用一些方法来衡量变量的重要性
# 方法一:统计随机过程中每个变量出现的次数 vu <- varUsed(census.train.small.rf, count=TRUE) vusrted <- sort(vu, decreasing = FALSE, index.return = TRUE) # draw a Cleveland dot plot dotchart(vusorted$x, names(census.train.small.rf$forest$xlevels[vusorted$ix]))
其中,age出现次数最多,sex出现次数最少。
# 方法二:比较平均Gini指数的下降程度 varImpPlot(census.train.small.rf)
其中,occupation、education、age的平均Gini指数减少的最多,sex的平均Gini指数减少的最少
# 改进的CART模型(考虑cp值) library(caret) library(lattice) library(ggplot2) library(e1071) # 找出使得准确率最高的cp值 set.seed(2) numFolds <- trainControl(method = 'cv', number = 10) cpGrid <- expand.grid(.cp = seq(0.002,0.1,0.002)) train(over50k ~ ., data = census.train, method = 'rpart', trControl = numFolds, tuneGrid = cpGrid) # cp = 0.002时模型准确度最高 # 构建新的CART模型(cp=0.002) census.bestTree <- rpart(over50k ~ ., data = census.train, method = 'class', cp = 0.002) prp(census.bestTree) # 作图 # 模型预测 predCV <- predict(census.bestTree, newdata = census.test, type = 'class') # 计算新模型的准确率 ## method1 table(census.test$over50k, predCV) sum(diag(table(census.test$over50k, predCV)))/nrow(census.test) ## method2 confusionMatrix(predCV, as.factor(census.test$over50k)) # 0.8612
考虑cp值以后的CART模型的准确性比默认模型高了1%左右,但是模型明显复杂了更多,因此需要在模型简洁性及准确性之间做出权衡。
本案例中,默认模型足够简洁且准确度也很高,所以倾向使用默认模型。
加载全部内容