201 lines
4.2 KiB
Go
201 lines
4.2 KiB
Go
package scrypt
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/sour-is/go-passwd"
|
|
"golang.org/x/crypto/scrypt"
|
|
)
|
|
|
|
type scryptpw struct {
|
|
N int // CPU/memory cost parameter (logN)
|
|
R int // block size parameter (octets)
|
|
P int // parallelisation parameter (positive int)
|
|
SaltLen int // bytes to use as salt (octets)
|
|
DKLen int // length of the derived key (octets)
|
|
|
|
name string
|
|
encoder interface {
|
|
EncodeToString(src []byte) string
|
|
DecodeString(s string) ([]byte, error)
|
|
}
|
|
}
|
|
type scryptArgs struct {
|
|
N int // CPU/memory cost parameter (logN)
|
|
R int // block size parameter (octets)
|
|
P int // parallelisation parameter (positive int)
|
|
SaltLen int // bytes to use as salt (octets)
|
|
DKLen int // length of the derived key (octets)
|
|
|
|
name string
|
|
salt []byte
|
|
hash []byte
|
|
|
|
encoder interface {
|
|
EncodeToString(src []byte) string
|
|
DecodeString(s string) ([]byte, error)
|
|
}
|
|
}
|
|
|
|
var All = []passwd.Passwder{Simple, Scrypt2}
|
|
|
|
var Simple = &scryptpw{
|
|
N: 16384, R: 8, P: 1, SaltLen: 16, DKLen: 32,
|
|
name: "s1", encoder: hexenc{},
|
|
}
|
|
|
|
var Scrypt2 = &scryptpw{
|
|
N: 16384, R: 8, P: 1, SaltLen: 16, DKLen: 32,
|
|
name: "s2", encoder: base64.RawStdEncoding,
|
|
}
|
|
|
|
func (s *scryptpw) Passwd(pass string, check string) (string, error) {
|
|
var args *scryptArgs
|
|
var err error
|
|
|
|
if check == "" {
|
|
args = s.defaultArgs()
|
|
_, err := rand.Read(args.salt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
args.hash, err = scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
} else {
|
|
args, err = s.parseArgs(check)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
hash, err := scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if subtle.ConstantTimeCompare(hash, args.hash) == 0 {
|
|
return "", passwd.ErrNoMatch
|
|
}
|
|
}
|
|
|
|
return args.String(), nil
|
|
}
|
|
func (s *scryptpw) ApplyPasswd(p *passwd.Passwd) {
|
|
p.Register(s.name, s)
|
|
if s.name == "s1" {
|
|
p.SetFallthrough(s)
|
|
}
|
|
}
|
|
func (s *scryptpw) IsPreferred(hash string) bool {
|
|
args, err := s.parseArgs(hash)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if args.N < s.N {
|
|
return false
|
|
}
|
|
if args.R < s.R {
|
|
return false
|
|
}
|
|
if args.P < s.P {
|
|
return false
|
|
}
|
|
if args.SaltLen < s.SaltLen {
|
|
return false
|
|
}
|
|
if args.DKLen < s.DKLen {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
func (s *scryptpw) defaultArgs() *scryptArgs {
|
|
return &scryptArgs{
|
|
name: s.name,
|
|
N: s.N,
|
|
R: s.R,
|
|
P: s.P,
|
|
DKLen: s.DKLen,
|
|
SaltLen: s.SaltLen,
|
|
salt: make([]byte, s.SaltLen),
|
|
encoder: s.encoder,
|
|
}
|
|
}
|
|
func (s *scryptpw) parseArgs(hash string) (*scryptArgs, error) {
|
|
args := s.defaultArgs()
|
|
|
|
name := "$" + s.name + "$"
|
|
hash = strings.TrimPrefix(hash, name)
|
|
|
|
N, hash, ok := strings.Cut(hash, "$")
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: N", passwd.ErrBadHash)
|
|
}
|
|
if n, err := strconv.Atoi(N); err == nil {
|
|
args.N = n
|
|
}
|
|
|
|
R, hash, ok := strings.Cut(hash, "$")
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: R", passwd.ErrBadHash)
|
|
}
|
|
if r, err := strconv.Atoi(R); err == nil {
|
|
args.R = r
|
|
}
|
|
|
|
P, hash, ok := strings.Cut(hash, "$")
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: P", passwd.ErrBadHash)
|
|
}
|
|
if p, err := strconv.Atoi(P); err == nil {
|
|
args.P = p
|
|
}
|
|
|
|
salt, hash, ok := strings.Cut(hash, "$")
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: salt", passwd.ErrBadHash)
|
|
}
|
|
|
|
var err error
|
|
args.salt, err = s.encoder.DecodeString(salt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: corrupt salt part", passwd.ErrBadHash)
|
|
}
|
|
args.SaltLen = len(args.salt)
|
|
|
|
args.hash, err = s.encoder.DecodeString(hash)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: corrupt hash part", passwd.ErrBadHash)
|
|
}
|
|
args.DKLen = len(args.hash)
|
|
|
|
return args, nil
|
|
}
|
|
func (s *scryptArgs) String() string {
|
|
var name string
|
|
if s.name != "s1" {
|
|
name = "$" + s.name + "$"
|
|
}
|
|
salt := s.encoder.EncodeToString(s.salt)
|
|
hash := s.encoder.EncodeToString(s.hash)
|
|
|
|
return fmt.Sprintf("%s%d$%d$%d$%s$%s", name, s.N, s.R, s.P, salt, hash)
|
|
}
|
|
|
|
type hexenc struct{}
|
|
|
|
func (hexenc) EncodeToString(src []byte) string {
|
|
return hex.EncodeToString(src)
|
|
}
|
|
func (hexenc) DecodeString(s string) ([]byte, error) {
|
|
return hex.DecodeString(s)
|
|
}
|