# DADA2 Amplicon Analysis Pipeline (Paired-End) + decontam blank correction for HalSval2024 dataset

suppressPackageStartupMessages({
  library(dada2)
  library(phyloseq)
  library(ShortRead)
  library(Biostrings)
  library(tidyverse)
  library(readxl)
  library(ggplot2)
  library(pheatmap)
  library(decontam)
})

# 0. Number of cores -------------------------------------------------------
ncores <- 8
message("Using ", ncores, " cores for multithreading.")

# 1. File setup ------------------------------------------------------------
path <- "00_RAW"
out_path <- "dada2_out"

if (!dir.exists(out_path)) dir.create(out_path, recursive = TRUE)

fnFs <- sort(list.files(path, pattern = "__16S_1.fastq.gz$", full.names = TRUE))
fnRs <- sort(list.files(path, pattern = "__16S_2.fastq.gz$", full.names = TRUE))

if (length(fnFs) != length(fnRs)) {
  stop("Unequal numbers of forward and reverse files!")
}

# Robust sample name extraction
sample_names <- basename(fnFs) %>%
  str_remove("__16S_1\\.fastq\\.gz$")

if (any(sample_names == basename(fnFs))) {
  sample_names <- basename(fnFs) %>%
    str_replace("_R1_001\\.fastq\\.gz$", "") %>%
    str_replace("_L\\d+.*$", "")
}

names(fnFs) <- sample_names
names(fnRs) <- sample_names
message("Detected samples: ", paste(sample_names, collapse = ", "))

# ==========================================================================
# TRACKING TABLE INITIALISATION
# ==========================================================================
track <- tibble(Sample = sample_names)

# Count raw reads (forward only – pairs are equal)
raw_counts <- map_dbl(fnFs, ~ countFastq(.x)[1, "records"])
track$raw <- raw_counts

# 2. Primer definitions ----------------------------------------------------
FWD <- "CCTACGGGNGGCWGCAG"      # 341F
REV <- "GACTACHVGGGTATCTAATCC"  # 785R

allOrients <- function(primer) {
  dna <- DNAString(primer)
  orients <- c(Forward    = dna,
               Complement = complement(dna),
               Reverse    = reverse(dna),
               RevComp    = reverseComplement(dna))
  sapply(orients, toString)
}

FWD.orients <- allOrients(FWD)
REV.orients <- allOrients(REV)

# 3. Filter out reads with Ns ---------------------------------------------
filtN_dir <- file.path(path, "filtN")
if (!dir.exists(filtN_dir)) dir.create(filtN_dir)

fnFs.filtN <- file.path(filtN_dir, basename(fnFs))
fnRs.filtN <- file.path(filtN_dir, basename(fnRs))
names(fnFs.filtN) <- sample_names
names(fnRs.filtN) <- sample_names

filterAndTrim(fnFs, fnFs.filtN, fnRs, fnRs.filtN,
              maxN = 0, multithread = ncores, verbose = TRUE)

# 4. Cutadapt trimming -----------------------------------------------------
cutadapt <- "/PAATH/TO//cutadapt" 
REV.RC <- dada2:::rc(REV)

cut_dir <- file.path(path, "cutadapt")
if (!dir.exists(cut_dir)) dir.create(cut_dir)

fnFs.cut <- file.path(cut_dir, basename(fnFs.filtN))
fnRs.cut <- file.path(cut_dir, basename(fnRs.filtN))
names(fnFs.cut) <- sample_names
names(fnRs.cut) <- sample_names

min_length <- 50
common_flags <- c("-j", ncores, "-n", "2", "-m", min_length,
                  "--pair-filter=any", "--discard-untrimmed")

for (s in sample_names) {
  message("Cutadapt trimming sample: ", s)
  system2(cutadapt, args = c(common_flags,
                             "-g", FWD, "-a", REV.RC,
                             "-G", REV, "-A", dada2:::rc(FWD),
                             "-o", fnFs.cut[s], "-p", fnRs.cut[s],
                             fnFs.filtN[s], fnRs.filtN[s]))
}

# Count after primer trimming
cut_counts <- map_dbl(fnFs.cut, ~ countFastq(.x)[1, "records"])
track$after_primer <- cut_counts

# 5. DADA2 filtering/trimming ---------------------------------------------
filt_dir <- file.path(out_path, "filtered")
if (!dir.exists(filt_dir)) dir.create(filt_dir)

filtFs <- file.path(filt_dir, paste0(sample_names, "_R1_filt.fastq.gz"))
filtRs <- file.path(filt_dir, paste0(sample_names, "_R2_filt.fastq.gz"))
names(filtFs) <- sample_names
names(filtRs) <- sample_names

out <- filterAndTrim(fnFs.cut, filtFs,
                     fnRs.cut, filtRs,
                     truncLen = c(270, 260),
                     maxN = 0,
                     truncQ = 12,
                     maxEE = c(1, 1),
                     rm.phix = TRUE,
                     compress = TRUE,
                     multithread = ncores,
                     verbose = TRUE)
# out has columns "reads.in" and "reads.out" for forward reads
track$after_filterAndTrim <- out[, "reads.out"]

# 6. Learn error rates -----------------------------------------------------
errF <- learnErrors(filtFs, multithread = ncores, verbose = TRUE)
errR <- learnErrors(filtRs, multithread = ncores, verbose = TRUE)

