library(tidyverse)
library(purrr)
This assignment will challenge your function writing abilities. I’m not going to lie, these functions are difficult but well within your reach. I do, however, want to recognize that not everyone is interested in being a “virtuoso” with their function writing. So, there are two options for this week’s lab:
- Option 1: Complete this lab assignment in search of virtuoso status with your function writing
- Option 2: Complete one of the difficult functions (Exercise 1 or Exercise 2) and complete the “Alternative Lab 6”.
Setting the Stage
My number one use case for writing functions and iteration / looping is to perform some exploration or modeling repeatedly for different “tweaked” versions. For example, our broad goal might be to fit a linear regression model to our data. However, there are often multiple choices that we have to make in practice:
- Keep missing values or fill them in (imputation)?
- Filter out outliers in one or more variables?
We can map these choices to arguments in a custom model-fitting function:
impute
: TRUE or FALSEremove_outliers
: TRUE or FALSE
A function that implements the analysis and allows for variation in these choices:
<- function(df, impute, remove_outliers, mod) {
fit_model if (impute) {
<- some_imputation_function(df)
df
}
if (remove_outliers) {
<- function_for_removing_outliers(df)
df
}
lm(mod, data = df)
}
Helper Functions
Exercise 1: Write a function that removes outliers in a dataset. The user should be able to supply the dataset, the variables to remove outliers from, and a threshold on the number of SDs away from the mean used to define outliers. Hint 1: You will need to calculate a z-score to filter the values! Hint 2: You might want to consider specifying a default value (e.g., 3) for sd_thresh
.
<- function(df, ..., sd_thresh = 3) {
remove_outliers
<- rlang::enquos(...) #suggested by discord, specify rlang after class discussion
selected_cols
# make sure that at least one variable input
if (length(selected_cols) == 0) {
stop("Specify at least one variable to remove outliers from")
}
<- df
df_filtered
for (col_quo in selected_cols) {
# column names as string to filter
<- rlang::as_name(col_quo)
col_name
#make sure column is in df
if (!col_name %in% names(df_filtered)) {
warning(paste(col_name, "not found in the dataframe. Skipping."))
next
}
# vectorize to check numeric (ChatGPT helped with this part)
<- df_filtered[[col_name]]
col_vec
# check if numeric
if (!is.numeric(col_vec)) {
warning(paste(col_name, "is not numeric. Skipping."))
next
}
<- mean(col_vec, na.rm = TRUE)
col_mean <- sd(col_vec, na.rm = TRUE)
col_sd
<- (col_vec - col_mean) / col_sd
z_scores
# identify obs that are ouliers
<- abs(z_scores) >= sd_thresh
remove
# remove outliers from df
<- df_filtered[-remove, ]
df_filtered
}
return(df_filtered)
}
Testing Your Function!
## Testing how your function handles multiple input variables
remove_outliers(diamonds,
price,
x,
y, z)
# A tibble: 53,936 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
2 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
3 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
4 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
5 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
6 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
7 0.3 Good J SI1 64 55 339 4.25 4.28 2.73
8 0.23 Ideal J VS1 62.8 56 340 3.93 3.9 2.46
9 0.22 Premium F SI1 60.4 61 342 3.88 3.84 2.33
10 0.31 Ideal J SI2 62.2 54 344 4.35 4.37 2.71
# ℹ 53,926 more rows
## Testing how your function handles an input that isn't numeric
remove_outliers(diamonds,
price, color)
# A tibble: 53,939 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
2 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
3 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
4 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
5 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
6 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
7 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
8 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
9 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
10 0.3 Good J SI1 64 55 339 4.25 4.28 2.73
# ℹ 53,929 more rows
## Testing how your function handles a non-default sd_thresh
remove_outliers(diamonds,
price,
x,
y,
z, sd_thresh = 2)
# A tibble: 53,936 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
2 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
3 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
4 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
5 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
6 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
7 0.3 Good J SI1 64 55 339 4.25 4.28 2.73
8 0.23 Ideal J VS1 62.8 56 340 3.93 3.9 2.46
9 0.22 Premium F SI1 60.4 61 342 3.88 3.84 2.33
10 0.31 Ideal J SI2 62.2 54 344 4.35 4.37 2.71
# ℹ 53,926 more rows
## Demonstrating error message for not-found column
remove_outliers(diamonds,
lol,
x,
y, z)
# A tibble: 53,937 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
2 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
3 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
4 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
5 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
6 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
7 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
8 0.3 Good J SI1 64 55 339 4.25 4.28 2.73
9 0.23 Ideal J VS1 62.8 56 340 3.93 3.9 2.46
10 0.22 Premium F SI1 60.4 61 342 3.88 3.84 2.33
# ℹ 53,927 more rows
Exercise 2: Write a function that imputes missing values for numeric variables in a dataset. The user should be able to supply the dataset, the variables to impute values for, and a function to use when imputing. Hint 1: You will need to use across()
to apply your function, since the user can input multiple variables. Hint 2: The replace_na()
function is helpful here!
<- function(df, ..., impute_fun = mean) {
impute_missing <- rlang::enquos(...)
selected_cols
# warning message for no columns selected
if (length(selected_cols) == 0) {
warning("No columns specified for imputation, returning original dataframe")
return(df)
}
# selecting df columns, !!! allows for unquoting (... is usable)
<- df %>% dplyr::select(!!!selected_cols)
filtered_df <- names(filtered_df)
col_names
# warning message for when all columns do not exist
if (length(col_names) == 0) {
stop("Columns specified do not exist")
}
# checking is all cols numeric
<- sapply(filtered_df, is.numeric)
are_numeric if (!all(are_numeric)) {
# id non-numeric and warning
<- col_names[!are_numeric]
non_numeric_selected_cols stop(paste("The following columns are non-numeric:", paste(non_numeric_selected_cols, collapse=", ")))
}
<- df %>%
df_imputed mutate(
across(
.cols = all_of(col_names), #selecting cols
.fns = ~ {imputation_val <- impute_fun(.x, na.rm = TRUE) #creating impute values
::replace_na(.x, imputation_val) #replacing NAs
tidyr
}
))
return(df_imputed)
}
Testing Your Function!
## Testing how your function handles multiple input variables
impute_missing(nycflights13::flights,
arr_delay, dep_delay)
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles an input that isn't numeric
impute_missing(nycflights13::flights,
arr_delay, carrier)
Error in impute_missing(nycflights13::flights, arr_delay, carrier): The following columns are non-numeric: carrier
## Testing how your function handles a non-default impute_fun
impute_missing(nycflights13::flights,
arr_delay,
dep_delay, impute_fun = median)
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
Primary Function
Exercise 3: Write a fit_model()
function that fits a specified linear regression model for a specified dataset. The function should:
- allow the user to specify if outliers should be removed (
TRUE
orFALSE
) - allow the user to specify if missing observations should be imputed (
TRUE
orFALSE
)
If either option is TRUE
, your function should call your remove_outliers()
or impute_missing()
functions to modify the data before the regression model is fit.
<- function(df, mod_formula, remove_outliers = F, impute_missing = F, ...){
fit_model <- rlang::enquos(...)
selected_cols
# warning for no vars provided
if(length(selected_cols) == 0){
stop("Please provide variables in formula")
}
# warning if no formula provided or if not a call object
if(is.call(mod_formula) == F | is.null(mod_formula) == T){
stop("Model formula is not specified correctly")
}
<- df
df_filtered
# removing outliers
if(remove_outliers == T){
<- remove_outliers(df_filtered, !!!selected_cols)
df_filtered
}
# imputing
if(impute_missing == T){
<- impute_missing(df_filtered, !!!selected_cols)
df_filtered
}
#fitting model
<- lm(mod_formula, data = df_filtered)
model
return(model)
}
Testing Your Function!
fit_model(
diamonds,mod_formula = price ~ carat + cut,
remove_outliers = TRUE,
impute_missing = TRUE,
price,
carat )
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.47 7871.17 1239.78 -528.59 367.96 74.63
Iteration
In the diamonds
dataset, we want to understand the relationship between price
and size (carat
). We want to explore variation along two choices:
The variables included in the model. We’ll explore 3 sets of variables:
- No further variables (just
price
andcarat
) - Adjusting for
cut
- Adjusting for
cut
andclarity
- Adjusting for
cut
,clarity
, andcolor
- No further variables (just
Whether or not to impute missing values
Whether or not to remove outliers in the
carat
variable (we’ll define outliers as cases whosecarat
is over 3 SDs away from the mean).
Parameters
First, we need to define the set of parameters we want to iterate the fit_model()
function over. The tidyr
package has a useful function called crossing()
that is useful for generating argument combinations. For each argument, we specify all possible values for that argument and crossing()
generates all combinations. Note that you can create a list of formula objects in R with c(y ~ x1, y ~ x1 + x2)
.
<- crossing(
df_arg_combos impute = c(TRUE, FALSE),
remove_outliers = c(TRUE, FALSE),
mod = c(y ~ x1,
~ x1 + x2)
y
) df_arg_combos
Exercise 4: Use crossing()
to create the data frame of argument combinations for our analyses.
<- crossing(
df_arg_combos impute = c(TRUE, FALSE),
remove_outliers = c(TRUE, FALSE),
mod = c(price ~ carat,
~ carat + cut,
price ~ carat + cut + clarity,
price ~ carat + cut + clarity + color)
price
) df_arg_combos
# A tibble: 16 × 3
impute remove_outliers mod
<lgl> <lgl> <list>
1 FALSE FALSE <formula>
2 FALSE FALSE <formula>
3 FALSE FALSE <formula>
4 FALSE FALSE <formula>
5 FALSE TRUE <formula>
6 FALSE TRUE <formula>
7 FALSE TRUE <formula>
8 FALSE TRUE <formula>
9 TRUE FALSE <formula>
10 TRUE FALSE <formula>
11 TRUE FALSE <formula>
12 TRUE FALSE <formula>
13 TRUE TRUE <formula>
14 TRUE TRUE <formula>
15 TRUE TRUE <formula>
16 TRUE TRUE <formula>
Iterating Over the Parameters
We’ve arrived at the final step!
Exercise 5: Use pmap()
from purrr
to apply the fit_model()
function to every combination of arguments from `diamonds.
pmap(df_arg_combos, fit_model, df = diamonds, price, carat)
[[1]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat
-2256 7756
[[2]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.38 7871.08 1239.80 -528.60 367.91 74.59
[[3]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.540 8472.026 713.804 -334.503 188.482 1.663
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4011.681 -1821.922 917.658 -430.047 257.141 26.909
clarity^7
186.742
[[4]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.603 8886.129 698.907 -327.686 180.565 -1.207
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.535 -1832.406 923.273 -361.995 216.616 2.105
clarity^7 color.L color.Q color.C color^4 color^5
110.340 -1910.288 -627.954 -171.960 21.678 -85.943
color^6
-49.986
[[5]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat
-2256 7757
[[6]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.47 7871.17 1239.78 -528.59 367.96 74.63
[[7]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.784 8472.316 713.686 -334.533 188.544 1.705
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4012.037 -1821.966 917.545 -429.913 257.052 26.916
clarity^7
186.767
[[8]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.814 8886.379 698.795 -327.726 180.624 -1.168
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.848 -1832.466 923.155 -361.872 216.522 2.114
clarity^7 color.L color.Q color.C color^4 color^5
110.368 -1910.260 -627.961 -172.133 21.894 -86.104
color^6
-49.899
[[9]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat
-2256 7756
[[10]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.38 7871.08 1239.80 -528.60 367.91 74.59
[[11]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.540 8472.026 713.804 -334.503 188.482 1.663
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4011.681 -1821.922 917.658 -430.047 257.141 26.909
clarity^7
186.742
[[12]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.603 8886.129 698.907 -327.686 180.565 -1.207
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.535 -1832.406 923.273 -361.995 216.616 2.105
clarity^7 color.L color.Q color.C color^4 color^5
110.340 -1910.288 -627.954 -171.960 21.678 -85.943
color^6
-49.986
[[13]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat
-2256 7757
[[14]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.47 7871.17 1239.78 -528.59 367.96 74.63
[[15]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.784 8472.316 713.686 -334.533 188.544 1.705
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4012.037 -1821.966 917.545 -429.913 257.052 26.916
clarity^7
186.767
[[16]]
Call:
lm(formula = mod_formula, data = df_filtered)
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.814 8886.379 698.795 -327.726 180.624 -1.168
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.848 -1832.466 923.155 -361.872 216.522 2.114
clarity^7 color.L color.Q color.C color^4 color^5
110.368 -1910.260 -627.961 -172.133 21.894 -86.104
color^6
-49.899