diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3e9a76d..668f888 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -25,4 +25,4 @@ jobs: run: go build -v ./... - name: Test - run: go test -v ./... + run: go test -v -cover ./... diff --git a/README.md b/README.md index e544c97..e6f8ac5 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,12 @@ Here is an example of usage: ```go +// Example of upgrading password hash to a greater complexity. +// +// Note: This example uses very unsecure hash functions to allow for predictable output. Use of argon2.Argon2id or scrypt.Scrypt2 for greater hash security is recommended. func Example() { - pass := "my_pass" - hash := "$1$81ed91e1131a3a5a50d8a68e8ef85fa0" + pass := []byte("my_pass") + hash := []byte("$1$81ed91e1131a3a5a50d8a68e8ef85fa0") pwd := passwd.New( argon2.Argon2id, // first is preferred type. @@ -19,23 +22,25 @@ func Example() { _, err := pwd.Passwd(pass, hash) if err != nil { fmt.Println("fail: ", err) + return } // Check if we want to update. if !pwd.IsPreferred(hash) { - newHash, err := pwd.Passwd(pass, "") + newHash, err := pwd.Passwd(pass, nil) if err != nil { fmt.Println("fail: ", err) + return } - fmt.Println("new hash:", newHash) + fmt.Println("new hash:", string(newHash)[:31], "...") } // Output: - // new hash: $argon2id$... + // new hash: $argon2id$v=19,m=65536,t=1,p=4$ ... } ``` -https://github.com/sour-is/go-passwd/blob/main/passwd_test.go#L33-L59 +https://github.com/sour-is/go-passwd/blob/main/passwd_test.go#L40-L68 This shows how one would set a preferred hashing type and if the current version of ones password is not the preferred type updates it to enhance the security of the hashed password when someone logs in. @@ -61,12 +66,12 @@ https://github.com/sour-is/go-passwd/blob/main/passwd_test.go#L28-L31 Circling back to the `IsPreferred` method. A hasher can define its own `IsPreferred` method that will be called to check if the current hash meets the complexity requirements. This is good for updating the password hashes to be more secure over time. ```go -func (p *Passwd) IsPreferred(hash string) bool { +func (p *Passwd) IsPreferred(hash []byte) bool { _, algo := p.getAlgo(hash) if algo != nil && algo == p.d { // if the algorithm defines its own check for preference. - if ck, ok := algo.(interface{ IsPreferred(string) bool }); ok { + if ck, ok := algo.(interface{ IsPreferred([]byte) bool }); ok { return ck.IsPreferred(hash) } diff --git a/passwd.go b/passwd.go index 7795940..a232824 100644 --- a/passwd.go +++ b/passwd.go @@ -1,13 +1,13 @@ package passwd import ( + "bytes" "errors" "fmt" - "strings" ) type Passwder interface { - Passwd(string, string) (string, error) + Passwd(pass, hash []byte) ([]byte, error) ApplyPasswd(*Passwd) } @@ -45,8 +45,8 @@ func (p *Passwd) SetFallthrough(pass Passwder) { p.f = pass } -func (p *Passwd) Passwd(pass, hash string) (string, error) { - if hash == "" { +func (p *Passwd) Passwd(pass, hash []byte) ([]byte, error) { + if hash == nil { return p.d.Passwd(pass, hash) } name, algo := p.getAlgo(hash) @@ -54,17 +54,17 @@ func (p *Passwd) Passwd(pass, hash string) (string, error) { algo = p.f } if algo == nil { - return "", fmt.Errorf("%w: %s", ErrNoHandler, name) + return nil, fmt.Errorf("%w: %s", ErrNoHandler, name) } return algo.Passwd(pass, hash) } -func (p *Passwd) IsPreferred(hash string) bool { +func (p *Passwd) IsPreferred(hash []byte) bool { _, algo := p.getAlgo(hash) if algo != nil && algo == p.d { // if the algorithm defines its own check for preference. - if ck, ok := algo.(interface{ IsPreferred(string) bool }); ok { + if ck, ok := algo.(interface{ IsPreferred([]byte) bool }); ok { return ck.IsPreferred(hash) } @@ -73,17 +73,18 @@ func (p *Passwd) IsPreferred(hash string) bool { return false } -func (p *Passwd) getAlgo(hash string) (string, Passwder) { +func (p *Passwd) getAlgo(hash []byte) (string, Passwder) { var algo string - if !strings.HasPrefix(hash, "$") { + if !bytes.HasPrefix(hash, []byte("$")) { return p.getName(p.f), p.f } - if _, h, ok := strings.Cut(hash, "$"); ok { - algo, _, ok = strings.Cut(h, "$") + if _, h, ok := bytes.Cut(hash, []byte("$")); ok { + a, _, ok := bytes.Cut(h, []byte("$")) if !ok { return "", nil } + algo = string(a) if passwd, ok := p.m[algo]; ok { return algo, passwd diff --git a/passwd_test.go b/passwd_test.go index e786b83..07e9836 100644 --- a/passwd_test.go +++ b/passwd_test.go @@ -1,24 +1,28 @@ package passwd_test import ( + "bytes" "crypto/subtle" "fmt" - "strings" "testing" "github.com/matryer/is" "github.com/sour-is/go-passwd" + "github.com/sour-is/go-passwd/pkg/argon2" "github.com/sour-is/go-passwd/pkg/unix" ) type plainPasswd struct{} -func (p *plainPasswd) Passwd(pass string, check string) (string, error) { - if check == "" { - return fmt.Sprint("$plain$", pass), nil +func (p *plainPasswd) Passwd(pass, check []byte) ([]byte, error) { + if check == nil { + var b bytes.Buffer + b.WriteString("$plain$") + b.Write(pass) + return b.Bytes(), nil } - if subtle.ConstantTimeCompare([]byte(pass), []byte(strings.TrimPrefix(check, "$plain$"))) == 1 { + if subtle.ConstantTimeCompare([]byte(pass), []byte(bytes.TrimPrefix(check, []byte("$plain$")))) == 1 { return check, nil } @@ -34,50 +38,52 @@ func (p *plainPasswd) ApplyPasswd(passwd *passwd.Passwd) { // // Note: This example uses very unsecure hash functions to allow for predictable output. Use of argon2.Argon2id or scrypt.Scrypt2 for greater hash security is recommended. func Example() { - pass := "my_pass" - hash := "my_pass" + pass := []byte("my_pass") + hash := []byte("$1$81ed91e1131a3a5a50d8a68e8ef85fa0") pwd := passwd.New( - &unix.MD5{}, // first is preferred type. - &plainPasswd{}, + argon2.Argon2id, // first is preferred type. + &unix.MD5{}, ) _, err := pwd.Passwd(pass, hash) if err != nil { fmt.Println("fail: ", err) + return } // Check if we want to update. if !pwd.IsPreferred(hash) { - newHash, err := pwd.Passwd(pass, "") + newHash, err := pwd.Passwd(pass, nil) if err != nil { fmt.Println("fail: ", err) + return } - fmt.Println("new hash:", newHash) + fmt.Println("new hash:", string(newHash)[:31], "...") } // Output: - // new hash: $1$81ed91e1131a3a5a50d8a68e8ef85fa0 + // new hash: $argon2id$v=19,m=65536,t=1,p=4$ ... } func TestPasswdHash(t *testing.T) { type testCase struct { - pass, hash string + pass, hash []byte } tests := []testCase{ - {"passwd", "passwd"}, - {"passwd", "$plain$passwd"}, + {[]byte("passwd"), []byte("passwd")}, + {[]byte("passwd"), []byte("$plain$passwd")}, } algos := []passwd.Passwder{&plainPasswd{}} is := is.New(t) // Generate additional test cases for each algo. for _, algo := range algos { - hash, err := algo.Passwd("passwd", "") + hash, err := algo.Passwd([]byte("passwd"), nil) is.NoErr(err) - tests = append(tests, testCase{"passwd", hash}) + tests = append(tests, testCase{[]byte("passwd"), hash}) } pass := passwd.New(algos...) @@ -98,9 +104,9 @@ func TestPasswdIsPreferred(t *testing.T) { pass := passwd.New(&plainPasswd{}) - ok := pass.IsPreferred("$plain$passwd") + ok := pass.IsPreferred([]byte("$plain$passwd")) is.True(ok) - ok = pass.IsPreferred("$foo$passwd") + ok = pass.IsPreferred([]byte("$foo$passwd")) is.True(!ok) } diff --git a/pkg/argon2/argon2.go b/pkg/argon2/argon2.go index 1fed682..274dcd7 100644 --- a/pkg/argon2/argon2.go +++ b/pkg/argon2/argon2.go @@ -1,12 +1,12 @@ package argon2 import ( + "bytes" "crypto/rand" "crypto/subtle" "encoding/base64" "fmt" "strconv" - "strings" "golang.org/x/crypto/argon2" @@ -73,35 +73,35 @@ func NewArgon2id( } } -func (p *argon) Passwd(pass string, check string) (string, error) { +func (p *argon) Passwd(pass, check []byte) ([]byte, error) { var args *pwArgs var err error - if check == "" { + if check == nil { args = p.defaultArgs() _, err := rand.Read(args.salt) if err != nil { - return "", err + return nil, err } - args.hash = p.keyFn([]byte(pass), args.salt, args.time, args.memory, args.threads, args.keyLen) + args.hash = p.keyFn(pass, args.salt, args.time, args.memory, args.threads, args.keyLen) } else { args, err = p.parseArgs(check) if err != nil { - return "", err + return nil, err } - hash := p.keyFn([]byte(pass), args.salt, args.time, args.memory, args.threads, args.keyLen) + hash := p.keyFn(pass, args.salt, args.time, args.memory, args.threads, args.keyLen) if subtle.ConstantTimeCompare(hash, args.hash) == 0 { - return "", passwd.ErrNoMatch + return nil, passwd.ErrNoMatch } } - return args.String(), nil + return args.Bytes(), nil } func (p *argon) ApplyPasswd(passwd *passwd.Passwd) { passwd.Register(p.name, p) } -func (s *argon) IsPreferred(hash string) bool { +func (s *argon) IsPreferred(hash []byte) bool { args, err := s.parseArgs(hash) if err != nil { return false @@ -142,29 +142,33 @@ func (p *argon) defaultArgs() *pwArgs { salt: make([]byte, p.saltLen), } } -func (p *argon) parseArgs(hash string) (*pwArgs, error) { - pfx := "$" + p.name + "$" +func (p *argon) parseArgs(hash []byte) (*pwArgs, error) { + pfx := []byte("$" + p.name + "$") - if !strings.HasPrefix(hash, pfx) { + if !bytes.HasPrefix(hash, pfx) { return nil, fmt.Errorf("%w: missing prefix", passwd.ErrBadHash) } - hash = strings.TrimPrefix(hash, pfx) - args, hash, ok := strings.Cut(hash, "$") + hash = bytes.TrimPrefix(hash, pfx) + args, hash, ok := bytes.Cut(hash, []byte("$")) if !ok { return nil, fmt.Errorf("%w: missing args", passwd.ErrBadHash) } - salt, hash, ok := strings.Cut(hash, "$") + salt, hash, ok := bytes.Cut(hash, []byte("$")) if !ok { return nil, fmt.Errorf("%w: missing salt", passwd.ErrBadHash) } var err error pass := p.defaultArgs() - pass.salt, err = base64.RawStdEncoding.DecodeString(salt) + + pass.salt = make([]byte, base64.RawStdEncoding.DecodedLen(len(salt))) + _, err = base64.RawStdEncoding.Decode(pass.salt, salt) if err != nil { return nil, fmt.Errorf("%w: corrupt salt part", passwd.ErrBadHash) } - pass.hash, err = base64.RawStdEncoding.DecodeString(hash) + + pass.hash = make([]byte, base64.RawStdEncoding.DecodedLen(len(hash))) + _, err = base64.RawStdEncoding.Decode(pass.hash, hash) if err != nil { return nil, fmt.Errorf("%w: corrupt hash part", passwd.ErrBadHash) } @@ -172,23 +176,23 @@ func (p *argon) parseArgs(hash string) (*pwArgs, error) { pass.name = p.name pass.keyLen = uint32(len(pass.hash)) - for _, part := range strings.Split(args, ",") { - if k, v, ok := strings.Cut(part, "="); ok { - switch k { + for _, part := range bytes.Split(args, []byte(",")) { + if k, v, ok := bytes.Cut(part, []byte("=")); ok { + switch string(k) { case "v": - if i, err := strconv.ParseUint(v, 10, 8); err == nil { + if i, err := strconv.ParseUint(string(v), 10, 8); err == nil { pass.version = uint8(i) } case "m": - if i, err := strconv.ParseUint(v, 10, 32); err == nil { + if i, err := strconv.ParseUint(string(v), 10, 32); err == nil { pass.memory = uint32(i) } case "t": - if i, err := strconv.ParseUint(v, 10, 32); err == nil { + if i, err := strconv.ParseUint(string(v), 10, 32); err == nil { pass.time = uint32(i) } case "p": - if i, err := strconv.ParseUint(v, 10, 8); err == nil { + if i, err := strconv.ParseUint(string(v), 10, 8); err == nil { pass.threads = uint8(i) } } @@ -209,9 +213,19 @@ type pwArgs struct { hash []byte } -func (p *pwArgs) String() string { - salt := base64.RawStdEncoding.EncodeToString(p.salt) - hash := base64.RawStdEncoding.EncodeToString(p.hash) +func (p *pwArgs) Bytes() []byte { + var b bytes.Buffer - return fmt.Sprintf("$%s$v=%d,m=%d,t=%d,p=%d$%s$%s", p.name, p.version, p.memory, p.time, p.threads, salt, hash) + fmt.Fprintf(&b, "$%s$v=%d,m=%d,t=%d,p=%d$", p.name, p.version, p.memory, p.time, p.threads) + + salt := make([]byte, base64.RawURLEncoding.EncodedLen(len(p.salt))) + base64.RawStdEncoding.Encode(salt, p.salt) + b.Write(salt) + + hash := make([]byte, base64.RawURLEncoding.EncodedLen(len(p.hash))) + base64.RawStdEncoding.Encode(hash, p.hash) + b.WriteRune('$') + b.Write(hash) + + return b.Bytes() } diff --git a/pkg/argon2/argon2_test.go b/pkg/argon2/argon2_test.go index f65e27e..1d5eb08 100644 --- a/pkg/argon2/argon2_test.go +++ b/pkg/argon2/argon2_test.go @@ -13,7 +13,7 @@ import ( func TestPasswdHash(t *testing.T) { type testCase struct { - pass, hash string + pass, hash []byte } tests := []testCase{} @@ -22,9 +22,9 @@ func TestPasswdHash(t *testing.T) { is := is.New(t) // Generate additional test cases for each algo. for _, algo := range algos { - hash, err := algo.Passwd("passwd", "") + hash, err := algo.Passwd([]byte("passwd"), nil) is.NoErr(err) - tests = append(tests, testCase{"passwd", hash}) + tests = append(tests, testCase{[]byte("passwd"), hash}) } pass := passwd.New(algos...) @@ -45,12 +45,12 @@ func TestPasswdIsPreferred(t *testing.T) { pass := passwd.New(argon2.Argon2i, &unix.MD5{}) - ok := pass.IsPreferred("$argon2i$v=19,m=32768,t=3,p=4$LdaB2Z4EI4lwpxTc78QUFw$VhlPSK0tdF226QCLC24IIrmQcMBmg47Ik9h/Yq6htFI") + ok := pass.IsPreferred([]byte("$argon2i$v=19,m=32768,t=3,p=4$LdaB2Z4EI4lwpxTc78QUFw$VhlPSK0tdF226QCLC24IIrmQcMBmg47Ik9h/Yq6htFI")) is.True(ok) - ok = pass.IsPreferred("$argon2i$v=19,m=1024,t=2,p=4$LdaB2Z4EI4lwpxTc78QUFw$VhlPSK0tdF226QCLC24IIrmQcMBmg47Ik9h/Yq6htFI") + ok = pass.IsPreferred([]byte("$argon2i$v=19,m=1024,t=2,p=4$LdaB2Z4EI4lwpxTc78QUFw$VhlPSK0tdF226QCLC24IIrmQcMBmg47Ik9h/Yq6htFI")) is.True(!ok) - ok = pass.IsPreferred("$1$76a2173be6393254e72ffa4d6df1030a") + ok = pass.IsPreferred([]byte("$1$76a2173be6393254e72ffa4d6df1030a")) is.True(!ok) } diff --git a/pkg/scrypt/scrypt.go b/pkg/scrypt/scrypt.go index c69f7ec..4fc7e31 100644 --- a/pkg/scrypt/scrypt.go +++ b/pkg/scrypt/scrypt.go @@ -1,13 +1,13 @@ package scrypt import ( + "bytes" "crypto/rand" "crypto/subtle" "encoding/base64" "encoding/hex" "fmt" "strconv" - "strings" "github.com/sour-is/go-passwd" "golang.org/x/crypto/scrypt" @@ -22,8 +22,10 @@ type scryptpw struct { name string encoder interface { - EncodeToString(src []byte) string - DecodeString(s string) ([]byte, error) + EncodedLen(n int) int + Encode(dst, src []byte) + DecodedLen(x int) int + Decode(dst, src []byte) (n int, err error) } } type scryptArgs struct { @@ -38,8 +40,10 @@ type scryptArgs struct { hash []byte encoder interface { - EncodeToString(src []byte) string - DecodeString(s string) ([]byte, error) + EncodedLen(n int) int + Encode(dst, src []byte) + DecodedLen(x int) int + Decode(dst, src []byte) (n int, err error) } } @@ -55,37 +59,37 @@ var Scrypt2 = &scryptpw{ name: "s2", encoder: base64.RawStdEncoding, } -func (s *scryptpw) Passwd(pass string, check string) (string, error) { +func (s *scryptpw) Passwd(pass, check []byte) ([]byte, error) { var args *scryptArgs var err error - if check == "" { + if check == nil { args = s.defaultArgs() _, err := rand.Read(args.salt) if err != nil { - return "", err + return nil, err } - args.hash, err = scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen) + args.hash, err = scrypt.Key(pass, args.salt, args.N, args.R, args.P, args.DKLen) if err != nil { - return "", err + return nil, err } } else { args, err = s.parseArgs(check) if err != nil { - return "", err + return nil, err } hash, err := scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen) if err != nil { - return "", err + return nil, err } if subtle.ConstantTimeCompare(hash, args.hash) == 0 { - return "", passwd.ErrNoMatch + return nil, passwd.ErrNoMatch } } - return args.String(), nil + return args.Bytes(), nil } func (s *scryptpw) ApplyPasswd(p *passwd.Passwd) { p.Register(s.name, s) @@ -93,7 +97,7 @@ func (s *scryptpw) ApplyPasswd(p *passwd.Passwd) { p.SetFallthrough(s) } } -func (s *scryptpw) IsPreferred(hash string) bool { +func (s *scryptpw) IsPreferred(hash []byte) bool { args, err := s.parseArgs(hash) if err != nil { return false @@ -129,49 +133,52 @@ func (s *scryptpw) defaultArgs() *scryptArgs { encoder: s.encoder, } } -func (s *scryptpw) parseArgs(hash string) (*scryptArgs, error) { + +func (s *scryptpw) parseArgs(hash []byte) (*scryptArgs, error) { args := s.defaultArgs() - name := "$" + s.name + "$" - hash = strings.TrimPrefix(hash, name) + name := []byte("$" + s.name + "$") + hash = bytes.TrimPrefix(hash, name) - N, hash, ok := strings.Cut(hash, "$") + N, hash, ok := bytes.Cut(hash, []byte("$")) if !ok { return nil, fmt.Errorf("%w: missing args: N", passwd.ErrBadHash) } - if n, err := strconv.Atoi(N); err == nil { + if n, err := strconv.Atoi(string(N)); err == nil { args.N = n } - R, hash, ok := strings.Cut(hash, "$") + R, hash, ok := bytes.Cut(hash, []byte("$")) if !ok { return nil, fmt.Errorf("%w: missing args: R", passwd.ErrBadHash) } - if r, err := strconv.Atoi(R); err == nil { + if r, err := strconv.Atoi(string(R)); err == nil { args.R = r } - P, hash, ok := strings.Cut(hash, "$") + P, hash, ok := bytes.Cut(hash, []byte("$")) if !ok { return nil, fmt.Errorf("%w: missing args: P", passwd.ErrBadHash) } - if p, err := strconv.Atoi(P); err == nil { + if p, err := strconv.Atoi(string(P)); err == nil { args.P = p } - salt, hash, ok := strings.Cut(hash, "$") + salt, hash, ok := bytes.Cut(hash, []byte("$")) if !ok { return nil, fmt.Errorf("%w: missing args: salt", passwd.ErrBadHash) } var err error - args.salt, err = s.encoder.DecodeString(salt) + args.salt = make([]byte, s.encoder.DecodedLen(len(salt))) + _, err = s.encoder.Decode(args.salt, salt) if err != nil { return nil, fmt.Errorf("%w: corrupt salt part", passwd.ErrBadHash) } args.SaltLen = len(args.salt) - args.hash, err = s.encoder.DecodeString(hash) + args.hash = make([]byte, s.encoder.DecodedLen(len(hash))) + _, err = s.encoder.Decode(args.hash, hash) if err != nil { return nil, fmt.Errorf("%w: corrupt hash part", passwd.ErrBadHash) } @@ -179,22 +186,38 @@ func (s *scryptpw) parseArgs(hash string) (*scryptArgs, error) { return args, nil } -func (s *scryptArgs) String() string { - var name string - if s.name != "s1" { - name = "$" + s.name + "$" - } - salt := s.encoder.EncodeToString(s.salt) - hash := s.encoder.EncodeToString(s.hash) - return fmt.Sprintf("%s%d$%d$%d$%s$%s", name, s.N, s.R, s.P, salt, hash) +func (s *scryptArgs) Bytes() []byte { + var b bytes.Buffer + + if s.name != "s1" { + b.WriteRune('$') + b.WriteString(s.name) + b.WriteRune('$') + } + + fmt.Fprintf(&b, "%d$%d$%d", s.N, s.R, s.P) + + salt := make([]byte, s.encoder.EncodedLen(len(s.salt))) + s.encoder.Encode(salt, s.salt) + b.WriteRune('$') + b.Write(salt) + + hash := make([]byte, s.encoder.EncodedLen(len(s.hash))) + s.encoder.Encode(hash, s.hash) + b.WriteRune('$') + b.Write(hash) + + return b.Bytes() } type hexenc struct{} -func (hexenc) EncodeToString(src []byte) string { - return hex.EncodeToString(src) +func (hexenc) Encode(dst, src []byte) { + hex.Encode(dst, src) } -func (hexenc) DecodeString(s string) ([]byte, error) { - return hex.DecodeString(s) +func (hexenc) EncodedLen(n int) int { return hex.EncodedLen(n) } +func (hexenc) Decode(dst, src []byte) (n int, err error) { + return hex.Decode(dst, src) } +func (hexenc) DecodedLen(x int) int { return hex.DecodedLen(x) } diff --git a/pkg/scrypt/scrypt_test.go b/pkg/scrypt/scrypt_test.go index 9527840..e6bda93 100644 --- a/pkg/scrypt/scrypt_test.go +++ b/pkg/scrypt/scrypt_test.go @@ -13,7 +13,7 @@ import ( func TestPasswdHash(t *testing.T) { type testCase struct { - pass, hash string + pass, hash []byte } tests := []testCase{} @@ -22,9 +22,9 @@ func TestPasswdHash(t *testing.T) { is := is.New(t) // Generate additional test cases for each algo. for _, algo := range algos { - hash, err := algo.Passwd("passwd", "") + hash, err := algo.Passwd([]byte("passwd"), nil) is.NoErr(err) - tests = append(tests, testCase{"passwd", hash}) + tests = append(tests, testCase{[]byte("passwd"), hash}) } pass := passwd.New(algos...) @@ -45,15 +45,15 @@ func TestPasswdIsPreferred(t *testing.T) { pass := passwd.New(scrypt.Scrypt2, &unix.MD5{}) - ok := pass.IsPreferred("16384$8$1$b97ed09792dd74b71dcb7fc8caf04a89$0b5cda82b17298ec4bf6d2139f7ea8587d8478fcc68c09e2506a7cf08b2817c0") + ok := pass.IsPreferred([]byte("16384$8$1$b97ed09792dd74b71dcb7fc8caf04a89$0b5cda82b17298ec4bf6d2139f7ea8587d8478fcc68c09e2506a7cf08b2817c0")) is.True(!ok) - ok = pass.IsPreferred("$s2$16384$8$1$iEdwbgXyKa5GNGNW/0NsOA$9YN/hzbskVVDZ887ppqv5su0n8SxVXwDB/rhVhAc9xQ") + ok = pass.IsPreferred([]byte("$s2$16384$8$1$iEdwbgXyKa5GNGNW/0NsOA$9YN/hzbskVVDZ887ppqv5su0n8SxVXwDB/rhVhAc9xQ")) is.True(ok) - ok = pass.IsPreferred("$s2$16384$7$1$iEdwbgXyKa5GNGNW/0NsOA$9YN/hzbskVVDZ887ppqv5su0n8SxVXwDB/rhVhAc9xQ") + ok = pass.IsPreferred([]byte("$s2$16384$7$1$iEdwbgXyKa5GNGNW/0NsOA$9YN/hzbskVVDZ887ppqv5su0n8SxVXwDB/rhVhAc9xQ")) is.True(!ok) - ok = pass.IsPreferred("$1$76a2173be6393254e72ffa4d6df1030a") + ok = pass.IsPreferred([]byte("$1$76a2173be6393254e72ffa4d6df1030a")) is.True(!ok) } diff --git a/pkg/unix/unix.go b/pkg/unix/unix.go index 7f7ddee..a3dbbd1 100644 --- a/pkg/unix/unix.go +++ b/pkg/unix/unix.go @@ -16,11 +16,11 @@ var All = []passwd.Passwder{ type MD5 struct{} -func (p *MD5) Passwd(pass string, check string) (string, error) { +func (p *MD5) Passwd(pass, check []byte) ([]byte, error) { h := md5.New() - fmt.Fprint(h, pass) + h.Write(pass) - hash := fmt.Sprintf("$1$%x", h.Sum(nil)) + hash := []byte(fmt.Sprintf("$1$%x", h.Sum(nil))) return hashCheck(hash, check) } @@ -31,18 +31,18 @@ func (p *MD5) ApplyPasswd(passwd *passwd.Passwd) { type Blowfish struct{} -func (p *Blowfish) Passwd(pass string, check string) (string, error) { - if check == "" { - b, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost) +func (p *Blowfish) Passwd(pass, check []byte) ([]byte, error) { + if check == nil { + b, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost) if err != nil { - return "", err + return nil, err } - return string(b), nil + return b, nil } - err := bcrypt.CompareHashAndPassword([]byte(check), []byte(pass)) + err := bcrypt.CompareHashAndPassword(check, pass) if err != nil { - return "", err + return nil, err } return check, nil } @@ -51,42 +51,12 @@ func (p *Blowfish) ApplyPasswd(passwd *passwd.Passwd) { passwd.Register("2a", p) } -// type SHA256 struct{} - -// func (p *SHA256) Passwd(pass string, check string) (string, error) { -// h := sha256.New() -// fmt.Fprint(h, pass) - -// hash := fmt.Sprintf("$5$%x", h.Sum(nil)) - -// return hashCheck(hash, check) -// } - -// func (p *SHA256) ApplyPasswd(passwd *passwd.Passwd) { -// passwd.Register("5", p) -// } - -// type SHA512 struct{} - -// func (p *SHA512) Passwd(pass string, check string) (string, error) { -// h := sha512.New() -// fmt.Fprint(h, pass) - -// hash := fmt.Sprintf("$6$%x", h.Sum(nil)) - -// return hashCheck(hash, check) -// } - -// func (p *SHA512) ApplyPasswd(passwd *passwd.Passwd) { -// passwd.Register("6", p) -// } - -func hashCheck(hash, check string) (string, error) { - if check == "" { +func hashCheck(hash, check []byte) ([]byte, error) { + if check == nil { return hash, nil } - if subtle.ConstantTimeCompare([]byte(hash), []byte(check)) == 1 { + if subtle.ConstantTimeCompare(hash, check) == 1 { return hash, nil } diff --git a/pkg/unix/unix_test.go b/pkg/unix/unix_test.go index c36642d..1715338 100644 --- a/pkg/unix/unix_test.go +++ b/pkg/unix/unix_test.go @@ -12,20 +12,20 @@ import ( func TestPasswdHash(t *testing.T) { type testCase struct { - pass, hash string + pass, hash []byte } tests := []testCase{ - {"passwd", "$1$76a2173be6393254e72ffa4d6df1030a"}, - {"passwd", "$2a$10$GkJwB.nOaaeAvRGgyl2TI.kruM8e.iIo.OozgdslegpNlC/vIFKRq"}, + {[]byte("passwd"), []byte("$1$76a2173be6393254e72ffa4d6df1030a")}, + {[]byte("passwd"), []byte("$2a$10$GkJwB.nOaaeAvRGgyl2TI.kruM8e.iIo.OozgdslegpNlC/vIFKRq")}, } is := is.New(t) // Generate additional test cases for each algo. for _, algo := range unix.All { - hash, err := algo.Passwd("passwd", "") + hash, err := algo.Passwd([]byte("passwd"), nil) is.NoErr(err) - tests = append(tests, testCase{"passwd", hash}) + tests = append(tests, testCase{[]byte("passwd"), hash}) } pass := passwd.New(unix.All...) @@ -35,7 +35,7 @@ func TestPasswdHash(t *testing.T) { is := is.New(t) hash, err := pass.Passwd(tt.pass, tt.hash) - is.Equal(hash, tt.hash) + is.Equal(string(hash), string(tt.hash)) is.NoErr(err) }) }