mirror of
https://github.com/securego/gosec.git
synced 2026-01-15 01:33:41 +08:00
feat(sql): enhance SQL injection detection with improved string concatenation checks (#1454)
* feat(sql): enhance SQL injection detection with improved string concatenation checks * optimize: only one ast.Inspect loop, use slices.ContainsFunc * refactor(sql): streamline SQL argument retrieval, replace constObject with TryResolve, minor cleanup * feat(sql): enhance query mutation checks for shadowed variables and add regression tests * remove deprecated ast.Object
This commit is contained in:
68
helpers.go
68
helpers.go
@@ -385,29 +385,44 @@ func GetPkgAbsPath(pkgPath string) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
// ConcatString recursively concatenates strings from a binary expression
|
||||
func ConcatString(n *ast.BinaryExpr) (string, bool) {
|
||||
var s string
|
||||
// sub expressions are found in X object, Y object is always last BasicLit
|
||||
if rightOperand, ok := n.Y.(*ast.BasicLit); ok {
|
||||
if str, err := GetString(rightOperand); err == nil {
|
||||
s = str + s
|
||||
}
|
||||
} else {
|
||||
// ConcatString recursively concatenates constant strings from an expression
|
||||
// if the entire chain is fully constant-derived (using TryResolve).
|
||||
// Returns the concatenated string and true if successful.
|
||||
func ConcatString(expr ast.Expr, ctx *Context) (string, bool) {
|
||||
if expr == nil || !TryResolve(expr, ctx) {
|
||||
return "", false
|
||||
}
|
||||
if leftOperand, ok := n.X.(*ast.BinaryExpr); ok {
|
||||
if recursion, ok := ConcatString(leftOperand); ok {
|
||||
s = recursion + s
|
||||
|
||||
var build strings.Builder
|
||||
var traverse func(ast.Expr) bool
|
||||
traverse = func(e ast.Expr) bool {
|
||||
switch node := e.(type) {
|
||||
case *ast.BasicLit:
|
||||
if str, err := GetString(node); err == nil {
|
||||
build.WriteString(str)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case *ast.Ident:
|
||||
values := GetIdentStringValuesRecursive(node)
|
||||
for _, v := range values {
|
||||
build.WriteString(v)
|
||||
}
|
||||
return len(values) > 0
|
||||
case *ast.BinaryExpr:
|
||||
if node.Op != token.ADD {
|
||||
return false
|
||||
}
|
||||
return traverse(node.X) && traverse(node.Y)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
} else if leftOperand, ok := n.X.(*ast.BasicLit); ok {
|
||||
if str, err := GetString(leftOperand); err == nil {
|
||||
s = str + s
|
||||
}
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
|
||||
if traverse(expr) {
|
||||
return build.String(), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// FindVarIdentities returns array of all variable identities in a given binary expression
|
||||
@@ -574,3 +589,18 @@ func CLIBuildTags(buildTags []string) []string {
|
||||
|
||||
return buildFlags
|
||||
}
|
||||
|
||||
// ContainingFile returns the *ast.File from ctx.PkgFiles that contains the given node.
|
||||
// Returns nil if not found (shouldn't happen for nodes from the analyzed package).
|
||||
func ContainingFile(n ast.Node, ctx *Context) *ast.File {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
pos := n.Pos()
|
||||
for _, f := range ctx.PkgFiles {
|
||||
if f.Pos() <= pos && pos < f.End() {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
376
rules/sql.go
376
rules/sql.go
@@ -17,6 +17,8 @@ package rules
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"regexp"
|
||||
|
||||
"github.com/securego/gosec/v2"
|
||||
@@ -60,33 +62,27 @@ var sqlCallIdents = map[string]map[string]int{
|
||||
},
|
||||
}
|
||||
|
||||
// findQueryArg locates the argument taking raw SQL
|
||||
// findQueryArg locates the argument taking raw SQL.
|
||||
func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) {
|
||||
typeName, fnName, err := gosec.GetCallInfo(call, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i := -1
|
||||
if ni, ok := sqlCallIdents[typeName]; ok {
|
||||
if i, ok = ni[fnName]; !ok {
|
||||
i = -1
|
||||
|
||||
if methods, ok := sqlCallIdents[typeName]; ok {
|
||||
if i, ok := methods[fnName]; ok && i < len(call.Args) {
|
||||
return call.Args[i], nil
|
||||
}
|
||||
}
|
||||
if i == -1 {
|
||||
return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
|
||||
}
|
||||
if i >= len(call.Args) {
|
||||
return nil, nil
|
||||
}
|
||||
query := call.Args[i]
|
||||
return query, nil
|
||||
|
||||
return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
|
||||
}
|
||||
|
||||
func (s *sqlStatement) ID() string {
|
||||
return s.MetaData.ID
|
||||
}
|
||||
|
||||
// See if the string matches the patterns for the statement.
|
||||
// MatchPatterns checks if the string matches all required SQL patterns.
|
||||
func (s *sqlStatement) MatchPatterns(str string) bool {
|
||||
for _, pattern := range s.patterns {
|
||||
if !pattern.MatchString(str) {
|
||||
@@ -104,8 +100,9 @@ func (s *sqlStrConcat) ID() string {
|
||||
return s.MetaData.ID
|
||||
}
|
||||
|
||||
// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections
|
||||
// This method assumes you've already verified that the branch contains SQL syntax
|
||||
// findInjectionInBranch walks through a set of expressions and returns the first
|
||||
// binary expression containing a potential injection (non-constant operand).
|
||||
// This method assumes the branch already contains SQL syntax.
|
||||
func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr {
|
||||
for _, node := range branch {
|
||||
be, ok := node.(*ast.BinaryExpr)
|
||||
@@ -113,114 +110,194 @@ func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Ex
|
||||
continue
|
||||
}
|
||||
|
||||
operands := gosec.GetBinaryExprOperands(be)
|
||||
|
||||
for _, op := range operands {
|
||||
if _, ok := op.(*ast.BasicLit); ok {
|
||||
for _, op := range gosec.GetBinaryExprOperands(be) {
|
||||
if gosec.TryResolve(op, ctx) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) {
|
||||
continue
|
||||
}
|
||||
|
||||
return be
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// see if we can figure out what it is
|
||||
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
|
||||
if n.Obj != nil {
|
||||
return n.Obj.Kind != ast.Var && n.Obj.Kind != ast.Fun
|
||||
}
|
||||
|
||||
// Try to resolve unresolved identifiers using other files in same package
|
||||
for _, file := range c.PkgFiles {
|
||||
if node, ok := file.Scope.Objects[n.String()]; ok {
|
||||
return node.Kind != ast.Var && node.Kind != ast.Fun
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkQuery verifies if the query parameters is a string concatenation
|
||||
// checkQuery verifies if the query parameter involves risky string concatenation.
|
||||
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error) {
|
||||
query, err := findQueryArg(call, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Direct binary concatenation (e.g., "SELECT ..." + tainted)
|
||||
if be, ok := query.(*ast.BinaryExpr); ok {
|
||||
operands := gosec.GetBinaryExprOperands(be)
|
||||
if start, ok := operands[0].(*ast.BasicLit); ok {
|
||||
if str, e := gosec.GetString(start); e == nil {
|
||||
if !s.MatchPatterns(str) {
|
||||
return nil, nil
|
||||
if str, e := gosec.GetString(start); e == nil && s.MatchPatterns(str) {
|
||||
for _, op := range operands[1:] {
|
||||
if gosec.TryResolve(op, ctx) {
|
||||
continue
|
||||
}
|
||||
return ctx.NewIssue(be, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
}
|
||||
for _, op := range operands[1:] {
|
||||
if _, ok := op.(*ast.BasicLit); ok {
|
||||
continue
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Must be an identifier to continue (e.g., var query = ...; query += ...)
|
||||
ident, ok := query.(*ast.Ident)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
v, ok := ctx.Info.ObjectOf(ident).(*types.Var)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Determine search scope (package-level or local)
|
||||
isPkgLevel := ctx.Pkg != nil && v.Parent() == ctx.Pkg.Scope()
|
||||
|
||||
var filesToSearch []*ast.File
|
||||
if isPkgLevel {
|
||||
filesToSearch = ctx.PkgFiles
|
||||
} else {
|
||||
callFile := gosec.ContainingFile(call, ctx)
|
||||
if callFile == nil {
|
||||
return nil, nil
|
||||
}
|
||||
filesToSearch = []*ast.File{callFile}
|
||||
}
|
||||
|
||||
// Find the defining declaration and check for SQL patterns / initial risky concatenation
|
||||
declRHS := []ast.Expr{}
|
||||
foundDecl := false
|
||||
|
||||
// Determine the file containing the variable's defining position
|
||||
var declFile *ast.File
|
||||
if ctx.FileSet != nil {
|
||||
if posFile := ctx.FileSet.File(v.Pos()); posFile != nil {
|
||||
targetName := posFile.Name()
|
||||
for _, f := range filesToSearch {
|
||||
if fileInfo := ctx.FileSet.File(f.Pos()); fileInfo != nil && fileInfo.Name() == targetName {
|
||||
declFile = f
|
||||
break
|
||||
}
|
||||
if op, ok := op.(*ast.Ident); ok && s.checkObject(op, ctx) {
|
||||
continue
|
||||
}
|
||||
return ctx.NewIssue(be, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1"
|
||||
if id, ok := query.(*ast.Ident); ok {
|
||||
var match bool
|
||||
for _, str := range gosec.GetIdentStringValuesRecursive(id) {
|
||||
if s.MatchPatterns(str) {
|
||||
match = true
|
||||
if declFile != nil {
|
||||
ast.Inspect(declFile, func(n ast.Node) bool {
|
||||
switch d := n.(type) {
|
||||
case *ast.ValueSpec:
|
||||
for _, name := range d.Names {
|
||||
if name.Pos() == v.Pos() && ctx.Info.ObjectOf(name) == v {
|
||||
declRHS = d.Values
|
||||
foundDecl = true
|
||||
return false // Stop inspection
|
||||
}
|
||||
}
|
||||
case *ast.AssignStmt:
|
||||
if d.Tok == token.DEFINE { // Only short variable declarations define new vars
|
||||
for _, lhs := range d.Lhs {
|
||||
if id, ok := lhs.(*ast.Ident); ok && id.Pos() == v.Pos() && ctx.Info.ObjectOf(id) == v {
|
||||
declRHS = d.Rhs
|
||||
foundDecl = true
|
||||
return false // Stop inspection
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if foundDecl {
|
||||
// Check for SQL patterns in initial values
|
||||
hasSQLPattern := false
|
||||
for _, val := range declRHS {
|
||||
if str, err := gosec.GetStringRecursive(val); err == nil && s.MatchPatterns(str) {
|
||||
hasSQLPattern = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !match {
|
||||
return nil, nil
|
||||
// Check for risky initial concatenation
|
||||
if inj := s.findInjectionInBranch(ctx, declRHS); inj != nil {
|
||||
return ctx.NewIssue(inj, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
|
||||
switch decl := id.Obj.Decl.(type) {
|
||||
case *ast.AssignStmt:
|
||||
if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil {
|
||||
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
if !hasSQLPattern {
|
||||
return nil, nil
|
||||
}
|
||||
} else {
|
||||
// No defining declaration found → assume not SQL-related
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for risky mutations (query += tainted or query = query + tainted)
|
||||
for _, f := range filesToSearch {
|
||||
var found *ast.AssignStmt
|
||||
ast.Inspect(f, func(n ast.Node) bool {
|
||||
assign, ok := n.(*ast.AssignStmt)
|
||||
if !ok || len(assign.Lhs) != 1 || len(assign.Rhs) != 1 {
|
||||
return true
|
||||
}
|
||||
case *ast.ValueSpec:
|
||||
// handle: var query string = "SELECT ...'" + user
|
||||
if injection := s.findInjectionInBranch(ctx, decl.Values); injection != nil {
|
||||
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
lIdent, ok := assign.Lhs[0].(*ast.Ident)
|
||||
if !ok || ctx.Info.ObjectOf(lIdent) != v {
|
||||
return true
|
||||
}
|
||||
|
||||
var appended ast.Expr
|
||||
switch assign.Tok {
|
||||
case token.ADD_ASSIGN:
|
||||
appended = assign.Rhs[0]
|
||||
case token.ASSIGN:
|
||||
be, ok := assign.Rhs[0].(*ast.BinaryExpr)
|
||||
if !ok || be.Op != token.ADD {
|
||||
return true
|
||||
}
|
||||
left, ok := be.X.(*ast.Ident)
|
||||
if !ok || ctx.Info.ObjectOf(left) != v {
|
||||
return true
|
||||
}
|
||||
appended = be.Y
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
||||
if !gosec.TryResolve(appended, ctx) {
|
||||
found = assign
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if found != nil {
|
||||
return ctx.NewIssue(found, s.ID(), s.What, s.Severity, s.Confidence), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Checks SQL query concatenation issues such as "SELECT * FROM table WHERE " + " ' OR 1=1"
|
||||
// Match looks for SQL execution calls and checks for concatenation issues.
|
||||
func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error) {
|
||||
switch stmt := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range stmt.Rhs {
|
||||
if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
if call, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
|
||||
return s.checkQuery(call, ctx)
|
||||
}
|
||||
}
|
||||
case *ast.ExprStmt:
|
||||
if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
if call, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(call, ctx) != nil {
|
||||
return s.checkQuery(call, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// NewSQLStrConcat looks for cases where we are building SQL strings via concatenation
|
||||
// NewSQLStrConcat creates a rule for detecting SQL string concatenation.
|
||||
func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
rule := &sqlStrConcat{
|
||||
sqlStatement: sqlStatement{
|
||||
@@ -237,9 +314,9 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
},
|
||||
}
|
||||
|
||||
for s, si := range sqlCallIdents {
|
||||
for i := range si {
|
||||
rule.Add(s, i)
|
||||
for typ, methods := range sqlCallIdents {
|
||||
for method := range methods {
|
||||
rule.Add(typ, method)
|
||||
}
|
||||
}
|
||||
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
|
||||
@@ -253,65 +330,77 @@ type sqlStrFormat struct {
|
||||
noIssueQuoted gosec.CallList
|
||||
}
|
||||
|
||||
// see if we can figure out what it is
|
||||
func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
|
||||
n, ok := e.(*ast.Ident)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if n.Obj != nil {
|
||||
return n.Obj.Kind == ast.Con
|
||||
}
|
||||
|
||||
// Try to resolve unresolved identifiers using other files in same package
|
||||
for _, file := range c.PkgFiles {
|
||||
if node, ok := file.Scope.Objects[n.String()]; ok {
|
||||
return node.Kind == ast.Con
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkQuery verifies if the query parameter involves risky formatting.
|
||||
func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error) {
|
||||
query, err := findQueryArg(call, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
|
||||
decl := ident.Obj.Decl
|
||||
if assign, ok := decl.(*ast.AssignStmt); ok {
|
||||
for _, expr := range assign.Rhs {
|
||||
issue := s.checkFormatting(expr, ctx)
|
||||
if issue != nil {
|
||||
return issue, err
|
||||
}
|
||||
}
|
||||
}
|
||||
// Must be a variable identifier (short-declared with :=)
|
||||
ident, ok := query.(*ast.Ident)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
v, ok := ctx.Info.ObjectOf(ident).(*types.Var)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Short variable declarations are always local → use the file containing the call
|
||||
callFile := gosec.ContainingFile(call, ctx)
|
||||
if callFile == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Find the defining short declaration (query := fmt.Sprintf(...))
|
||||
var foundIssue *issue.Issue
|
||||
ast.Inspect(callFile, func(n ast.Node) bool {
|
||||
assign, ok := n.(*ast.AssignStmt)
|
||||
if !ok || assign.Tok != token.DEFINE {
|
||||
return true
|
||||
}
|
||||
|
||||
// Find the LHS identifier that defines this variable
|
||||
for _, lhs := range assign.Lhs {
|
||||
if defIdent, ok := lhs.(*ast.Ident); ok &&
|
||||
defIdent.Pos() == v.Pos() && ctx.Info.ObjectOf(defIdent) == v {
|
||||
|
||||
// Check every initializer expression on the RHS
|
||||
for _, expr := range assign.Rhs {
|
||||
if expr == nil {
|
||||
continue
|
||||
}
|
||||
if iss := s.checkFormatting(expr, ctx); iss != nil {
|
||||
foundIssue = iss
|
||||
return false // Stop entire inspection
|
||||
}
|
||||
}
|
||||
return false // Declaration found and processed
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return foundIssue, nil
|
||||
}
|
||||
|
||||
// checkFormatting checks if a formatting call builds a risky SQL query.
|
||||
func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Issue {
|
||||
// argIndex changes the function argument which gets matched to the regex
|
||||
argIndex := 0
|
||||
if node := s.fmtCalls.ContainsPkgCallExpr(n, ctx, false); node != nil {
|
||||
// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
|
||||
if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
|
||||
if sel.Sel.Name == "Fprintf" {
|
||||
// if os.Stderr or os.Stdout is in Arg[0], mark as no issue
|
||||
if arg, ok := node.Args[0].(*ast.SelectorExpr); ok {
|
||||
if ident, ok := arg.X.(*ast.Ident); ok {
|
||||
if s.noIssue.Contains(ident.Name, arg.Sel.Name) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if sel, ok := node.Fun.(*ast.SelectorExpr); ok && sel.Sel.Name == "Fprintf" {
|
||||
// if os.Stderr or os.Stdout is in Arg[0], mark as no issue
|
||||
if arg, ok := node.Args[0].(*ast.SelectorExpr); ok {
|
||||
if ident, ok := arg.X.(*ast.Ident); ok && s.noIssue.Contains(ident.Name, arg.Sel.Name) {
|
||||
return nil
|
||||
}
|
||||
// the function is Fprintf so set argIndex = 1
|
||||
argIndex = 1
|
||||
}
|
||||
// the function is Fprintf so set argIndex = 1
|
||||
argIndex = 1
|
||||
}
|
||||
|
||||
// no formatter
|
||||
@@ -319,17 +408,8 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is
|
||||
return nil
|
||||
}
|
||||
|
||||
var formatter string
|
||||
|
||||
// concats callexpr arg strings together if needed before regex evaluation
|
||||
if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok {
|
||||
if fullStr, ok := gosec.ConcatString(argExpr); ok {
|
||||
formatter = fullStr
|
||||
}
|
||||
} else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil {
|
||||
formatter = arg
|
||||
}
|
||||
if len(formatter) <= 0 {
|
||||
formatter, ok := gosec.ConcatString(node.Args[argIndex], ctx)
|
||||
if !ok || formatter == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -337,7 +417,7 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is
|
||||
if argIndex+1 < len(node.Args) {
|
||||
allSafe := true
|
||||
for _, arg := range node.Args[argIndex+1:] {
|
||||
if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true); n == nil && !s.constObject(arg, ctx) {
|
||||
if s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true) == nil && !gosec.TryResolve(arg, ctx) {
|
||||
allSafe = false
|
||||
break
|
||||
}
|
||||
@@ -346,6 +426,7 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.MatchPatterns(formatter) {
|
||||
return ctx.NewIssue(n, s.ID(), s.What, s.Severity, s.Confidence)
|
||||
}
|
||||
@@ -353,37 +434,31 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check SQL query formatting issues such as "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
|
||||
// Match looks for SQL calls involving formatted strings.
|
||||
func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error) {
|
||||
switch stmt := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range stmt.Rhs {
|
||||
if call, ok := expr.(*ast.CallExpr); ok {
|
||||
selector, ok := call.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
sqlQueryCall, ok := selector.X.(*ast.CallExpr)
|
||||
if ok && s.ContainsCallExpr(sqlQueryCall, ctx) != nil {
|
||||
issue, err := s.checkQuery(sqlQueryCall, ctx)
|
||||
if err == nil && issue != nil {
|
||||
return issue, err
|
||||
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
|
||||
if sqlCall, ok := sel.X.(*ast.CallExpr); ok && s.ContainsCallExpr(sqlCall, ctx) != nil {
|
||||
return s.checkQuery(sqlCall, ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
if s.ContainsCallExpr(expr, ctx) != nil {
|
||||
return s.checkQuery(call, ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
case *ast.ExprStmt:
|
||||
if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
|
||||
return s.checkQuery(sqlQueryCall, ctx)
|
||||
if call, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(call, ctx) != nil {
|
||||
return s.checkQuery(call, ctx)
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
|
||||
// NewSQLStrFormat creates a rule for detecting SQL string formatting.
|
||||
func NewSQLStrFormat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
rule := &sqlStrFormat{
|
||||
CallList: gosec.NewCallList(),
|
||||
@@ -403,14 +478,13 @@ func NewSQLStrFormat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
|
||||
},
|
||||
},
|
||||
}
|
||||
for s, si := range sqlCallIdents {
|
||||
for i := range si {
|
||||
rule.Add(s, i)
|
||||
for typ, methods := range sqlCallIdents {
|
||||
for method := range methods {
|
||||
rule.Add(typ, method)
|
||||
}
|
||||
}
|
||||
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
|
||||
rule.noIssue.AddAll("os", "Stdout", "Stderr")
|
||||
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
|
||||
|
||||
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
|
||||
}
|
||||
|
||||
@@ -427,5 +427,147 @@ func main() {
|
||||
}
|
||||
defer stmt.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// Safe verb (%d) with tainted input - no string injection risk
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
id, _ := strconv.Atoi(os.Args[1]) // tainted but used with %d
|
||||
q := fmt.Sprintf("SELECT * FROM foo WHERE id = %d", id)
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// Mixed args: unsafe %s (tainted) + safe %d (constant)
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := fmt.Sprintf("SELECT * FROM %s WHERE id = %d", os.Args[1], 42) // tainted table + safe int
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 1, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// All args constant but unsafe verb present - safe
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const name = "admin"
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := fmt.Sprintf("SELECT * FROM users WHERE name = '%s'", name)
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// Formatter from concatenation - risky
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
base := "SELECT * FROM foo WHERE"
|
||||
q := fmt.Sprintf(base + " name = '%s'", os.Args[1])
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 1, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// No unsafe % verb but SQL pattern + tainted concat - G202, not G201
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := "SELECT * FROM foo WHERE name = " + os.Args[1] // concat, no %
|
||||
rows, err := db.Query(q)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()}, // G201 should NOT flag (G202 does)
|
||||
{[]string{`
|
||||
// Fprintf to os.Stderr - no issue
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
q := fmt.Sprintf("SELECT * FROM foo WHERE name = '%s'", os.Args[1])
|
||||
fmt.Fprintf(os.Stderr, "Debug query: %s\n", q) // log, not exec
|
||||
rows, err := db.Query("SELECT * FROM foo")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
}
|
||||
|
||||
@@ -335,5 +335,163 @@ func main() {
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 1, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
query := "SELECT * FROM album WHERE id = "
|
||||
query += os.Args[0]
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 1, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
query := "SELECT * FROM album WHERE id = "
|
||||
query += os.Args[0]
|
||||
fmt.Println(query)
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
query := "SELECT * FROM album WHERE id = "
|
||||
query = query + os.Args[0] // risky reassignment concatenation
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 1, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
query := "SELECT * FROM album WHERE id = "
|
||||
query = query + "42" // safe literal reassignment concatenation
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// Shadowing edge case: tainted mutation on shadowed variable - should NOT flag
|
||||
// The outer 'query' is safe and passed to db.Query.
|
||||
// The inner shadowed 'query' is mutated with tainted input (irrelevant).
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
query := "SELECT * FROM foo WHERE id = 42" // safe outer query
|
||||
{
|
||||
query := "base" // shadows outer query
|
||||
query += os.Args[1] // tainted mutation on shadow - should be ignored
|
||||
_ = query // prevent unused warning
|
||||
}
|
||||
rows, err := db.Query(query) // uses safe outer query
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// Shadowing edge case: no mutation on shadow, safe outer - regression guard
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
query := "SELECT * FROM foo WHERE id = 42"
|
||||
{
|
||||
query := "shadowed but unused"
|
||||
_ = query
|
||||
}
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
}
|
||||
`}, 0, gosec.NewConfig()},
|
||||
{[]string{`
|
||||
// package-level SQL string with tainted concatenation in init()
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
var query string = "SELECT * FROM foo WHERE name = "
|
||||
|
||||
func init() {
|
||||
query += os.Args[1]
|
||||
}
|
||||
`, `
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := sql.Open("sqlite3", ":memory:")
|
||||
_, _ = db.Query(query)
|
||||
}
|
||||
`}, 1, gosec.NewConfig()},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user