diff --git a/helpers.go b/helpers.go index 7f5724b..feb74c5 100644 --- a/helpers.go +++ b/helpers.go @@ -385,29 +385,44 @@ func GetPkgAbsPath(pkgPath string) (string, error) { return absPath, nil } -// ConcatString recursively concatenates strings from a binary expression -func ConcatString(n *ast.BinaryExpr) (string, bool) { - var s string - // sub expressions are found in X object, Y object is always last BasicLit - if rightOperand, ok := n.Y.(*ast.BasicLit); ok { - if str, err := GetString(rightOperand); err == nil { - s = str + s - } - } else { +// ConcatString recursively concatenates constant strings from an expression +// if the entire chain is fully constant-derived (using TryResolve). +// Returns the concatenated string and true if successful. +func ConcatString(expr ast.Expr, ctx *Context) (string, bool) { + if expr == nil || !TryResolve(expr, ctx) { return "", false } - if leftOperand, ok := n.X.(*ast.BinaryExpr); ok { - if recursion, ok := ConcatString(leftOperand); ok { - s = recursion + s + + var build strings.Builder + var traverse func(ast.Expr) bool + traverse = func(e ast.Expr) bool { + switch node := e.(type) { + case *ast.BasicLit: + if str, err := GetString(node); err == nil { + build.WriteString(str) + return true + } + return false + case *ast.Ident: + values := GetIdentStringValuesRecursive(node) + for _, v := range values { + build.WriteString(v) + } + return len(values) > 0 + case *ast.BinaryExpr: + if node.Op != token.ADD { + return false + } + return traverse(node.X) && traverse(node.Y) + default: + return false } - } else if leftOperand, ok := n.X.(*ast.BasicLit); ok { - if str, err := GetString(leftOperand); err == nil { - s = str + s - } - } else { - return "", false } - return s, true + + if traverse(expr) { + return build.String(), true + } + return "", false } // FindVarIdentities returns array of all variable identities in a given binary expression @@ -574,3 +589,18 @@ func CLIBuildTags(buildTags []string) []string { return buildFlags } + +// ContainingFile returns the *ast.File from ctx.PkgFiles that contains the given node. +// Returns nil if not found (shouldn't happen for nodes from the analyzed package). +func ContainingFile(n ast.Node, ctx *Context) *ast.File { + if n == nil { + return nil + } + pos := n.Pos() + for _, f := range ctx.PkgFiles { + if f.Pos() <= pos && pos < f.End() { + return f + } + } + return nil +} diff --git a/rules/sql.go b/rules/sql.go index 622c2fe..0801339 100644 --- a/rules/sql.go +++ b/rules/sql.go @@ -17,6 +17,8 @@ package rules import ( "fmt" "go/ast" + "go/token" + "go/types" "regexp" "github.com/securego/gosec/v2" @@ -60,33 +62,27 @@ var sqlCallIdents = map[string]map[string]int{ }, } -// findQueryArg locates the argument taking raw SQL +// findQueryArg locates the argument taking raw SQL. func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) { typeName, fnName, err := gosec.GetCallInfo(call, ctx) if err != nil { return nil, err } - i := -1 - if ni, ok := sqlCallIdents[typeName]; ok { - if i, ok = ni[fnName]; !ok { - i = -1 + + if methods, ok := sqlCallIdents[typeName]; ok { + if i, ok := methods[fnName]; ok && i < len(call.Args) { + return call.Args[i], nil } } - if i == -1 { - return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName) - } - if i >= len(call.Args) { - return nil, nil - } - query := call.Args[i] - return query, nil + + return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName) } func (s *sqlStatement) ID() string { return s.MetaData.ID } -// See if the string matches the patterns for the statement. +// MatchPatterns checks if the string matches all required SQL patterns. func (s *sqlStatement) MatchPatterns(str string) bool { for _, pattern := range s.patterns { if !pattern.MatchString(str) { @@ -104,8 +100,9 @@ func (s *sqlStrConcat) ID() string { return s.MetaData.ID } -// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections -// This method assumes you've already verified that the branch contains SQL syntax +// findInjectionInBranch walks through a set of expressions and returns the first +// binary expression containing a potential injection (non-constant operand). +// This method assumes the branch already contains SQL syntax. func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr { for _, node := range branch { be, ok := node.(*ast.BinaryExpr) @@ -113,114 +110,194 @@ func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Ex continue } - operands := gosec.GetBinaryExprOperands(be) - - for _, op := range operands { - if _, ok := op.(*ast.BasicLit); ok { + for _, op := range gosec.GetBinaryExprOperands(be) { + if gosec.TryResolve(op, ctx) { continue } - - if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) { - continue - } - return be } } return nil } -// see if we can figure out what it is -func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool { - if n.Obj != nil { - return n.Obj.Kind != ast.Var && n.Obj.Kind != ast.Fun - } - - // Try to resolve unresolved identifiers using other files in same package - for _, file := range c.PkgFiles { - if node, ok := file.Scope.Objects[n.String()]; ok { - return node.Kind != ast.Var && node.Kind != ast.Fun - } - } - return false -} - -// checkQuery verifies if the query parameters is a string concatenation +// checkQuery verifies if the query parameter involves risky string concatenation. func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error) { query, err := findQueryArg(call, ctx) if err != nil { return nil, err } + // Direct binary concatenation (e.g., "SELECT ..." + tainted) if be, ok := query.(*ast.BinaryExpr); ok { operands := gosec.GetBinaryExprOperands(be) if start, ok := operands[0].(*ast.BasicLit); ok { - if str, e := gosec.GetString(start); e == nil { - if !s.MatchPatterns(str) { - return nil, nil + if str, e := gosec.GetString(start); e == nil && s.MatchPatterns(str) { + for _, op := range operands[1:] { + if gosec.TryResolve(op, ctx) { + continue + } + return ctx.NewIssue(be, s.ID(), s.What, s.Severity, s.Confidence), nil } } - for _, op := range operands[1:] { - if _, ok := op.(*ast.BasicLit); ok { - continue + } + return nil, nil + } + + // Must be an identifier to continue (e.g., var query = ...; query += ...) + ident, ok := query.(*ast.Ident) + if !ok { + return nil, nil + } + + v, ok := ctx.Info.ObjectOf(ident).(*types.Var) + if !ok { + return nil, nil + } + + // Determine search scope (package-level or local) + isPkgLevel := ctx.Pkg != nil && v.Parent() == ctx.Pkg.Scope() + + var filesToSearch []*ast.File + if isPkgLevel { + filesToSearch = ctx.PkgFiles + } else { + callFile := gosec.ContainingFile(call, ctx) + if callFile == nil { + return nil, nil + } + filesToSearch = []*ast.File{callFile} + } + + // Find the defining declaration and check for SQL patterns / initial risky concatenation + declRHS := []ast.Expr{} + foundDecl := false + + // Determine the file containing the variable's defining position + var declFile *ast.File + if ctx.FileSet != nil { + if posFile := ctx.FileSet.File(v.Pos()); posFile != nil { + targetName := posFile.Name() + for _, f := range filesToSearch { + if fileInfo := ctx.FileSet.File(f.Pos()); fileInfo != nil && fileInfo.Name() == targetName { + declFile = f + break } - if op, ok := op.(*ast.Ident); ok && s.checkObject(op, ctx) { - continue - } - return ctx.NewIssue(be, s.ID(), s.What, s.Severity, s.Confidence), nil } } } - // Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" - if id, ok := query.(*ast.Ident); ok { - var match bool - for _, str := range gosec.GetIdentStringValuesRecursive(id) { - if s.MatchPatterns(str) { - match = true + if declFile != nil { + ast.Inspect(declFile, func(n ast.Node) bool { + switch d := n.(type) { + case *ast.ValueSpec: + for _, name := range d.Names { + if name.Pos() == v.Pos() && ctx.Info.ObjectOf(name) == v { + declRHS = d.Values + foundDecl = true + return false // Stop inspection + } + } + case *ast.AssignStmt: + if d.Tok == token.DEFINE { // Only short variable declarations define new vars + for _, lhs := range d.Lhs { + if id, ok := lhs.(*ast.Ident); ok && id.Pos() == v.Pos() && ctx.Info.ObjectOf(id) == v { + declRHS = d.Rhs + foundDecl = true + return false // Stop inspection + } + } + } + } + return true + }) + } + + if foundDecl { + // Check for SQL patterns in initial values + hasSQLPattern := false + for _, val := range declRHS { + if str, err := gosec.GetStringRecursive(val); err == nil && s.MatchPatterns(str) { + hasSQLPattern = true break } } - if !match { - return nil, nil + // Check for risky initial concatenation + if inj := s.findInjectionInBranch(ctx, declRHS); inj != nil { + return ctx.NewIssue(inj, s.ID(), s.What, s.Severity, s.Confidence), nil } - switch decl := id.Obj.Decl.(type) { - case *ast.AssignStmt: - if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil { - return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil + if !hasSQLPattern { + return nil, nil + } + } else { + // No defining declaration found → assume not SQL-related + return nil, nil + } + + // Check for risky mutations (query += tainted or query = query + tainted) + for _, f := range filesToSearch { + var found *ast.AssignStmt + ast.Inspect(f, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok || len(assign.Lhs) != 1 || len(assign.Rhs) != 1 { + return true } - case *ast.ValueSpec: - // handle: var query string = "SELECT ...'" + user - if injection := s.findInjectionInBranch(ctx, decl.Values); injection != nil { - return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil + lIdent, ok := assign.Lhs[0].(*ast.Ident) + if !ok || ctx.Info.ObjectOf(lIdent) != v { + return true } + + var appended ast.Expr + switch assign.Tok { + case token.ADD_ASSIGN: + appended = assign.Rhs[0] + case token.ASSIGN: + be, ok := assign.Rhs[0].(*ast.BinaryExpr) + if !ok || be.Op != token.ADD { + return true + } + left, ok := be.X.(*ast.Ident) + if !ok || ctx.Info.ObjectOf(left) != v { + return true + } + appended = be.Y + default: + return true + } + + if !gosec.TryResolve(appended, ctx) { + found = assign + return false + } + return true + }) + if found != nil { + return ctx.NewIssue(found, s.ID(), s.What, s.Severity, s.Confidence), nil } } return nil, nil } -// Checks SQL query concatenation issues such as "SELECT * FROM table WHERE " + " ' OR 1=1" +// Match looks for SQL execution calls and checks for concatenation issues. func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error) { switch stmt := n.(type) { case *ast.AssignStmt: for _, expr := range stmt.Rhs { - if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil { - return s.checkQuery(sqlQueryCall, ctx) + if call, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil { + return s.checkQuery(call, ctx) } } case *ast.ExprStmt: - if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil { - return s.checkQuery(sqlQueryCall, ctx) + if call, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(call, ctx) != nil { + return s.checkQuery(call, ctx) } } - return nil, nil } -// NewSQLStrConcat looks for cases where we are building SQL strings via concatenation +// NewSQLStrConcat creates a rule for detecting SQL string concatenation. func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) { rule := &sqlStrConcat{ sqlStatement: sqlStatement{ @@ -237,9 +314,9 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) { }, } - for s, si := range sqlCallIdents { - for i := range si { - rule.Add(s, i) + for typ, methods := range sqlCallIdents { + for method := range methods { + rule.Add(typ, method) } } return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} @@ -253,65 +330,77 @@ type sqlStrFormat struct { noIssueQuoted gosec.CallList } -// see if we can figure out what it is -func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool { - n, ok := e.(*ast.Ident) - if !ok { - return false - } - - if n.Obj != nil { - return n.Obj.Kind == ast.Con - } - - // Try to resolve unresolved identifiers using other files in same package - for _, file := range c.PkgFiles { - if node, ok := file.Scope.Objects[n.String()]; ok { - return node.Kind == ast.Con - } - } - return false -} - +// checkQuery verifies if the query parameter involves risky formatting. func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issue.Issue, error) { query, err := findQueryArg(call, ctx) if err != nil { return nil, err } - if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil { - decl := ident.Obj.Decl - if assign, ok := decl.(*ast.AssignStmt); ok { - for _, expr := range assign.Rhs { - issue := s.checkFormatting(expr, ctx) - if issue != nil { - return issue, err - } - } - } + // Must be a variable identifier (short-declared with :=) + ident, ok := query.(*ast.Ident) + if !ok { + return nil, nil } - return nil, nil + v, ok := ctx.Info.ObjectOf(ident).(*types.Var) + if !ok { + return nil, nil + } + + // Short variable declarations are always local → use the file containing the call + callFile := gosec.ContainingFile(call, ctx) + if callFile == nil { + return nil, nil + } + + // Find the defining short declaration (query := fmt.Sprintf(...)) + var foundIssue *issue.Issue + ast.Inspect(callFile, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok || assign.Tok != token.DEFINE { + return true + } + + // Find the LHS identifier that defines this variable + for _, lhs := range assign.Lhs { + if defIdent, ok := lhs.(*ast.Ident); ok && + defIdent.Pos() == v.Pos() && ctx.Info.ObjectOf(defIdent) == v { + + // Check every initializer expression on the RHS + for _, expr := range assign.Rhs { + if expr == nil { + continue + } + if iss := s.checkFormatting(expr, ctx); iss != nil { + foundIssue = iss + return false // Stop entire inspection + } + } + return false // Declaration found and processed + } + } + return true + }) + + return foundIssue, nil } +// checkFormatting checks if a formatting call builds a risky SQL query. func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Issue { // argIndex changes the function argument which gets matched to the regex argIndex := 0 if node := s.fmtCalls.ContainsPkgCallExpr(n, ctx, false); node != nil { // if the function is fmt.Fprintf, search for SQL statement in Args[1] instead - if sel, ok := node.Fun.(*ast.SelectorExpr); ok { - if sel.Sel.Name == "Fprintf" { - // if os.Stderr or os.Stdout is in Arg[0], mark as no issue - if arg, ok := node.Args[0].(*ast.SelectorExpr); ok { - if ident, ok := arg.X.(*ast.Ident); ok { - if s.noIssue.Contains(ident.Name, arg.Sel.Name) { - return nil - } - } + if sel, ok := node.Fun.(*ast.SelectorExpr); ok && sel.Sel.Name == "Fprintf" { + // if os.Stderr or os.Stdout is in Arg[0], mark as no issue + if arg, ok := node.Args[0].(*ast.SelectorExpr); ok { + if ident, ok := arg.X.(*ast.Ident); ok && s.noIssue.Contains(ident.Name, arg.Sel.Name) { + return nil } - // the function is Fprintf so set argIndex = 1 - argIndex = 1 } + // the function is Fprintf so set argIndex = 1 + argIndex = 1 } // no formatter @@ -319,17 +408,8 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is return nil } - var formatter string - - // concats callexpr arg strings together if needed before regex evaluation - if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok { - if fullStr, ok := gosec.ConcatString(argExpr); ok { - formatter = fullStr - } - } else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil { - formatter = arg - } - if len(formatter) <= 0 { + formatter, ok := gosec.ConcatString(node.Args[argIndex], ctx) + if !ok || formatter == "" { return nil } @@ -337,7 +417,7 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is if argIndex+1 < len(node.Args) { allSafe := true for _, arg := range node.Args[argIndex+1:] { - if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true); n == nil && !s.constObject(arg, ctx) { + if s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true) == nil && !gosec.TryResolve(arg, ctx) { allSafe = false break } @@ -346,6 +426,7 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is return nil } } + if s.MatchPatterns(formatter) { return ctx.NewIssue(n, s.ID(), s.What, s.Severity, s.Confidence) } @@ -353,37 +434,31 @@ func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) *issue.Is return nil } -// Check SQL query formatting issues such as "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)" +// Match looks for SQL calls involving formatted strings. func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, error) { switch stmt := n.(type) { case *ast.AssignStmt: for _, expr := range stmt.Rhs { if call, ok := expr.(*ast.CallExpr); ok { - selector, ok := call.Fun.(*ast.SelectorExpr) - if !ok { - continue - } - sqlQueryCall, ok := selector.X.(*ast.CallExpr) - if ok && s.ContainsCallExpr(sqlQueryCall, ctx) != nil { - issue, err := s.checkQuery(sqlQueryCall, ctx) - if err == nil && issue != nil { - return issue, err + if sel, ok := call.Fun.(*ast.SelectorExpr); ok { + if sqlCall, ok := sel.X.(*ast.CallExpr); ok && s.ContainsCallExpr(sqlCall, ctx) != nil { + return s.checkQuery(sqlCall, ctx) } } - } - if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil { - return s.checkQuery(sqlQueryCall, ctx) + if s.ContainsCallExpr(expr, ctx) != nil { + return s.checkQuery(call, ctx) + } } } case *ast.ExprStmt: - if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil { - return s.checkQuery(sqlQueryCall, ctx) + if call, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(call, ctx) != nil { + return s.checkQuery(call, ctx) } } return nil, nil } -// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings +// NewSQLStrFormat creates a rule for detecting SQL string formatting. func NewSQLStrFormat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) { rule := &sqlStrFormat{ CallList: gosec.NewCallList(), @@ -403,14 +478,13 @@ func NewSQLStrFormat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) { }, }, } - for s, si := range sqlCallIdents { - for i := range si { - rule.Add(s, i) + for typ, methods := range sqlCallIdents { + for method := range methods { + rule.Add(typ, method) } } rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf") rule.noIssue.AddAll("os", "Stdout", "Stderr") rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier") - return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)} } diff --git a/testutils/g201_samples.go b/testutils/g201_samples.go index b2811e7..83719c7 100644 --- a/testutils/g201_samples.go +++ b/testutils/g201_samples.go @@ -427,5 +427,147 @@ func main() { } defer stmt.Close() } +`}, 0, gosec.NewConfig()}, + {[]string{` +// Safe verb (%d) with tainted input - no string injection risk +package main + +import ( + "database/sql" + "fmt" + "os" + "strconv" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + id, _ := strconv.Atoi(os.Args[1]) // tainted but used with %d + q := fmt.Sprintf("SELECT * FROM foo WHERE id = %d", id) + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 0, gosec.NewConfig()}, + {[]string{` +// Mixed args: unsafe %s (tainted) + safe %d (constant) +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT * FROM %s WHERE id = %d", os.Args[1], 42) // tainted table + safe int + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 1, gosec.NewConfig()}, + {[]string{` +// All args constant but unsafe verb present - safe +package main + +import ( + "database/sql" + "fmt" +) + +const name = "admin" + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT * FROM users WHERE name = '%s'", name) + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 0, gosec.NewConfig()}, + {[]string{` +// Formatter from concatenation - risky +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + base := "SELECT * FROM foo WHERE" + q := fmt.Sprintf(base + " name = '%s'", os.Args[1]) + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 1, gosec.NewConfig()}, + {[]string{` +// No unsafe % verb but SQL pattern + tainted concat - G202, not G201 +package main + +import ( + "database/sql" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := "SELECT * FROM foo WHERE name = " + os.Args[1] // concat, no % + rows, err := db.Query(q) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 0, gosec.NewConfig()}, // G201 should NOT flag (G202 does) + {[]string{` +// Fprintf to os.Stderr - no issue +package main + +import ( + "database/sql" + "fmt" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + q := fmt.Sprintf("SELECT * FROM foo WHERE name = '%s'", os.Args[1]) + fmt.Fprintf(os.Stderr, "Debug query: %s\n", q) // log, not exec + rows, err := db.Query("SELECT * FROM foo") + if err != nil { + panic(err) + } + defer rows.Close() +} `}, 0, gosec.NewConfig()}, } diff --git a/testutils/g202_samples.go b/testutils/g202_samples.go index 58a153a..c67e1c2 100644 --- a/testutils/g202_samples.go +++ b/testutils/g202_samples.go @@ -335,5 +335,163 @@ func main() { } defer rows.Close() } +`}, 1, gosec.NewConfig()}, + {[]string{` +package main + +import ( + "database/sql" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + query := "SELECT * FROM album WHERE id = " + query += os.Args[0] + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 1, gosec.NewConfig()}, + {[]string{` +package main + +import ( + "fmt" + "os" +) + +func main() { + query := "SELECT * FROM album WHERE id = " + query += os.Args[0] + fmt.Println(query) +} +`}, 0, gosec.NewConfig()}, + {[]string{` +package main + +import ( + "database/sql" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + query := "SELECT * FROM album WHERE id = " + query = query + os.Args[0] // risky reassignment concatenation + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 1, gosec.NewConfig()}, + {[]string{` +package main + +import ( + "database/sql" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + query := "SELECT * FROM album WHERE id = " + query = query + "42" // safe literal reassignment concatenation + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 0, gosec.NewConfig()}, + {[]string{` +// Shadowing edge case: tainted mutation on shadowed variable - should NOT flag +// The outer 'query' is safe and passed to db.Query. +// The inner shadowed 'query' is mutated with tainted input (irrelevant). +package main + +import ( + "database/sql" + "os" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + query := "SELECT * FROM foo WHERE id = 42" // safe outer query + { + query := "base" // shadows outer query + query += os.Args[1] // tainted mutation on shadow - should be ignored + _ = query // prevent unused warning + } + rows, err := db.Query(query) // uses safe outer query + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 0, gosec.NewConfig()}, + {[]string{` +// Shadowing edge case: no mutation on shadow, safe outer - regression guard +package main + +import ( + "database/sql" +) + +func main() { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(err) + } + query := "SELECT * FROM foo WHERE id = 42" + { + query := "shadowed but unused" + _ = query + } + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() +} +`}, 0, gosec.NewConfig()}, + {[]string{` +// package-level SQL string with tainted concatenation in init() +package main + +import ( + "os" +) + +var query string = "SELECT * FROM foo WHERE name = " + +func init() { + query += os.Args[1] +} +`, ` +package main + +import ( + "database/sql" +) + +func main() { + db, _ := sql.Open("sqlite3", ":memory:") + _, _ = db.Query(query) +} `}, 1, gosec.NewConfig()}, }