기계학습

의사결정나무
Author
Affiliation
이광춘

한국R사용자회

Published

January 2, 2023

1 시각화

library(tidyverse)
library(tidymodels)
library(parttree)
library(palmerpenguins)
library(ggparty)

# 데이터셋
penguins_tbl <- penguins %>% 
  na.omit()

# 기계학습모형
penguins_dt <- 
  decision_tree() %>%
  set_engine("rpart") %>%
  set_mode("classification") %>%
  fit(species ~ flipper_length_mm + bill_length_mm, data = penguins_tbl)

1.1 펭귄종 분류 - Decision Boundary

# 시각화
penguins_tbl %>%
  ggplot(aes(x = flipper_length_mm, y = bill_length_mm)) +
  # geom_jitter(aes(col=species), alpha=0.7) +
  geom_point(aes(color = species)) +
  geom_parttree(data = penguins_dt, aes(fill=species), alpha = 0.1,
                flipaxes = FALSE) +
  scale_color_manual(values = c("Adelie"  = "blue",
                              "Chinstrap" = "darkgreen",
                              "Gentoo" = "black")) +
  scale_fill_manual(values = c("Adelie"  = "blue",
                              "Chinstrap" = "darkgreen",
                              "Gentoo" = "black")) +  
  theme_minimal()

1.2 펭귄 종 분류

library(rpart)

rpart_fit <- rpart(species ~ flipper_length_mm + bill_length_mm, 
                   data = penguins_tbl)

partykit::as.party(rpart_fit) %>% 
  ggparty() +
    geom_edge(size = 1.5) +
    geom_edge_label(colour = "gray30", size = 6) +
    geom_node_splitvar() +
    geom_node_plot(gglist = list(geom_bar(aes(x   = species,
                                              fill = species),
                                              alpha = 0.8),
                                theme_bw(base_size = 15),
                                scale_fill_manual(values = c("Adelie"  = "blue",
                                                            "Chinstrap" = "darkgreen",
                                                            "Gentoo" = "black")),
                                labs(x = "",
                                     y = "펭귄 개체수",
                                     fill = "펭귄 종")
                                ),
                    shared_axis_labels = TRUE,
                    legend_separator = TRUE
                  ) +
    geom_node_label(aes(label = paste0("노드 ", id, ", 펭귄수 = ", nodesize)),
                    fontface = "bold",
                    ids = "terminal",
                    size = 5, 
                    nudge_y = 0.01) +
    theme(legend.position = "none")