| 1 | package auth |
| 2 | |
| 3 | import ( |
| 4 | "net/http" |
| 5 | "strings" |
| 6 | "testing" |
| 7 | |
| 8 | "golang.org/x/crypto/bcrypt" |
| 9 | ) |
| 10 | |
| 11 | func TestSecureEqual_Equal(t *testing.T) { |
| 12 | if !SecureEqual("hunter2", "hunter2") { |
| 13 | t.Fatal("equal strings should match") |
| 14 | } |
| 15 | } |
| 16 | |
| 17 | func TestSecureEqual_Unequal(t *testing.T) { |
| 18 | if SecureEqual("hunter2", "hunter3") { |
| 19 | t.Fatal("different strings should not match") |
| 20 | } |
| 21 | } |
| 22 | |
| 23 | func TestSecureEqual_BothEmpty(t *testing.T) { |
| 24 | if !SecureEqual("", "") { |
| 25 | t.Fatal("two empty strings should match") |
| 26 | } |
| 27 | } |
| 28 | |
| 29 | func TestSecureEqual_EmptyVsNonempty(t *testing.T) { |
| 30 | if SecureEqual("", "x") { |
| 31 | t.Fatal("empty vs nonempty should not match") |
| 32 | } |
| 33 | if SecureEqual("x", "") { |
| 34 | t.Fatal("nonempty vs empty should not match") |
| 35 | } |
| 36 | } |
| 37 | |
| 38 | func TestSecureEqual_LengthMismatch(t *testing.T) { |
| 39 | if SecureEqual("short", "longer_password") { |
| 40 | t.Fatal("length mismatch should not match") |
| 41 | } |
| 42 | } |
| 43 | |
| 44 | func TestSecureEqual_Over256SameLengthAndPrefix(t *testing.T) { |
| 45 | a := strings.Repeat("a", 300) |
| 46 | b := strings.Repeat("a", 256) + strings.Repeat("b", 44) |
| 47 | if !SecureEqual(a, b) { |
| 48 | t.Fatal("same length, identical first 256 bytes should match after pad/truncate") |
| 49 | } |
| 50 | } |
| 51 | |
| 52 | func TestSecureEqual_Over256DifferentPrefix(t *testing.T) { |
| 53 | a := strings.Repeat("a", 300) |
| 54 | b := "z" + strings.Repeat("a", 299) |
| 55 | if SecureEqual(a, b) { |
| 56 | t.Fatal("differing prefix within first 256 bytes should not match") |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | func TestSecureEqual_Exactly256(t *testing.T) { |
| 61 | pw := strings.Repeat("x", 256) |
| 62 | if !SecureEqual(pw, pw) { |
| 63 | t.Fatal("exact 256-byte identical strings should match") |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | func TestVerifyPassword_PlainHappy(t *testing.T) { |
| 68 | if !VerifyPassword("hunter2", "hunter2") { |
| 69 | t.Fatal("plain password should verify") |
| 70 | } |
| 71 | } |
| 72 | |
| 73 | func TestVerifyPassword_PlainSad(t *testing.T) { |
| 74 | if VerifyPassword("hunter2", "hunter3") { |
| 75 | t.Fatal("wrong plain password should fail") |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | func TestVerifyPassword_PlainLengthMismatch(t *testing.T) { |
| 80 | if VerifyPassword("short", "longer_password") { |
| 81 | t.Fatal("length mismatch should fail") |
| 82 | } |
| 83 | } |
| 84 | |
| 85 | func TestVerifyPassword_BcryptHappy(t *testing.T) { |
| 86 | hash, err := bcrypt.GenerateFromPassword([]byte("hunter2"), bcrypt.MinCost) |
| 87 | if err != nil { |
| 88 | t.Fatal(err) |
| 89 | } |
| 90 | if !VerifyPassword("hunter2", string(hash)) { |
| 91 | t.Fatal("bcrypt password should verify") |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | func TestVerifyPassword_BcryptSad(t *testing.T) { |
| 96 | hash, err := bcrypt.GenerateFromPassword([]byte("hunter2"), bcrypt.MinCost) |
| 97 | if err != nil { |
| 98 | t.Fatal(err) |
| 99 | } |
| 100 | if VerifyPassword("nope", string(hash)) { |
| 101 | t.Fatal("wrong bcrypt password should fail") |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | func TestGenerateSessionToken(t *testing.T) { |
| 106 | tok, err := GenerateSessionToken() |
| 107 | if err != nil { |
| 108 | t.Fatal(err) |
| 109 | } |
| 110 | if len(tok) != 64 { |
| 111 | t.Fatalf("want 64 hex chars, got %d", len(tok)) |
| 112 | } |
| 113 | for _, c := range tok { |
| 114 | isHex := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') |
| 115 | if !isHex { |
| 116 | t.Fatalf("non-hex char %q in token", c) |
| 117 | } |
| 118 | } |
| 119 | tok2, _ := GenerateSessionToken() |
| 120 | if tok == tok2 { |
| 121 | t.Fatal("two tokens should differ") |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | func TestSessionCookie_Attrs(t *testing.T) { |
| 126 | s := &Store{CookieName: "session", CookieSecure: false} |
| 127 | c := s.SessionCookie("abc123") |
| 128 | if c.Name != "session" || c.Value != "abc123" { |
| 129 | t.Fatalf("bad name/value: %+v", c) |
| 130 | } |
| 131 | if !c.HttpOnly { |
| 132 | t.Fatal("HttpOnly should be set") |
| 133 | } |
| 134 | if c.SameSite != http.SameSiteStrictMode { |
| 135 | t.Fatalf("want SameSite=Strict, got %v", c.SameSite) |
| 136 | } |
| 137 | if c.Path != "/" { |
| 138 | t.Fatalf("want Path=/, got %q", c.Path) |
| 139 | } |
| 140 | if c.MaxAge != 7*24*3600 { |
| 141 | t.Fatalf("want MaxAge=604800, got %d", c.MaxAge) |
| 142 | } |
| 143 | if c.Secure { |
| 144 | t.Fatal("Secure should be false") |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | func TestSessionCookie_Secure(t *testing.T) { |
| 149 | s := &Store{CookieName: "session", CookieSecure: true} |
| 150 | if !s.SessionCookie("x").Secure { |
| 151 | t.Fatal("Secure should be true when CookieSecure=true") |
| 152 | } |
| 153 | } |
| 154 | |
| 155 | func TestClearCookie(t *testing.T) { |
| 156 | s := &Store{CookieName: "session"} |
| 157 | c := s.ClearCookie() |
| 158 | if c.Value != "" || c.MaxAge != -1 { |
| 159 | t.Fatalf("clear cookie should have empty value and MaxAge=-1, got %+v", c) |
| 160 | } |
| 161 | if c.SameSite != http.SameSiteStrictMode { |
| 162 | t.Fatalf("want SameSite=Strict, got %v", c.SameSite) |
| 163 | } |
| 164 | } |
| 165 | |
| 166 | func TestGenerateShortID(t *testing.T) { |
| 167 | id, err := GenerateShortID(10) |
| 168 | if err != nil { |
| 169 | t.Fatal(err) |
| 170 | } |
| 171 | if len(id) != 10 { |
| 172 | t.Fatalf("want len 10, got %d", len(id)) |
| 173 | } |
| 174 | if _, err := GenerateShortID(0); err == nil { |
| 175 | t.Fatal("zero length should error") |
| 176 | } |
| 177 | } |