Skip to content

Commit 73679fa

Browse files
authored
fix: parameterize CQL/N1QL queries to prevent injection (#500)
* fix: parameterize CQL/N1QL queries to prevent injection (CWE-943, CWE-209) Replace all fmt.Sprintf string interpolation of user-controlled values with parameterized queries across Cassandra (10 files, 66+ queries) and Couchbase (2 files, 3 queries) backends. Cassandra: use gocql ? placeholders with values passed to Query(). Couchbase: use $param named parameters with NamedParameters option. * fix: convert json.Number to native types before passing to gocql The JSON decoder with UseNumber() produces json.Number values which gocql cannot marshal into CQL bigint columns. Add convertMapValues() helper to convert json.Number to int64/float64, called after every JSON map decode in dynamic INSERT/UPDATE queries.
1 parent 109ae0a commit 73679fa

13 files changed

Lines changed: 281 additions & 259 deletions

File tree

internal/storage/db/cassandradb/authenticator.go

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"reflect"
87
"strings"
98
"time"
109

@@ -40,31 +39,28 @@ func (p *provider) AddAuthenticator(ctx context.Context, authenticators *schemas
4039
if err != nil {
4140
return nil, err
4241
}
42+
convertMapValues(authenticatorsMap)
4343

4444
fields := "("
45-
values := "("
45+
placeholders := "("
46+
var insertValues []interface{}
4647
for key, value := range authenticatorsMap {
4748
if value != nil {
4849
if key == "_id" {
4950
fields += "id,"
5051
} else {
5152
fields += key + ","
5253
}
53-
54-
valueType := reflect.TypeOf(value)
55-
if valueType.Name() == "string" {
56-
values += fmt.Sprintf("'%s',", value.(string))
57-
} else {
58-
values += fmt.Sprintf("%v,", value)
59-
}
54+
placeholders += "?,"
55+
insertValues = append(insertValues, value)
6056
}
6157
}
6258

6359
fields = fields[:len(fields)-1] + ")"
64-
values = values[:len(values)-1] + ")"
60+
placeholders = placeholders[:len(placeholders)-1] + ")"
6561

66-
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+schemas.Collections.Authenticators, fields, values)
67-
err = p.db.Query(query).Exec()
62+
query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+schemas.Collections.Authenticators, fields, placeholders)
63+
err = p.db.Query(query, insertValues...).Exec()
6864
if err != nil {
6965
return nil, err
7066
}
@@ -87,8 +83,10 @@ func (p *provider) UpdateAuthenticator(ctx context.Context, authenticators *sche
8783
if err != nil {
8884
return nil, err
8985
}
86+
convertMapValues(authenticatorsMap)
9087

9188
updateFields := ""
89+
var updateValues []interface{}
9290
for key, value := range authenticatorsMap {
9391
if key == "_id" {
9492
continue
@@ -103,18 +101,15 @@ func (p *provider) UpdateAuthenticator(ctx context.Context, authenticators *sche
103101
continue
104102
}
105103

106-
valueType := reflect.TypeOf(value)
107-
if valueType.Name() == "string" {
108-
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
109-
} else {
110-
updateFields += fmt.Sprintf("%s = %v, ", key, value)
111-
}
104+
updateFields += fmt.Sprintf("%s = ?, ", key)
105+
updateValues = append(updateValues, value)
112106
}
113107
updateFields = strings.Trim(updateFields, " ")
114108
updateFields = strings.TrimSuffix(updateFields, ",")
115109

116-
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+schemas.Collections.Authenticators, updateFields, authenticators.ID)
117-
err = p.db.Query(query).Exec()
110+
updateValues = append(updateValues, authenticators.ID)
111+
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = ?", KeySpace+"."+schemas.Collections.Authenticators, updateFields)
112+
err = p.db.Query(query, updateValues...).Exec()
118113
if err != nil {
119114
return nil, err
120115
}
@@ -124,8 +119,8 @@ func (p *provider) UpdateAuthenticator(ctx context.Context, authenticators *sche
124119

125120
func (p *provider) GetAuthenticatorDetailsByUserId(ctx context.Context, userId string, authenticatorType string) (*schemas.Authenticator, error) {
126121
var authenticators schemas.Authenticator
127-
query := fmt.Sprintf("SELECT id, user_id, method, secret, recovery_codes, verified_at, created_at, updated_at FROM %s WHERE user_id = '%s' AND method = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+schemas.Collections.Authenticators, userId, authenticatorType)
128-
err := p.db.Query(query).Consistency(gocql.One).Scan(&authenticators.ID, &authenticators.UserID, &authenticators.Method, &authenticators.Secret, &authenticators.RecoveryCodes, &authenticators.VerifiedAt, &authenticators.CreatedAt, &authenticators.UpdatedAt)
122+
query := fmt.Sprintf("SELECT id, user_id, method, secret, recovery_codes, verified_at, created_at, updated_at FROM %s WHERE user_id = ? AND method = ? LIMIT 1 ALLOW FILTERING", KeySpace+"."+schemas.Collections.Authenticators)
123+
err := p.db.Query(query, userId, authenticatorType).Consistency(gocql.One).Scan(&authenticators.ID, &authenticators.UserID, &authenticators.Method, &authenticators.Secret, &authenticators.RecoveryCodes, &authenticators.VerifiedAt, &authenticators.CreatedAt, &authenticators.UpdatedAt)
129124
if err != nil {
130125
return nil, err
131126
}

internal/storage/db/cassandradb/email_template.go

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"reflect"
87
"strings"
98
"time"
109

@@ -27,8 +26,8 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *schemas.
2726
if existingEmailTemplate != nil {
2827
return nil, fmt.Errorf("email template with %s event_name already exists", emailTemplate.EventName)
2928
}
30-
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, subject, design, template, created_at, updated_at) VALUES ('%s', '%s', '%s','%s','%s', %d, %d)", KeySpace+"."+schemas.Collections.EmailTemplate, emailTemplate.ID, emailTemplate.EventName, emailTemplate.Subject, emailTemplate.Design, emailTemplate.Template, emailTemplate.CreatedAt, emailTemplate.UpdatedAt)
31-
err := p.db.Query(insertQuery).Exec()
29+
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, subject, design, template, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", KeySpace+"."+schemas.Collections.EmailTemplate)
30+
err := p.db.Query(insertQuery, emailTemplate.ID, emailTemplate.EventName, emailTemplate.Subject, emailTemplate.Design, emailTemplate.Template, emailTemplate.CreatedAt, emailTemplate.UpdatedAt).Exec()
3231
if err != nil {
3332
return nil, err
3433
}
@@ -50,7 +49,9 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schem
5049
if err != nil {
5150
return nil, err
5251
}
52+
convertMapValues(emailTemplateMap)
5353
updateFields := ""
54+
var updateValues []interface{}
5455
for key, value := range emailTemplateMap {
5556
if key == "_id" {
5657
continue
@@ -62,18 +63,15 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schem
6263
updateFields += fmt.Sprintf("%s = null,", key)
6364
continue
6465
}
65-
valueType := reflect.TypeOf(value)
66-
if valueType.Name() == "string" {
67-
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
68-
} else {
69-
updateFields += fmt.Sprintf("%s = %v, ", key, value)
70-
}
66+
updateFields += fmt.Sprintf("%s = ?, ", key)
67+
updateValues = append(updateValues, value)
7168
}
7269
updateFields = strings.Trim(updateFields, " ")
7370
updateFields = strings.TrimSuffix(updateFields, ",")
7471

75-
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+schemas.Collections.EmailTemplate, updateFields, emailTemplate.ID)
76-
err = p.db.Query(query).Exec()
72+
updateValues = append(updateValues, emailTemplate.ID)
73+
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = ?", KeySpace+"."+schemas.Collections.EmailTemplate, updateFields)
74+
err = p.db.Query(query, updateValues...).Exec()
7775
if err != nil {
7876
return nil, err
7977
}
@@ -116,8 +114,8 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagi
116114
// GetEmailTemplateByID to get EmailTemplate by id
117115
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*schemas.EmailTemplate, error) {
118116
var emailTemplate schemas.EmailTemplate
119-
query := fmt.Sprintf(`SELECT id, event_name, subject, design, template, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+schemas.Collections.EmailTemplate, emailTemplateID)
120-
err := p.db.Query(query).Consistency(gocql.One).Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Subject, &emailTemplate.Design, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
117+
query := fmt.Sprintf(`SELECT id, event_name, subject, design, template, created_at, updated_at FROM %s WHERE id = ? LIMIT 1`, KeySpace+"."+schemas.Collections.EmailTemplate)
118+
err := p.db.Query(query, emailTemplateID).Consistency(gocql.One).Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Subject, &emailTemplate.Design, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
121119
if err != nil {
122120
return nil, err
123121
}
@@ -127,8 +125,8 @@ func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID str
127125
// GetEmailTemplateByEventName to get EmailTemplate by event_name
128126
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*schemas.EmailTemplate, error) {
129127
var emailTemplate schemas.EmailTemplate
130-
query := fmt.Sprintf(`SELECT id, event_name, subject, design, template, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.EmailTemplate, eventName)
131-
err := p.db.Query(query).Consistency(gocql.One).Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Subject, &emailTemplate.Design, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
128+
query := fmt.Sprintf(`SELECT id, event_name, subject, design, template, created_at, updated_at FROM %s WHERE event_name = ? LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.EmailTemplate)
129+
err := p.db.Query(query, eventName).Consistency(gocql.One).Scan(&emailTemplate.ID, &emailTemplate.EventName, &emailTemplate.Subject, &emailTemplate.Design, &emailTemplate.Template, &emailTemplate.CreatedAt, &emailTemplate.UpdatedAt)
132130
if err != nil {
133131
return nil, err
134132
}
@@ -137,8 +135,8 @@ func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName st
137135

138136
// DeleteEmailTemplate to delete EmailTemplate
139137
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) error {
140-
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+schemas.Collections.EmailTemplate, emailTemplate.ID)
141-
err := p.db.Query(query).Exec()
138+
query := fmt.Sprintf("DELETE FROM %s WHERE id = ?", KeySpace+"."+schemas.Collections.EmailTemplate)
139+
err := p.db.Query(query, emailTemplate.ID).Exec()
142140
if err != nil {
143141
return err
144142
}

internal/storage/db/cassandradb/env.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ func (p *provider) AddEnv(ctx context.Context, env *schemas.Env) (*schemas.Env,
1818
}
1919
env.CreatedAt = time.Now().Unix()
2020
env.UpdatedAt = time.Now().Unix()
21-
insertEnvQuery := fmt.Sprintf("INSERT INTO %s (id, env, hash, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d)", KeySpace+"."+schemas.Collections.Env, env.ID, env.EnvData, env.Hash, env.CreatedAt, env.UpdatedAt)
22-
err := p.db.Query(insertEnvQuery).Exec()
21+
insertEnvQuery := fmt.Sprintf("INSERT INTO %s (id, env, hash, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", KeySpace+"."+schemas.Collections.Env)
22+
err := p.db.Query(insertEnvQuery, env.ID, env.EnvData, env.Hash, env.CreatedAt, env.UpdatedAt).Exec()
2323
if err != nil {
2424
return nil, err
2525
}
@@ -30,8 +30,8 @@ func (p *provider) AddEnv(ctx context.Context, env *schemas.Env) (*schemas.Env,
3030
// UpdateEnv to update environment information in database
3131
func (p *provider) UpdateEnv(ctx context.Context, env *schemas.Env) (*schemas.Env, error) {
3232
env.UpdatedAt = time.Now().Unix()
33-
updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = '%s', updated_at = %d WHERE id = '%s'", KeySpace+"."+schemas.Collections.Env, env.EnvData, env.UpdatedAt, env.ID)
34-
err := p.db.Query(updateEnvQuery).Exec()
33+
updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = ?, updated_at = ? WHERE id = ?", KeySpace+"."+schemas.Collections.Env)
34+
err := p.db.Query(updateEnvQuery, env.EnvData, env.UpdatedAt, env.ID).Exec()
3535
if err != nil {
3636
return nil, err
3737
}

internal/storage/db/cassandradb/otp.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schem
4848
otp.UpdatedAt = time.Now().Unix()
4949
query := ""
5050
if shouldCreate {
51-
query = fmt.Sprintf(`INSERT INTO %s (id, email, phone_number, otp, expires_at, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %d, %d, %d)`, KeySpace+"."+schemas.Collections.OTP, otp.ID, otp.Email, otp.PhoneNumber, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt)
51+
query = fmt.Sprintf(`INSERT INTO %s (id, email, phone_number, otp, expires_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, KeySpace+"."+schemas.Collections.OTP)
52+
err := p.db.Query(query, otp.ID, otp.Email, otp.PhoneNumber, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt).Exec()
53+
if err != nil {
54+
return nil, err
55+
}
5256
} else {
53-
query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE id = '%s'`, KeySpace+"."+schemas.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID)
54-
}
55-
56-
err := p.db.Query(query).Exec()
57-
if err != nil {
58-
return nil, err
57+
query = fmt.Sprintf(`UPDATE %s SET otp = ?, expires_at = ?, updated_at = ? WHERE id = ?`, KeySpace+"."+schemas.Collections.OTP)
58+
err := p.db.Query(query, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID).Exec()
59+
if err != nil {
60+
return nil, err
61+
}
5962
}
6063

6164
return otp, nil
@@ -64,8 +67,8 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schem
6467
// GetOTPByEmail to get otp for a given email address
6568
func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*schemas.OTP, error) {
6669
var otp schemas.OTP
67-
query := fmt.Sprintf(`SELECT id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.OTP, emailAddress)
68-
err := p.db.Query(query).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.PhoneNumber, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt)
70+
query := fmt.Sprintf(`SELECT id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s WHERE email = ? LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.OTP)
71+
err := p.db.Query(query, emailAddress).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.PhoneNumber, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt)
6972
if err != nil {
7073
return nil, err
7174
}
@@ -75,8 +78,8 @@ func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*sch
7578
// GetOTPByPhoneNumber to get otp for a given phone number
7679
func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*schemas.OTP, error) {
7780
var otp schemas.OTP
78-
query := fmt.Sprintf(`SELECT id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s WHERE phone_number = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.OTP, phoneNumber)
79-
err := p.db.Query(query).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.PhoneNumber, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt)
81+
query := fmt.Sprintf(`SELECT id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s WHERE phone_number = ? LIMIT 1 ALLOW FILTERING`, KeySpace+"."+schemas.Collections.OTP)
82+
err := p.db.Query(query, phoneNumber).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.PhoneNumber, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt)
8083
if err != nil {
8184
return nil, err
8285
}
@@ -85,8 +88,8 @@ func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string)
8588

