224 lines
4.7 KiB
Go
224 lines
4.7 KiB
Go
package scrypt
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/sour-is/go-passwd"
|
|
"golang.org/x/crypto/scrypt"
|
|
)
|
|
|
|
type scryptpw struct {
|
|
N int // CPU/memory cost parameter (logN)
|
|
R int // block size parameter (octets)
|
|
P int // parallelization parameter (positive int)
|
|
SaltLen int // bytes to use as salt (octets)
|
|
DKLen int // length of the derived key (octets)
|
|
|
|
name string
|
|
encoder interface {
|
|
EncodedLen(n int) int
|
|
Encode(dst, src []byte)
|
|
DecodedLen(x int) int
|
|
Decode(dst, src []byte) (n int, err error)
|
|
}
|
|
}
|
|
type scryptArgs struct {
|
|
N int // CPU/memory cost parameter (logN)
|
|
R int // block size parameter (octets)
|
|
P int // parallelization parameter (positive int)
|
|
SaltLen int // bytes to use as salt (octets)
|
|
DKLen int // length of the derived key (octets)
|
|
|
|
name string
|
|
salt []byte
|
|
hash []byte
|
|
|
|
encoder interface {
|
|
EncodedLen(n int) int
|
|
Encode(dst, src []byte)
|
|
DecodedLen(x int) int
|
|
Decode(dst, src []byte) (n int, err error)
|
|
}
|
|
}
|
|
|
|
var All = []passwd.Passwder{Simple, Scrypt2}
|
|
|
|
var Simple = &scryptpw{
|
|
N: 16384, R: 8, P: 1, SaltLen: 16, DKLen: 32,
|
|
name: "s1", encoder: hexenc{},
|
|
}
|
|
|
|
var Scrypt2 = &scryptpw{
|
|
N: 16384, R: 8, P: 1, SaltLen: 16, DKLen: 32,
|
|
name: "s2", encoder: base64.RawStdEncoding,
|
|
}
|
|
|
|
func (s *scryptpw) Passwd(pass, check []byte) ([]byte, error) {
|
|
var args *scryptArgs
|
|
var err error
|
|
|
|
if check == nil {
|
|
args = s.defaultArgs()
|
|
_, err := rand.Read(args.salt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
args.hash, err = scrypt.Key(pass, args.salt, args.N, args.R, args.P, args.DKLen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
} else {
|
|
args, err = s.parseArgs(check)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hash, err := scrypt.Key([]byte(pass), args.salt, args.N, args.R, args.P, args.DKLen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if subtle.ConstantTimeCompare(hash, args.hash) == 0 {
|
|
return nil, passwd.ErrNoMatch
|
|
}
|
|
}
|
|
|
|
return args.Bytes(), nil
|
|
}
|
|
func (s *scryptpw) ApplyPasswd(p *passwd.Passwd) {
|
|
p.Register(s.name, s)
|
|
if s.name == "s1" {
|
|
p.SetFallthrough(s)
|
|
}
|
|
}
|
|
func (s *scryptpw) IsPreferred(hash []byte) bool {
|
|
args, err := s.parseArgs(hash)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if args.N < s.N {
|
|
return false
|
|
}
|
|
if args.R < s.R {
|
|
return false
|
|
}
|
|
if args.P < s.P {
|
|
return false
|
|
}
|
|
if args.SaltLen < s.SaltLen {
|
|
return false
|
|
}
|
|
if args.DKLen < s.DKLen {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
func (s *scryptpw) defaultArgs() *scryptArgs {
|
|
return &scryptArgs{
|
|
name: s.name,
|
|
N: s.N,
|
|
R: s.R,
|
|
P: s.P,
|
|
DKLen: s.DKLen,
|
|
SaltLen: s.SaltLen,
|
|
salt: make([]byte, s.SaltLen),
|
|
encoder: s.encoder,
|
|
}
|
|
}
|
|
|
|
func (s *scryptpw) parseArgs(hash []byte) (*scryptArgs, error) {
|
|
args := s.defaultArgs()
|
|
|
|
name := []byte("$" + s.name + "$")
|
|
hash = bytes.TrimPrefix(hash, name)
|
|
|
|
N, hash, ok := bytes.Cut(hash, []byte("$"))
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: N", passwd.ErrBadHash)
|
|
}
|
|
if n, err := strconv.Atoi(string(N)); err == nil {
|
|
args.N = n
|
|
}
|
|
|
|
R, hash, ok := bytes.Cut(hash, []byte("$"))
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: R", passwd.ErrBadHash)
|
|
}
|
|
if r, err := strconv.Atoi(string(R)); err == nil {
|
|
args.R = r
|
|
}
|
|
|
|
P, hash, ok := bytes.Cut(hash, []byte("$"))
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: P", passwd.ErrBadHash)
|
|
}
|
|
if p, err := strconv.Atoi(string(P)); err == nil {
|
|
args.P = p
|
|
}
|
|
|
|
salt, hash, ok := bytes.Cut(hash, []byte("$"))
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: missing args: salt", passwd.ErrBadHash)
|
|
}
|
|
|
|
var err error
|
|
args.salt = make([]byte, s.encoder.DecodedLen(len(salt)))
|
|
_, err = s.encoder.Decode(args.salt, salt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: corrupt salt part", passwd.ErrBadHash)
|
|
}
|
|
args.SaltLen = len(args.salt)
|
|
|
|
args.hash = make([]byte, s.encoder.DecodedLen(len(hash)))
|
|
_, err = s.encoder.Decode(args.hash, hash)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: corrupt hash part", passwd.ErrBadHash)
|
|
}
|
|
args.DKLen = len(args.hash)
|
|
|
|
return args, nil
|
|
}
|
|
|
|
func (s *scryptArgs) Bytes() []byte {
|
|
var b bytes.Buffer
|
|
|
|
if s.name != "s1" {
|
|
b.WriteRune('$')
|
|
b.WriteString(s.name)
|
|
b.WriteRune('$')
|
|
}
|
|
|
|
fmt.Fprintf(&b, "%d$%d$%d", s.N, s.R, s.P)
|
|
|
|
salt := make([]byte, s.encoder.EncodedLen(len(s.salt)))
|
|
s.encoder.Encode(salt, s.salt)
|
|
b.WriteRune('$')
|
|
b.Write(salt)
|
|
|
|
hash := make([]byte, s.encoder.EncodedLen(len(s.hash)))
|
|
s.encoder.Encode(hash, s.hash)
|
|
b.WriteRune('$')
|
|
b.Write(hash)
|
|
|
|
return b.Bytes()
|
|
}
|
|
|
|
type hexenc struct{}
|
|
|
|
func (hexenc) Encode(dst, src []byte) {
|
|
hex.Encode(dst, src)
|
|
}
|
|
func (hexenc) EncodedLen(n int) int { return hex.EncodedLen(n) }
|
|
func (hexenc) Decode(dst, src []byte) (n int, err error) {
|
|
return hex.Decode(dst, src)
|
|
}
|
|
func (hexenc) DecodedLen(x int) int { return hex.DecodedLen(x) }
|