From bbbd06b9825bf920fd17a2e5d754aaef4510cdc3 Mon Sep 17 00:00:00 2001 From: Pablu23 Date: Mon, 9 Sep 2024 23:21:53 +0200 Subject: [PATCH] Way better performance, better multithreading --- src/main.rs | 77 +++++++++++++++++++++++------------------------------ 1 file changed, 34 insertions(+), 43 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1556a17..bf9abd6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,10 +2,12 @@ use anyhow::{bail, Result}; use argon2::Config; use chacha20::cipher::{KeyIvInit, StreamCipher}; use chacha20::XChaCha20; +use core::time; use rand::RngCore; use rand_core::OsRng; -use std::fs::ReadDir; +use std::hint; use std::path::{Path, PathBuf}; +use std::sync::atomic::AtomicUsize; use std::{ env, fs, @@ -22,7 +24,7 @@ extern crate rpassword; use rpassword::read_password; -const BUFFER_LEN: usize = 50 * 1024 * 1024; // 50 MiB +const BUFFER_LEN: usize = 64 * 1024; // 64 KiB pub fn encrypt_file( source_path: String, @@ -42,8 +44,7 @@ pub fn encrypt_file( let mut source_file = File::open(&source_path)?; let mut dest_file = File::create(dest_path)?; - - let mut buffer = vec![0u8; BUFFER_LEN].into_boxed_slice(); + let mut buffer = vec![0u8; BUFFER_LEN]/* .into_boxed_slice() */; println!("Now encrypting {source_path}"); @@ -91,7 +92,6 @@ pub fn decrypt_file(source_path: &Path, pwd: &String, config: &Config) -> Result } println!("Start decrypting File {source_path:?}"); - let mut source_file = File::open(&source_path)?; source_file.read(&mut nonce)?; @@ -122,15 +122,14 @@ pub fn decrypt_file(source_path: &Path, pwd: &String, config: &Config) -> Result if cfg!(unix) { file_name = file_name.replace("\\", "/"); } + let path = root_dir_path.join(&file_name); let prefix = path.parent().expect("No parent Directory"); std::fs::create_dir_all(prefix)?; let mut dest_file = File::create(path)?; - - let mut buffer = vec![0u8; BUFFER_LEN].into_boxed_slice(); - + let mut buffer = vec![0u8; BUFFER_LEN]/* .into_boxed_slice() */; loop { let read_count = source_file.read(&mut buffer)?; @@ -139,6 +138,7 @@ pub fn decrypt_file(source_path: &Path, pwd: &String, config: &Config) -> Result dest_file.write(&buffer)?; } else { cipher.apply_keystream(&mut buffer[..read_count]); + dest_file.write(&buffer[..read_count])?; break; } @@ -146,6 +146,7 @@ pub fn decrypt_file(source_path: &Path, pwd: &String, config: &Config) -> Result println!("Finished decrpyting File {file_name}"); + fs::remove_file(source_path)?; Ok(()) } @@ -200,40 +201,39 @@ fn main() -> io::Result<()> { std::io::stdout().flush().unwrap(); let pwd = Arc::new(read_password().unwrap()); - let config = Arc::new(config); - let mut handles: Vec> = Vec::with_capacity(max_threads); - - let mut current_threads = 0; + let curr_threads = Arc::new(AtomicUsize::new(0)); for path_result in paths { let path = path_result.unwrap().path(); let pwd = pwd.clone(); let config = config.clone(); + let curr_threads_clone = curr_threads.clone(); handles.push(thread::spawn(move || { + curr_threads_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst); if path.is_file() { - decrypt_file(path.as_path(), &pwd, &config).unwrap(); - fs::remove_file(String::from(path.to_str().unwrap())).unwrap(); + decrypt_file( + path.as_path(), + &pwd, + &config + ) + .unwrap(); } + + curr_threads_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); })); - current_threads += 1; - if current_threads >= max_threads { - while let Some(handle) = handles.pop() { - handle.join().unwrap(); - current_threads -= 1; - } + while curr_threads.load(std::sync::atomic::Ordering::SeqCst) >= max_threads { + hint::spin_loop(); } } - if current_threads > 0 { - while let Some(handle) = handles.pop() { - handle.join().unwrap(); - current_threads -= 1; - } + for thread in handles { + let _ = thread.join(); } + thread::sleep(time::Duration::from_millis(10)); fs::remove_dir(private)?; } else { let root_dir = fs::read_dir(&cwd).unwrap(); @@ -261,24 +261,21 @@ fn main() -> io::Result<()> { print!("Type password to encrypt files: "); std::io::stdout().flush().unwrap(); let pwd = Arc::new(read_password().unwrap()); - - let mut nonce = [0u8; 24]; - OsRng.fill_bytes(&mut nonce); - let exe = Arc::new(exe); - // let private = Arc::new(private); - let mut handles: Vec> = Vec::with_capacity(max_threads); - let mut current_threads = 0; let cwd: Arc = Arc::from(String::from(cwd.to_str().unwrap())); + let curr_threads = Arc::new(AtomicUsize::new(0)); + for path in paths { let pwd = pwd.clone(); let exe = exe.clone(); let config = config.clone(); let cwd = cwd.clone(); + let curr_threads_clone = curr_threads.clone(); handles.push(thread::spawn(move || { + curr_threads_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst); if path.is_file() && path.as_os_str() != exe.as_os_str() { encrypt_file( String::from(path.to_str().unwrap()), @@ -288,22 +285,16 @@ fn main() -> io::Result<()> { ) .unwrap(); } + curr_threads_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); })); - current_threads += 1; - if current_threads >= max_threads { - while let Some(handle) = handles.pop() { - handle.join().unwrap(); - current_threads -= 1; - } + while curr_threads.load(std::sync::atomic::Ordering::SeqCst) >= max_threads { + hint::spin_loop(); } } - if current_threads > 0 { - while let Some(handle) = handles.pop() { - handle.join().unwrap(); - current_threads -= 1; - } + for thread in handles { + let _ = thread.join(); } for dir in dir_list {