diff --git a/auth/oauth1.go b/auth/oauth1.go new file mode 100644 index 0000000000000000000000000000000000000000..af139761fc5ce34fc52ca047930e30445f6472c3 --- /dev/null +++ b/auth/oauth1.go @@ -0,0 +1,94 @@ +package oauth1 + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "github.com/aws/aws-sdk-go/aws" + "math/rand" + "net/url" + "strconv" + "strings" + "time" +) + +type Oauth1 struct { + ConsumerKey string + ConsumerSecret string + AccessToken string + AccessSecret string +} + +func (auth Oauth1) GenerateAuthorizationHeader(method, requestUrl string) (AuthorizationValue *string, err error) { + + // Take the URL and get all its parts + urlParts, err := url.Parse(requestUrl) + if err != nil { + return nil, err + } + + // Get the query parameters from the URL and make it easier to work with + rawParams, err := url.ParseQuery(urlParts.RawQuery) + if err != nil { + return nil, err + } + params := make(map[string]string) + for key, _ := range rawParams { + params[key] = rawParams.Get(key) + } + + urlValues := url.Values{} + urlValues.Add("oauth_nonce", generateNonce()) + urlValues.Add("oauth_consumer_key", auth.ConsumerKey) + urlValues.Add("oauth_signature_method", "HMAC-SHA256") + urlValues.Add("oauth_timestamp", strconv.Itoa(int(time.Now().Unix()))) + urlValues.Add("oauth_token", auth.AccessToken) + urlValues.Add("oauth_version", "1.0") + + for k, v := range params { + urlValues.Add(k, v) + } + + // If there are any '+' encoded spaces replace them with the proper urlencoded version of a space + parameterString := strings.Replace(urlValues.Encode(), "+", "%20", -1) + + // Build the signature + signatureBase := strings.ToUpper(method) + "&" + url.QueryEscape(strings.Split(requestUrl, "?")[0]) + "&" + url.QueryEscape(parameterString) + signingKey := url.QueryEscape(auth.ConsumerSecret) + "&" + url.QueryEscape(auth.AccessSecret) + signature := calculateSignature(signatureBase, signingKey) + + // Populate all the authorisation parameters + authParams := map[string]string{ + "oauth_consumer_key": url.QueryEscape(urlValues.Get("oauth_consumer_key")), + "oauth_nonce": url.QueryEscape(urlValues.Get("oauth_nonce")), + "oauth_signature": url.QueryEscape(signature), + "oauth_signature_method": url.QueryEscape(urlValues.Get("oauth_signature_method")), + "oauth_timestamp": url.QueryEscape(urlValues.Get("oauth_timestamp")), + "oauth_token": url.QueryEscape(urlValues.Get("oauth_token")), + "oauth_version": url.QueryEscape(urlValues.Get("oauth_version")), + } + + // Convert all the parameters into a comma delimited string with values defined as "" + var AuthorizationString string + for k, v := range authParams { + AuthorizationString += k + "=\"" + v + "\"," + } + + return aws.String("OAuth " + strings.TrimSuffix(AuthorizationString, ",")), nil +} + +func calculateSignature(base, key string) string { + hash := hmac.New(sha256.New, []byte(key)) + hash.Write([]byte(base)) + signature := hash.Sum(nil) + return base64.StdEncoding.EncodeToString(signature) +} + +func generateNonce() string { + const allowed = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, 48) + for i := range b { + b[i] = allowed[rand.Intn(len(allowed))] + } + return string(b) +}