# 7. Dereplication & DADA --------------------------------------------------
derepFs <- derepFastq(filtFs, verbose = TRUE)
derepRs <- derepFastq(filtRs, verbose = TRUE)
names(derepFs) <- sample_names
names(derepRs) <- sample_names

dadaFs <- dada(derepFs, err = errF, multithread = ncores, pool = FALSE)
dadaRs <- dada(derepRs, err = errR, multithread = ncores, pool = FALSE)

# Count after denoising
denoise_counts <- map_dbl(dadaFs, ~ sum(getUniques(.x)))
track$after_denoise <- denoise_counts

# 8. Merge pairs ---------------------------------------------------------
mergers <- mergePairs(dadaFs, derepFs, dadaRs, derepRs, verbose = TRUE)

merge_counts <- map_dbl(mergers, ~ sum(getUniques(.x)))
track$after_merge <- merge_counts

# 9. Construct sequence table --------------------------------------------
seqtab <- makeSequenceTable(mergers)

# Keep ASVs of expected length (optional)
min.len <- 400
max.len <- 435
seqtab2 <- seqtab[, nchar(colnames(seqtab)) %in% seq(min.len, max.len)]

# 10. Chimera removal ------------------------------------------------------
seqtab.nochim_all <- removeBimeraDenovo(seqtab2,
                                        method = "consensus",
                                        multithread = ncores,
                                        verbose = TRUE)

chimera_counts <- rowSums(seqtab.nochim_all)
track$after_chimera <- chimera_counts

# 11. Taxonomy assignment --------------------------------------------------
taxa_all <- assignTaxonomy(
  seqtab.nochim_all,
  "/PATH/TO/SILVA/DATABASE",
  multithread = ncores
)

# 12. Metadata ------------------------------------------------------------
sample_df <- read_excel("Metadata_HarSval2024.xlsx")
sample_df <- as.data.frame(sample_df)
sample_df <- sample_df %>%
  dplyr::rename(SampleWell = Sample)
rownames(sample_df) <- sample_df$Name
sample_df <- sample_df[sample_names, , drop = FALSE]

# 13. Phyloseq object (post‑chimera) --------------------------------------
physeq_all <- phyloseq(
  otu_table(seqtab.nochim_all, taxa_are_rows = FALSE),
  tax_table(taxa_all),
  sample_data(sample_df)
)

# 14. Decontam ------------------------------------------------------------
sample_data(physeq_all)$BLANK <- as.character(sample_data(physeq_all)$BLANK)
is.neg <- sample_data(physeq_all)$BLANK == "Yes"

if (!any(is.neg)) {
  warning("No BLANK == 'Yes' samples detected. Skipping decontam.")
  physeq_decontam <- physeq_all
  track$after_decontam <- track$after_chimera
} else {
  contam_df <- isContaminant(physeq_all,
                             method = "prevalence",
                             neg = is.neg,
                             threshold = 0.2)
  physeq_decontam <- prune_taxa(!contam_df$contaminant, physeq_all)
  decontam_counts <- sample_sums(physeq_decontam)
  track$after_decontam <- decontam_counts
}

# 15. Post‑decontam filters (prevalence & abundance) ----------------------
min_prev_samples <- 3
min_total_abund  <- 3

otu_mat <- as(otu_table(physeq_decontam), "matrix")
if (taxa_are_rows(physeq_decontam)) otu_mat <- t(otu_mat)

asv_prev   <- apply(otu_mat, 2, function(x) sum(x > 0))
asv_totals <- colSums(otu_mat)

keep <- asv_prev >= min_prev_samples & asv_totals >= min_total_abund
physeq_filt <- prune_taxa(keep, physeq_decontam)

filt_counts <- sample_sums(physeq_filt)
track$after_prev_abund <- filt_counts

# 16. Remove non‑Bacteria / organelles ------------------------------------
physeq_filt <- subset_taxa(physeq_filt,
                           Kingdom %in% c("Bacteria", "Archaea") &
                             !(Order == "Chloroplast") &
                             !(Family == "Mitochondria"))

# Get sample sums as a named vector
tax_filter_counts <- sample_sums(physeq_filt)  # names = sample IDs

# Merge into track (original 41 samples)
track <- track %>%
  mutate(after_tax_filter = if_else(Sample %in% names(tax_filter_counts),
                                    tax_filter_counts[Sample],
                                    NA_real_))

# 17. Final phyloseq object ------------------------------------------------
physeq <- physeq_filt

# ==========================================================================
# OUTPUT TRACKING TABLE
# ==========================================================================
# Add columns for percentages retained relative to raw
track <- track %>%
  mutate(
    pct_after_primer      = after_primer / raw * 100,
    pct_after_filter      = after_filterAndTrim / raw * 100,
    pct_after_denoise     = after_denoise / raw * 100,
    pct_after_merge       = after_merge / raw * 100,
    pct_after_chimera     = after_chimera / raw * 100,
    pct_after_decontam    = after_decontam / raw * 100,
    pct_after_prev_abund  = after_prev_abund / raw * 100,
    pct_after_tax_filter  = after_tax_filter / raw * 100
  )

# Save to CSV
write_csv(track, file = "read_tracking_summary.csv")
message("Tracking table saved to 'read_tracking_summary.csv'")

# Print the table
print(track)