8689
// DeleteOTP to delete otp
8790
func (p *provider) DeleteOTP(ctx context.Context, otp *schemas.OTP) error {
88-
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+schemas.Collections.OTP, otp.ID)
89-
err := p.db.Query(query).Exec()
91+
query := fmt.Sprintf("DELETE FROM %s WHERE id = ?", KeySpace+"."+schemas.Collections.OTP)
92+
err := p.db.Query(query, otp.ID).Exec()
9093
if err != nil {
9194
return err
9295
}

internal/storage/db/cassandradb/provider.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cassandradb
33
import (
44
"crypto/tls"
55
"crypto/x509"
6+
"encoding/json"
67
"fmt"
78
"strings"
89
"time"
@@ -350,3 +351,17 @@ func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) {
350351
db: session,
351352
}, err
352353
}
354+
355+
// convertMapValues converts json.Number values in a map to native Go types
356+
// (int64 or float64) so gocql can marshal them into CQL bigint/double columns.
357+
func convertMapValues(m map[string]interface{}) {
358+
for key, value := range m {
359+
if num, ok := value.(json.Number); ok {
360+
if i, err := num.Int64(); err == nil {
361+
m[key] = i
362+
} else if f, err := num.Float64(); err == nil {
363+
m[key] = f
364+
}
365+
}
366+
}
367+
}

internal/storage/db/cassandradb/session.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ func (p *provider) AddSession(ctx context.Context, session *schemas.Session) err
1717
}
1818
session.CreatedAt = time.Now().Unix()
1919
session.UpdatedAt = time.Now().Unix()
20-
insertSessionQuery := fmt.Sprintf("INSERT INTO %s (id, user_id, user_agent, ip, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %d, %d)", KeySpace+"."+schemas.Collections.Session, session.ID, session.UserID, session.UserAgent, session.IP, session.CreatedAt, session.UpdatedAt)
21-
err := p.db.Query(insertSessionQuery).Exec()
20+
insertSessionQuery := fmt.Sprintf("INSERT INTO %s (id, user_id, user_agent, ip, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)", KeySpace+"."+schemas.Collections.Session)
21+
err := p.db.Query(insertSessionQuery, session.ID, session.UserID, session.UserAgent, session.IP, session.CreatedAt, session.UpdatedAt).Exec()
2222
if err != nil {
2323
return err
2424
}

0 commit comments

Comments
 (0)