diff --git a/develop.go b/develop.go index 899fd34..6bc5dcd 100644 --- a/develop.go +++ b/develop.go @@ -2,6 +2,10 @@ package main +func getSecretPath() (string, error) { + return "", nil +} + func getSecret() (string, error) { return "test", nil } diff --git a/internal/server/options.go b/internal/server/options.go new file mode 100644 index 0000000..063119f --- /dev/null +++ b/internal/server/options.go @@ -0,0 +1,60 @@ +package server + +import "time" + +type Options struct { + Port int + Auth Optional[AuthOptions] + Tls Optional[TlsOptions] + UpdateInterval time.Duration +} + +type Optional[v any] struct { + Enabled bool + value v +} + +func (o *Optional[v]) Get() v { + return o.value +} + +func (o *Optional[v]) Set(value v) { + o.value = value + o.Enabled = true +} + +func (o *Optional[v]) Apply(apply func(*v)) { + o.Enabled = true + apply(&o.value) +} + +type AuthType int + +const ( + Raw AuthType = iota + File +) + +type AuthOptions struct { + // Secret Direct or Path to secret File + Secret string + LoadType AuthType +} + +type TlsOptions struct { + CertPath string + KeyPath string +} + +func NewDefaultOptions() Options { + return Options{ + Port: 8080, + Auth: Optional[AuthOptions]{ + Enabled: false, + }, + Tls: Optional[TlsOptions]{ + Enabled: false, + }, + UpdateInterval: 15 * time.Minute, + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 5742362..f50dda7 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "os" "path/filepath" "strconv" "strings" @@ -37,18 +38,25 @@ type Server struct { DbMgr *database.Manager - secret string - mux *http.ServeMux + mux *http.ServeMux + + options Options + secret string } -func New(provider provider.Provider, db *database.Manager, mux *http.ServeMux, secret string) *Server { +func New(provider provider.Provider, db *database.Manager, mux *http.ServeMux, options ...func(*Options)) *Server { + opts := NewDefaultOptions() + for _, opt := range options { + opt(&opts) + } + s := Server{ ImageBuffers: make(map[string][]byte), Provider: provider, DbMgr: db, Mutex: &sync.Mutex{}, mux: mux, - secret: secret, + options: opts, } return &s @@ -72,23 +80,38 @@ func (s *Server) RegisterRoutes() { s.mux.HandleFunc("GET /update", s.HandleUpdate) } -func (s *Server) StartTLS(port int, certFile, keyFile string) error { - log.Info().Int("Port", port).Str("Certificate", certFile).Str("Key", keyFile).Msg("Starting server") +func (s *Server) Start() error { server := http.Server{ - Addr: fmt.Sprintf(":%d", port), - Handler: s.Auth(s.mux), + Addr: fmt.Sprintf(":%d", s.options.Port), + Handler: s.mux, } - return server.ListenAndServeTLS(certFile, keyFile) -} -func (s *Server) Start(port int) error { - log.Info().Int("Port", port).Msg("Starting server") - - server := http.Server{ - Addr: fmt.Sprintf(":%d", port), - Handler: s.Auth(s.mux), + if s.options.Auth.Enabled { + auth := s.options.Auth.Get() + switch auth.LoadType { + case Raw: + s.secret = auth.Secret + case File: + secretBytes, err := os.ReadFile(auth.Secret) + if err != nil { + return err + } + s.secret = string(secretBytes) + } + s.secret = strings.TrimSpace(s.secret) + server.Handler = s.Auth(s.mux) + } + + s.registerUpdater() + + if s.options.Tls.Enabled { + tls := s.options.Tls.Get() + log.Info().Int("Port", s.options.Port).Str("Cert", tls.CertPath).Str("Key", tls.KeyPath).Msg("Starting server") + return server.ListenAndServeTLS(tls.CertPath, tls.KeyPath) + } else { + log.Info().Int("Port", s.options.Port).Msg("Starting server") + return server.ListenAndServe() } - return server.ListenAndServe() } func (s *Server) UpdateMangaList() { @@ -105,15 +128,18 @@ func (s *Server) UpdateMangaList() { } } -func (s *Server) RegisterUpdater(interval time.Duration) { - go func(s *Server) { - for { - select { - case <-time.After(interval): - s.UpdateMangaList() +func (s *Server) registerUpdater() { + if s.options.UpdateInterval > 0 { + log.Info().Str("Interval", s.options.UpdateInterval.String()).Msg("Registering Updater") + go func(s *Server) { + for { + select { + case <-time.After(s.options.UpdateInterval): + s.UpdateMangaList() + } } - } - }(s) + }(s) + } } func (s *Server) LoadNext() { diff --git a/main.go b/main.go index d14863b..3baa035 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "os/exec" "os/signal" "runtime" - "strings" "time" "github.com/pablu23/mangaGetter/internal/database" @@ -29,18 +28,101 @@ var ( databaseFlag = flag.String("database", "", "Path to sqlite.db file") certFlag = flag.String("cert", "", "Path to cert file, has to be used in conjunction with key") keyFlag = flag.String("key", "", "Path to key file, has to be used in conjunction with cert") - updateIntervalFlag = flag.String("update", "1h", "Interval to update Mangas") + updateIntervalFlag = flag.String("update", "0h", "Interval to update Mangas") debugFlag = flag.Bool("debug", false, "Activate debug Logs") prettyLogsFlag = flag.Bool("pretty", false, "Pretty pring Logs") logPathFlag = flag.String("log", "", "Path to logfile, stderr if default") ) func main() { - var secret string = "" - var filePath string - flag.Parse() + setupLogging() + + filePath := setupDb() + db := database.NewDatabase(filePath, true, *debugFlag) + err := db.Open() + if err != nil { + log.Fatal().Err(err).Str("Path", filePath).Msg("Could not open Database") + } + + mux := http.NewServeMux() + s := server.New(&provider.Bato{}, &db, mux, func(o *server.Options) { + authOptions := setupAuth() + o.Auth.Set(authOptions) + interval, err := time.ParseDuration(*updateIntervalFlag) + if err != nil { + log.Fatal().Err(err).Str("Interval", *updateIntervalFlag).Msg("Could not parse interval") + } + o.UpdateInterval = interval + + if *certFlag != "" && *keyFlag != "" { + o.Tls.Apply(func(to *server.TlsOptions) { + to.CertPath = *certFlag + to.KeyPath = *keyFlag + }) + } + }) + + setupClient() + setupClose(&db) + err = s.Start() + if err != nil { + log.Fatal().Err(err).Msg("Could not start server") + } +} + +func setupAuth() server.AuthOptions { + var authOptions server.AuthOptions + if *secretFlag != "" { + authOptions.LoadType = server.Raw + authOptions.Secret = *secretFlag + } else if *secretFilePathFlag != "" { + authOptions.LoadType = server.File + authOptions.Secret = *secretFilePathFlag + } else if *authFlag { + path, err := getSecretPath() + if err != nil { + log.Fatal().Err(err).Msg("Secret file could not be found") + } + authOptions.Secret = path + authOptions.LoadType = server.File + } + return authOptions +} + +func setupClient() { + if !*serverFlag { + go func() { + time.Sleep(300 * time.Millisecond) + err := open(fmt.Sprintf("http://localhost:%d", *portFlag)) + if err != nil { + log.Error().Err(err).Msg("Could not open Browser") + } + }() + } +} + +func setupClose(db *database.Manager) { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + go func() { + for range c { + Close(db) + } + }() +} + +func setupDb() string { + if *databaseFlag != "" { + return *databaseFlag + } else { + return getDbPath() + } +} + +func setupLogging() { zerolog.SetGlobalLevel(zerolog.InfoLevel) if *prettyLogsFlag { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) @@ -60,76 +142,6 @@ func main() { MaxBackups: 10, })) } - - if *secretFlag != "" { - secret = *secretFlag - } else if *secretFilePathFlag != "" { - buf, err := os.ReadFile(*secretFilePathFlag) - if err != nil { - log.Fatal().Err(err).Str("Path", *secretFilePathFlag).Msg("Could not read secret File") - } - secret = string(buf) - } else if *authFlag { - cacheSecret, err := getSecret() - secret = cacheSecret - if err != nil { - log.Error().Err(err).Msg("Secret file could not be found or read, not activating Auth") - } - } - - if *databaseFlag != "" { - filePath = *databaseFlag - } else { - filePath = getDbPath() - } - - db := database.NewDatabase(filePath, true, *debugFlag) - err := db.Open() - if err != nil { - log.Fatal().Err(err).Str("Path", filePath).Msg("Could not open Database") - } - - secret = strings.TrimSpace(secret) - mux := http.NewServeMux() - s := server.New(&provider.Bato{}, &db, mux, secret) - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - - go func() { - for range c { - Close(&db) - } - }() - - if !*serverFlag { - go func() { - time.Sleep(300 * time.Millisecond) - err := open(fmt.Sprintf("http://localhost:%d", *portFlag)) - if err != nil { - log.Error().Err(err).Msg("Could not open Browser") - } - }() - } - - interval, err := time.ParseDuration(*updateIntervalFlag) - if err != nil { - log.Fatal().Err(err).Str("Interval", *updateIntervalFlag).Msg("Could not parse interval") - } - s.RegisterUpdater(interval) - s.RegisterRoutes() - - if *certFlag != "" && *keyFlag != "" { - err = s.StartTLS(*portFlag, *certFlag, *keyFlag) - if err != nil { - log.Fatal().Err(err).Str("Cert", *certFlag).Str("Key", *keyFlag).Int("Port", *portFlag).Msg("Could not start TLS server") - } - } else { - err = s.Start(*portFlag) - if err != nil { - log.Fatal().Err(err).Int("Port", *portFlag).Msg("Could not start server") - } - } } func open(url string) error { diff --git a/release.go b/release.go index 027a9b6..7e20d6b 100644 --- a/release.go +++ b/release.go @@ -7,6 +7,17 @@ import ( "path/filepath" ) +func getSecretPath() (string, error) { + dir, err := os.UserCacheDir() + if err != nil { + return "", err + } + + dirPath := filepath.Join(dir, "MangaGetter") + filePath := filepath.Join(dirPath, "secret.secret") + return filePath, nil +} + func getSecret() (string, error) { dir, err := os.UserCacheDir() if err != nil {