initial commit
This commit is contained in:
174
pkg/scrypt/scrypt.go
Normal file
174
pkg/scrypt/scrypt.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package scrypt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sour-is/go-passwd"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
type scryptpw struct {
|
||||
N int // CPU/memory cost parameter (logN)
|
||||
R int // block size parameter (octets)
|
||||
P int // parallelisation parameter (positive int)
|
||||
SaltLen int // bytes to use as salt (octets)
|
||||
DKLen int // length of the derived key (octets)
|
||||
|
||||
name string
|
||||
encoder interface {
|
||||
EncodeToString(src []byte) string
|
||||
DecodeString(s string) ([]byte, error)
|
||||
}
|
||||
}
|
||||
type scryptArgs struct {
|
||||
N int // CPU/memory cost parameter (logN)
|
||||
R int // block size parameter (octets)
|
||||
P int // parallelisation parameter (positive int)
|
||||
SaltLen int // bytes to use as salt (octets)
|
||||
DKLen int // length of the derived key (octets)
|
||||
|
||||
name string
|
||||
salt []byte
|
||||
hash []byte
|
||||
|
||||
encoder interface {
|
||||
EncodeToString(src []byte) string
|
||||
DecodeString(s string) ([]byte, error)
|
||||
}
|
||||
}
|
||||
|
||||
var All = []passwd.Passwder{Simple, Scrypt2}
|
||||
|
||||
var Simple = &scryptpw{
|
||||
N: 16384, R: 8, P: 1, SaltLen: 16, DKLen: 32,
|
||||
name: "s1", encoder: hexenc{},
|
||||
}
|
||||
|
||||
var Scrypt2 = &scryptpw{
|
||||
N: 16384, R: 8, P: 1, SaltLen: 16, DKLen: 32,
|
||||
name: "s2", encoder: base64.RawStdEncoding,
|
||||
}
|
||||
|
||||
func (s *scryptpw) Passwd(pass string, check string) (string, error) {
|
||||
var args *scryptArgs
|
||||
var err error
|
||||
|
||||
if check == "" {
|
||||
args = s.defaultArgs()
|
||||
_, err := rand.Read(args.salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
args.hash, err = scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
} else {
|
||||
args, err = s.parseArgs(check)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hash, err := scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(hash, args.hash) == 0 {
|
||||
return "", passwd.ErrNoMatch
|
||||
}
|
||||
}
|
||||
|
||||
return args.String(), nil
|
||||
}
|
||||
func (s *scryptpw) ApplyPasswd(p *passwd.Passwd) {
|
||||
p.Register(s.name, s)
|
||||
if s.name == "s1" {
|
||||
p.SetFallthrough(s)
|
||||
}
|
||||
}
|
||||
func (s *scryptpw) defaultArgs() *scryptArgs {
|
||||
return &scryptArgs{
|
||||
name: s.name,
|
||||
N: s.N,
|
||||
R: s.R,
|
||||
P: s.P,
|
||||
DKLen: s.DKLen,
|
||||
SaltLen: s.SaltLen,
|
||||
salt: make([]byte, s.SaltLen),
|
||||
encoder: s.encoder,
|
||||
}
|
||||
}
|
||||
func (s *scryptpw) parseArgs(hash string) (*scryptArgs, error) {
|
||||
args := s.defaultArgs()
|
||||
|
||||
name := "$" + s.name + "$"
|
||||
hash = strings.TrimPrefix(hash, name)
|
||||
|
||||
N, hash, ok := strings.Cut(hash, "$")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: missing args: N", passwd.ErrBadHash)
|
||||
}
|
||||
if n, err := strconv.Atoi(N); err == nil {
|
||||
args.N = n
|
||||
}
|
||||
|
||||
R, hash, ok := strings.Cut(hash, "$")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: missing args: R", passwd.ErrBadHash)
|
||||
}
|
||||
if r, err := strconv.Atoi(R); err == nil {
|
||||
args.R = r
|
||||
}
|
||||
|
||||
P, hash, ok := strings.Cut(hash, "$")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: missing args: P", passwd.ErrBadHash)
|
||||
}
|
||||
if p, err := strconv.Atoi(P); err == nil {
|
||||
args.P = p
|
||||
}
|
||||
|
||||
salt, hash, ok := strings.Cut(hash, "$")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: missing args: salt", passwd.ErrBadHash)
|
||||
}
|
||||
|
||||
var err error
|
||||
args.salt, err = s.encoder.DecodeString(salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: corrupt salt part", passwd.ErrBadHash)
|
||||
}
|
||||
|
||||
args.hash, err = s.encoder.DecodeString(hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: corrupt hash part", passwd.ErrBadHash)
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
func (s *scryptArgs) String() string {
|
||||
var name string
|
||||
if s.name != "s1" {
|
||||
name = "$" + s.name + "$"
|
||||
}
|
||||
salt := s.encoder.EncodeToString(s.salt)
|
||||
hash := s.encoder.EncodeToString(s.hash)
|
||||
|
||||
return fmt.Sprintf("%s%d$%d$%d$%s$%s", name, s.N, s.R, s.P, salt, hash)
|
||||
}
|
||||
|
||||
type hexenc struct{}
|
||||
|
||||
func (hexenc) EncodeToString(src []byte) string {
|
||||
return hex.EncodeToString(src)
|
||||
}
|
||||
func (hexenc) DecodeString(s string) ([]byte, error) {
|
||||
return hex.DecodeString(s)
|
||||
}
|
||||
40
pkg/scrypt/scrypt_test.go
Normal file
40
pkg/scrypt/scrypt_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package scrypt_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/matryer/is"
|
||||
|
||||
"github.com/sour-is/go-passwd"
|
||||
"github.com/sour-is/go-passwd/pkg/scrypt"
|
||||
)
|
||||
|
||||
func TestPasswdHash(t *testing.T) {
|
||||
type testCase struct {
|
||||
pass, hash string
|
||||
}
|
||||
|
||||
tests := []testCase{}
|
||||
algos := scrypt.All
|
||||
|
||||
is := is.New(t)
|
||||
// Generate additional test cases for each algo.
|
||||
for _, algo := range algos {
|
||||
hash, err := algo.Passwd("passwd", "")
|
||||
is.NoErr(err)
|
||||
tests = append(tests, testCase{"passwd", hash})
|
||||
}
|
||||
|
||||
pass := passwd.New(algos...)
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(fmt.Sprint("Test-", i), func(t *testing.T) {
|
||||
is := is.New(t)
|
||||
|
||||
hash, err := pass.Passwd(tt.pass, tt.hash)
|
||||
is.NoErr(err)
|
||||
is.Equal(hash, tt.hash)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user