diff --git a/internal/backend.go b/internal/backend.go index 985e6c2..669935a 100644 --- a/internal/backend.go +++ b/internal/backend.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "regexp" "strings" "github.com/cupcakearmy/autorestic/internal/colors" @@ -58,6 +59,8 @@ func (b Backend) generateRepo() (string, error) { } } +var nonAlphaRegex = regexp.MustCompile("[^A-Za-z0-9]") + func (b Backend) getEnv() (map[string]string, error) { env := make(map[string]string) // Key @@ -73,7 +76,9 @@ func (b Backend) getEnv() (map[string]string, error) { } // From Envfile and passed as env - var prefix = "AUTORESTIC_" + strings.ToUpper(b.name) + "_" + nameForEnv := strings.ToUpper(b.name) + nameForEnv = nonAlphaRegex.ReplaceAllString(nameForEnv, "_") + var prefix = "AUTORESTIC_" + nameForEnv + "_" for _, variable := range os.Environ() { var splitted = strings.SplitN(variable, "=", 2) if strings.HasPrefix(splitted[0], prefix) { diff --git a/internal/backend_test.go b/internal/backend_test.go index a24afe0..e65e4fe 100644 --- a/internal/backend_test.go +++ b/internal/backend_test.go @@ -195,6 +195,33 @@ func TestGetEnv(t *testing.T) { assertEqual(t, result["B2_ACCOUNT_ID"], "foo123") assertEqual(t, result["B2_ACCOUNT_KEY"], "foo456") }) + + for _, char := range "@-_:/" { + t.Run(fmt.Sprintf("env var with special char (%c)", char), func(t *testing.T) { + // generate env variables + // TODO better way to teardown + defer os.Unsetenv("AUTORESTIC_FOO_BAR_RESTIC_PASSWORD") + defer os.Unsetenv("AUTORESTIC_FOO_BAR_B2_ACCOUNT_ID") + defer os.Unsetenv("AUTORESTIC_FOO_BAR_B2_ACCOUNT_KEY") + os.Setenv("AUTORESTIC_FOO_BAR_RESTIC_PASSWORD", "secret123") + os.Setenv("AUTORESTIC_FOO_BAR_B2_ACCOUNT_ID", "foo123") + os.Setenv("AUTORESTIC_FOO_BAR_B2_ACCOUNT_KEY", "foo456") + + b := Backend{ + name: fmt.Sprintf("foo%cbar", char), + Type: "local", + Path: "/foo/bar", + } + result, err := b.getEnv() + if err != nil { + t.Errorf("unexpected error %v", err) + } + assertEqual(t, result["RESTIC_REPOSITORY"], "/foo/bar") + assertEqual(t, result["RESTIC_PASSWORD"], "secret123") + assertEqual(t, result["B2_ACCOUNT_ID"], "foo123") + assertEqual(t, result["B2_ACCOUNT_KEY"], "foo456") + }) + } } func TestValidate(t *testing.T) {