usersig.go 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package tencentsig
  2. import (
  3. "bytes"
  4. "compress/zlib"
  5. "crypto"
  6. "crypto/ecdsa"
  7. "crypto/elliptic"
  8. "crypto/rand"
  9. "crypto/sha256"
  10. "crypto/x509"
  11. "encoding/asn1"
  12. "encoding/base64"
  13. "encoding/hex"
  14. "encoding/json"
  15. "encoding/pem"
  16. "fmt"
  17. "io/ioutil"
  18. "math/big"
  19. "strings"
  20. "time"
  21. )
  22. const (
  23. accountType = "29296"
  24. version = "201512300000"
  25. defaultExpire = 3600 * 24 * 180
  26. )
  27. var (
  28. tlsReplace = map[string]string{
  29. "+": "*",
  30. "/": "-",
  31. "=": "_",
  32. }
  33. )
  34. type Conf struct {
  35. AccountType string `json:"TLS.account_type"`
  36. Identifier string `json:"TLS.identifier"`
  37. AppidAt3rd string `json:"TLS.appid_at_3rd"`
  38. SdkAppid string `json:"TLS.sdk_appid"`
  39. ExpireAfter string `json:"TLS.expire_after"`
  40. Version string `json:"TLS.version"`
  41. Time string `json:"TLS.time"`
  42. Sig string `json:"TLS.sig"`
  43. }
  44. func NewConf(sdkAppId string, identifier string, appidAt3rd string) *Conf {
  45. return &Conf{
  46. AccountType: accountType,
  47. Identifier: identifier,
  48. AppidAt3rd: appidAt3rd,
  49. SdkAppid: sdkAppId,
  50. ExpireAfter: fmt.Sprintf("%d", defaultExpire),
  51. Version: version,
  52. Time: fmt.Sprintf("%d", time.Now().Unix()),
  53. }
  54. }
  55. func (c *Conf) WithExpire(expireInSeconds int) *Conf {
  56. c.ExpireAfter = fmt.Sprintf("%d", expireInSeconds)
  57. return c
  58. }
  59. func (c *Conf) GenUserSig(pemPrivateKey string) (string, error) {
  60. var err error
  61. c.Sig, err = c.sign(pemPrivateKey)
  62. if err != nil {
  63. return "", err
  64. }
  65. data, _ := json.Marshal(c)
  66. var b bytes.Buffer
  67. z := zlib.NewWriter(&b)
  68. z.Write(data)
  69. z.Close()
  70. return base64Encode(b.Bytes()), nil
  71. }
  72. func VerifyUserSig(pemPublicKey string, userSig string) (*Conf, bool, error) {
  73. data, err := base64Decode(userSig)
  74. if err != nil {
  75. return nil, false, err
  76. }
  77. reader, err := zlib.NewReader(bytes.NewReader(data))
  78. if err != nil {
  79. return nil, false, err
  80. }
  81. data, err = ioutil.ReadAll(reader)
  82. if err != nil {
  83. return nil, false, err
  84. }
  85. var conf Conf
  86. err = json.Unmarshal(data, &conf)
  87. if err != nil {
  88. return nil, false, err
  89. }
  90. block, _ := pem.Decode([]byte(pemPublicKey))
  91. pk, err := x509.ParsePKIXPublicKey(block.Bytes)
  92. if err != nil {
  93. if strings.Contains(err.Error(), "unsupported elliptic curve") {
  94. var pki publicKeyInfo
  95. if _, err := asn1.Unmarshal(block.Bytes, &pki); err != nil {
  96. return nil, false, err
  97. }
  98. asn1Data := pki.PublicKey.RightAlign()
  99. fmt.Println(hex.EncodeToString(asn1Data))
  100. paramsData := pki.Algorithm.Parameters.FullBytes
  101. namedCurveOID := new(asn1.ObjectIdentifier)
  102. _, err = asn1.Unmarshal(paramsData, namedCurveOID)
  103. if err != nil {
  104. return nil, false, err
  105. }
  106. if namedCurveOID.Equal(oidNamedCurveS256) {
  107. pubk := new(ecdsa.PublicKey)
  108. pubk.Curve = S256()
  109. pubk.X, pubk.Y = elliptic.Unmarshal(pubk.Curve, asn1Data)
  110. pk = pubk
  111. }
  112. } else {
  113. return nil, false, err
  114. }
  115. }
  116. pubKey := pk.(*ecdsa.PublicKey)
  117. content := conf.signContent()
  118. hashed := sha256.Sum256([]byte(content))
  119. signature, _ := base64.StdEncoding.DecodeString(conf.Sig)
  120. r, s, err := pointsFromDER(signature)
  121. if err != nil {
  122. return nil, false, err
  123. }
  124. res := ecdsa.Verify(pubKey, hashed[:], r, s)
  125. return &conf, res, nil
  126. }
  127. func (c *Conf) signContent() string {
  128. var builder strings.Builder
  129. builder.WriteString("TLS.appid_at_3rd:")
  130. builder.WriteString(c.AppidAt3rd)
  131. builder.WriteString("\n")
  132. builder.WriteString("TLS.account_type:")
  133. builder.WriteString(c.AccountType)
  134. builder.WriteString("\n")
  135. builder.WriteString("TLS.identifier:")
  136. builder.WriteString(c.Identifier)
  137. builder.WriteString("\n")
  138. builder.WriteString("TLS.sdk_appid:")
  139. builder.WriteString(c.SdkAppid)
  140. builder.WriteString("\n")
  141. builder.WriteString("TLS.time:")
  142. builder.WriteString(c.Time)
  143. builder.WriteString("\n")
  144. builder.WriteString("TLS.expire_after:")
  145. builder.WriteString(c.ExpireAfter)
  146. builder.WriteString("\n")
  147. return builder.String()
  148. }
  149. func (c *Conf) sign(privateKey string) (string, error) {
  150. block, _ := pem.Decode([]byte(privateKey))
  151. pk, err := x509.ParsePKCS8PrivateKey(block.Bytes)
  152. if err != nil {
  153. if strings.Contains(err.Error(), "unknown elliptic curve") {
  154. var privKey pkcs8
  155. if _, err := asn1.Unmarshal(block.Bytes, &privKey); err != nil {
  156. return "", err
  157. }
  158. if privKey.Algo.Algorithm.Equal(oidPublicKeyECDSA) {
  159. namedCurveOID := new(asn1.ObjectIdentifier)
  160. asn1.Unmarshal(privKey.Algo.Parameters.FullBytes, namedCurveOID)
  161. if namedCurveOID.Equal(oidNamedCurveS256) {
  162. var ecPrivKey ecPrivateKey
  163. asn1.Unmarshal(privKey.PrivateKey, &ecPrivKey)
  164. k := new(ecdsa.PrivateKey)
  165. k.Curve = S256()
  166. d := new(big.Int)
  167. d.SetBytes(ecPrivKey.PrivateKey)
  168. k.D = d
  169. k.X, k.Y = S256().ScalarBaseMult(d.Bytes())
  170. pk = k
  171. }
  172. }
  173. } else {
  174. return "", err
  175. }
  176. }
  177. priv := pk.(*ecdsa.PrivateKey)
  178. content := c.signContent()
  179. hashed := sha256.Sum256([]byte(content))
  180. sig, err := priv.Sign(rand.Reader, hashed[:], crypto.SHA256)
  181. if err != nil {
  182. return "", err
  183. }
  184. return base64.StdEncoding.EncodeToString(sig), nil
  185. }
  186. func base64Encode(data []byte) string {
  187. res := base64.StdEncoding.EncodeToString(data)
  188. for k, v := range tlsReplace {
  189. res = strings.Replace(res, k, v, -1)
  190. }
  191. return res
  192. }
  193. func base64Decode(data string) ([]byte, error) {
  194. for k, v := range tlsReplace {
  195. data = strings.Replace(data, v, k, -1)
  196. }
  197. return base64.StdEncoding.DecodeString(data)
  198. }
  199. func pointsFromDER(der []byte) (R, S *big.Int, err error) {
  200. R, S = &big.Int{}, &big.Int{}
  201. data := asn1.RawValue{}
  202. if _, err = asn1.Unmarshal(der, &data); err != nil {
  203. return
  204. }
  205. // The format of our DER string is 0x02 + rlen + r + 0x02 + slen + s
  206. rLen := data.Bytes[1] // The entire length of R + offset of 2 for 0x02 and rlen
  207. r := data.Bytes[2 : rLen+2]
  208. // Ignore the next 0x02 and slen bytes and just take the start of S to the end of the byte array
  209. s := data.Bytes[rLen+4:]
  210. R.SetBytes(r)
  211. S.SetBytes(s)
  212. return
  213. }