refactor(g115): improve coverage (#1462)

This commit is contained in:
oittaa
2026-01-12 11:37:18 +01:00
committed by GitHub
parent 0cc9e01a9d
commit 833d7919e0
3 changed files with 669 additions and 146 deletions

View File

@@ -18,6 +18,7 @@ import (
"fmt"
"go/token"
"math"
"math/bits"
"strconv"
"strings"
@@ -60,8 +61,8 @@ type overflowState struct {
}
type rangeCacheKey struct {
ifInstr *ssa.If
val ssa.Value
block *ssa.BasicBlock
val ssa.Value
}
func newOverflowState(pass *analysis.Pass) *overflowState {
@@ -91,6 +92,7 @@ func runConversionOverflow(pass *analysis.Pass) (any, error) {
if state.isSafeConversion(instr) {
continue
}
issue := newIssue(pass.Analyzer.Name,
fmt.Sprintf("integer overflow conversion %s -> %s", src, dst),
pass.Fset,
@@ -160,15 +162,9 @@ func (s *overflowState) hasRangeCheck(v ssa.Value, dstType string, block *ssa.Ba
isSrcUnsigned := strings.HasPrefix(v.Type().Underlying().String(), "uint")
// Get resolved range (definition + dominators)
rangeRes := s.resolveRange(v, block, make(map[ssa.Value]bool))
minValue := rangeRes.minValue
maxValue := rangeRes.maxValue
minValueSet := rangeRes.minValueSet
maxValueSet := rangeRes.maxValueSet
explicitPositiveVals := rangeRes.explicitPositiveVals
explicitNegativeVals := rangeRes.explicitNegativeVals
res := s.resolveRange(v, block, make(map[ssa.Value]bool))
minValue, minValueSet, maxValue, maxValueSet, isRangeCheck := res.minValue, res.minValueSet, res.maxValue, res.maxValueSet, res.isRangeCheck
explicitPositiveVals, explicitNegativeVals := res.explicitPositiveVals, res.explicitNegativeVals
if explicitValsInRange(explicitPositiveVals, explicitNegativeVals, dstInt) {
return true
}
@@ -187,6 +183,19 @@ func (s *overflowState) hasRangeCheck(v ssa.Value, dstType string, block *ssa.Ba
}
}
if explicitValsInRange(res.explicitPositiveVals, res.explicitNegativeVals, dstInt) {
return true
}
// Relax requirement: If we have a definitive range (both set) and it's safe,
// we allow it even if not explicitly "checked" by an IF,
// because definition-based ranges (like constants or arithmetic on constants) are certain.
isDefinitiveSafe := minValueSet && maxValueSet
if !isRangeCheck && !isDefinitiveSafe {
return false
}
// Check for impossible ranges (disjoint)
if !isSrcUnsigned {
if minValueSet && maxValueSet && toInt64(minValue) > toInt64(maxValue) {
@@ -197,22 +206,16 @@ func (s *overflowState) hasRangeCheck(v ssa.Value, dstType string, block *ssa.Ba
return true
}
var resFinal bool
if dstInt.Signed {
if isSrcUnsigned {
resFinal = maxValueSet && maxValue <= uint64(dstInt.Max)
} else {
resFinal = (minValueSet && toInt64(minValue) >= int64(dstInt.Min)) && (maxValueSet && toInt64(maxValue) <= toInt64(uint64(dstInt.Max)))
}
} else {
if isSrcUnsigned {
resFinal = maxValueSet && maxValue <= uint64(dstInt.Max)
} else {
resFinal = (minValueSet && toInt64(minValue) >= 0) && (maxValueSet && maxValue <= uint64(dstInt.Max))
return maxValueSet && maxValue <= uint64(dstInt.Max)
}
return (minValueSet && toInt64(minValue) >= int64(dstInt.Min)) && (maxValueSet && toInt64(maxValue) <= toInt64(uint64(dstInt.Max)))
}
return resFinal
if isSrcUnsigned {
return maxValueSet && maxValue <= uint64(dstInt.Max)
}
return (minValueSet && toInt64(minValue) >= 0) && (maxValueSet && maxValue <= uint64(dstInt.Max))
}
// minBounds computes the minimum of two uint64 values, treating them as signed if !isSrcUnsigned.
@@ -229,6 +232,23 @@ func minBounds(a, b uint64, isSrcUnsigned bool) uint64 {
return b
}
// updateRangeMinMax updates the min or max value of the result range if the new value is tighter.
func updateRangeMinMax(result *rangeResult, newVal uint64, isMin bool, isSrcUnsigned bool) {
if isMin {
if !result.minValueSet || (isSrcUnsigned && newVal > result.minValue) || (!isSrcUnsigned && toInt64(newVal) > toInt64(result.minValue)) {
result.minValue = newVal
result.minValueSet = true
result.isRangeCheck = true
}
} else {
if !result.maxValueSet || (isSrcUnsigned && newVal < result.maxValue) || (!isSrcUnsigned && toInt64(newVal) < toInt64(result.maxValue)) {
result.maxValue = newVal
result.maxValueSet = true
result.isRangeCheck = true
}
}
}
// maxBounds computes the maximum of two uint64 values, treating them as signed if !isSrcUnsigned.
func maxBounds(a, b uint64, isSrcUnsigned bool) uint64 {
if a == toUint64(minInt64) { // Using MinInt64 as "not set" for signed-capable minValue
@@ -300,20 +320,7 @@ func (s *overflowState) getResultRangeForIfEdge(vIf *ssa.If, isTrue bool, v ssa.
}
// getResultRangeForValue calculates the range of a value by analyzing the dominator tree and control flow.
func (s *overflowState) getResultRangeForValue(ifInstr *ssa.If, v ssa.Value, targetBlock *ssa.BasicBlock, visitedIfs map[*ssa.If]bool) rangeResult {
key := rangeCacheKey{ifInstr, v}
if res, ok := s.rangeCache[key]; ok {
return res
}
if visitedIfs[ifInstr] {
return rangeResult{
minValue: toUint64(minInt64),
maxValue: maxUint64,
}
}
visitedIfs[ifInstr] = true
func (s *overflowState) getResultRangeForValue(ifInstr *ssa.If, v ssa.Value, targetBlock *ssa.BasicBlock) rangeResult {
cond := ifInstr.Cond
binOp, ok := cond.(*ssa.BinOp)
if !ok || !isRangeCheck(binOp, v) {
@@ -334,13 +341,11 @@ func (s *overflowState) getResultRangeForValue(ifInstr *ssa.If, v ssa.Value, tar
elseFound := isReachable(ifInstr.Block().Succs[1], targetBlock, make(map[*ssa.BasicBlock]bool))
if thenFound && elseFound {
s.rangeCache[key] = result
return result
}
s.updateResultFromBinOpForValue(&result, binOp, v, thenFound)
s.rangeCache[key] = result
return result
}
@@ -370,6 +375,7 @@ func (s *overflowState) updateResultFromBinOpForValue(result *rangeResult, binOp
}
var matchSide ssa.Value
var inverseOp operationInfo
if isEquivalent(binOp.X, v) {
matchSide = binOp.Y
op = operationInfo{}
@@ -379,9 +385,24 @@ func (s *overflowState) updateResultFromBinOpForValue(result *rangeResult, binOp
op = operationInfo{}
} else if isSameOrRelated(binOp.X, compareVal) {
matchSide = binOp.Y
// check if binOp.X has an operation relative to compareVal
if rVal, rOp := getRealValueFromOperation(binOp.X); rVal == compareVal {
inverseOp = rOp
}
} else if rVal, rOp := getRealValueFromOperation(binOp.X); rVal == compareVal {
matchSide = binOp.Y
inverseOp = rOp
} else if isSameOrRelated(binOp.Y, compareVal) {
matchSide = binOp.X
operandsFlipped = true
// check if binOp.Y has an operation relative to compareVal
if rVal, rOp := getRealValueFromOperation(binOp.Y); rVal == compareVal {
inverseOp = rOp
}
} else if rVal, rOp := getRealValueFromOperation(binOp.Y); rVal == compareVal {
matchSide = binOp.X
operandsFlipped = true
inverseOp = rOp
} else {
return
}
@@ -392,6 +413,126 @@ func (s *overflowState) updateResultFromBinOpForValue(result *rangeResult, binOp
return
}
// Apply inverse operations to the limit 'val' before updating min/max
// e.g. if x << 2 < 100. val=100. inverseOp=<<.
// we want range for x. x < 100 >> 2.
if inverseOp.op != "" {
switch inverseOp.op {
case "<<":
if vShift, ok := GetConstantInt64(inverseOp.extra); ok && vShift >= 0 {
val = val >> uint(vShift)
}
case "+":
if vAdd, ok := GetConstantInt64(inverseOp.extra); ok {
val -= vAdd
}
case "-":
if vSub, ok := GetConstantInt64(inverseOp.extra); ok {
if inverseOp.flipped { // val = extra - x => x = extra - val
val = vSub - val
operandsFlipped = !operandsFlipped
} else { // val = x - extra => x = val + extra
val += vSub
}
}
case ">>":
if vShift, ok := GetConstantInt64(inverseOp.extra); ok && vShift >= 0 {
val = val << uint(vShift)
}
case "*":
if vMul, ok := GetConstantUint64(inverseOp.extra); ok && vMul > 0 {
val = toInt64(toUint64(val) / vMul)
}
case "/":
if vQuo, ok := GetConstantUint64(inverseOp.extra); ok && vQuo > 0 {
if inverseOp.flipped { // val = extra / x => x = extra / val
if val != 0 {
val = toInt64(vQuo / toUint64(val))
}
operandsFlipped = !operandsFlipped
} else { // val = x / extra => x = val * vQuo
val = toInt64(toUint64(val) * vQuo)
}
}
}
}
// Apply forward operations from 'op' to the limit 'val'
// e.g. if x < 30 and v is x * 10. val=30. op=*.
// we want range for v. v < 30 * 10.
if op.op != "" {
switch op.op {
case "<<":
if vShift, ok := GetConstantInt64(op.extra); ok && vShift >= 0 {
val = val << uint(vShift)
}
case "+":
if vAdd, ok := GetConstantInt64(op.extra); ok {
val += vAdd
}
case "-":
if vSub, ok := GetConstantInt64(op.extra); ok {
if op.flipped { // v = extra - x. x < val => v > extra - val
val = vSub - val
operandsFlipped = !operandsFlipped
} else { // v = x - extra. x < val => v < val - extra
val -= vSub
}
}
case ">>":
if vShift, ok := GetConstantInt64(op.extra); ok && vShift >= 0 {
val = val >> uint(vShift)
}
case "*":
isSrcUnsigned := strings.HasPrefix(v.Type().Underlying().String(), "uint")
if isSrcUnsigned {
if vMul, ok := GetConstantUint64(op.extra); ok && vMul != 0 {
hi, lo := bits.Mul64(toUint64(val), vMul)
if hi != 0 {
return
}
val = toInt64(lo)
}
} else {
if vMul, ok := GetConstantInt64(op.extra); ok && vMul != 0 {
if vMul > 0 {
if val >= 0 {
hi, lo := bits.Mul64(toUint64(val), uint64(vMul))
if hi != 0 {
return
}
val = toInt64(lo)
} else {
// Negative val, positive vMul
if val < minInt64/vMul {
return
}
val = val * vMul
}
} else {
// Negative vMul
val = val * vMul
operandsFlipped = !operandsFlipped
}
}
}
case "/":
if vQuo, ok := GetConstantInt64(op.extra); ok && vQuo > 0 {
if op.flipped { // v = extra / x. x < val => v > extra / val
if val != 0 {
val = vQuo / val
}
operandsFlipped = !operandsFlipped
} else { // v = x / extra. x < val => v < val / vQuo
val = val / vQuo
}
}
case "neg":
val = -val
operandsFlipped = !operandsFlipped
}
}
switch binOp.Op {
case token.LEQ, token.LSS:
updateMinMaxForLessOrEqual(result, val, binOp.Op, operandsFlipped, successPathConvert)
@@ -465,6 +606,15 @@ func (s *overflowState) updateResultFromBinOpForValue(result *rangeResult, binOp
result.maxValue >>= uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
}
}
case "<<":
if val, ok := GetConstantInt64(op.extra); ok && val >= 0 {
if result.maxValueSet {
result.maxValue <<= uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
}
if result.minValueSet {
result.minValue <<= uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
}
}
case "%":
if val, ok := GetConstantInt64(op.extra); ok && val > 0 {
if (result.minValueSet && toInt64(result.minValue) >= 0) || isNonNegative(binOp.X) || isNonNegative(compareVal) {
@@ -504,10 +654,10 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
// Definition-based range
switch v := v.(type) {
case *ssa.BinOp:
subResX := s.computeRange(v.X, block, visited)
switch v.Op {
case token.ADD:
subResY := s.computeRange(v.Y, block, visited)
subResX := s.resolveRange(v.X, block, visited)
subResY := s.resolveRange(v.Y, block, visited)
if subResX.minValueSet && subResY.minValueSet {
res.minValue = toUint64(toInt64(subResX.minValue) + toInt64(subResY.minValue))
res.minValueSet = true
@@ -518,6 +668,7 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
}
res.isRangeCheck = subResX.isRangeCheck || subResY.isRangeCheck
case token.SUB:
subResX := s.resolveRange(v.X, block, visited)
if val, ok := GetConstantInt64(v.Y); ok {
// x - val
if subResX.minValueSet {
@@ -531,27 +682,33 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
res.isRangeCheck = subResX.isRangeCheck
} else if val, ok := GetConstantInt64(v.X); ok {
// val - x
subResY := s.computeRange(v.Y, block, visited)
subResY := s.resolveRange(v.Y, block, visited)
if subResY.maxValueSet {
res.minValue = toUint64(val - toInt64(subResY.maxValue))
res.minValueSet = true
}
if subResY.minValueSet {
res.maxValue = toUint64(val - toInt64(subResY.minValue))
res.maxValueSet = true
}
res.isRangeCheck = subResY.isRangeCheck
}
case token.AND:
if val, ok := GetConstantInt64(v.Y); ok && val >= 0 {
// AND decreases magnitude usually.
if val, ok := GetConstantUint64(v.Y); ok {
res.minValue = 0
res.minValueSet = true
res.maxValue = uint64(val)
res.maxValue = val
res.maxValueSet = true
res.isRangeCheck = true
} else {
// If Y is not a constant, we can only say it's non-negative if X is.
if isNonNegative(v.X) {
res.minValue = 0
res.minValueSet = true
}
}
case token.SHR:
if val, ok := GetConstantInt64(v.Y); ok && val >= 0 {
subResX := s.resolveRange(v.X, block, visited)
if isNonNegative(v.X) {
res.minValue = 0
res.minValueSet = true
@@ -566,8 +723,32 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
}
res.isRangeCheck = subResX.isRangeCheck
}
case token.SHL:
if val, ok := GetConstantInt64(v.Y); ok && val >= 0 {
subResX := s.resolveRange(v.X, block, visited)
if subResX.minValueSet {
newMin := subResX.minValue << uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
// Check for overflow/wrap-around
// #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
if newMin>>uint(val) == subResX.minValue {
res.minValue = newMin
res.minValueSet = true
}
}
if subResX.maxValueSet {
newMax := subResX.maxValue << uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
// Check for overflow/wrap-around
// #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
if newMax>>uint(val) == subResX.maxValue {
res.maxValue = newMax
res.maxValueSet = true
}
}
res.isRangeCheck = subResX.isRangeCheck
}
case token.REM:
if val, ok := GetConstantInt64(v.Y); ok && val > 0 {
subResX := s.resolveRange(v.X, block, visited)
if (subResX.minValueSet && toInt64(subResX.minValue) >= 0) || isNonNegative(v.X) {
res.minValue = 0
res.minValueSet = true
@@ -582,69 +763,81 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
res.isRangeCheck = true
}
case token.MUL:
val, ok := GetConstantInt64(v.Y)
val, ok := GetConstantUint64(v.Y)
if !ok {
val, ok = GetConstantInt64(v.X)
val, ok = GetConstantUint64(v.X)
}
if ok && val != 0 {
var subRes rangeResult
if isSameOrRelated(v.Y, v.X) { // e.g. x*x, handled by generic fallback if not constant
// Should typically not happen if we found a constant
if isSameOrRelated(v.Y, v.X) {
// x*x handled by generic fallback
} else if _, isConst := v.Y.(*ssa.Const); isConst {
subRes = s.computeRange(v.X, block, visited)
subRes = s.resolveRange(v.X, block, visited)
} else {
subRes = s.computeRange(v.Y, block, visited)
subRes = s.resolveRange(v.Y, block, visited)
}
if val > 0 {
if subRes.minValueSet {
res.minValue = toUint64(toInt64(subRes.minValue) * val)
res.minValueSet = true
}
if subRes.maxValueSet {
res.maxValue = toUint64(toInt64(subRes.maxValue) * val)
res.maxValueSet = true
}
} else {
if subRes.maxValueSet {
res.minValue = toUint64(toInt64(subRes.maxValue) * val)
res.minValueSet = true
}
if subRes.minValueSet {
res.maxValue = toUint64(toInt64(subRes.minValue) * val)
res.maxValueSet = true
if subRes.maxValueSet {
hi, _ := bits.Mul64(subRes.maxValue, val)
if hi != 0 {
return res
}
}
if subRes.minValueSet {
res.minValue = subRes.minValue * val
res.minValueSet = true
}
if subRes.maxValueSet {
res.maxValue = subRes.maxValue * val
res.maxValueSet = true
}
res.isRangeCheck = subRes.isRangeCheck
}
case token.QUO:
if val, ok := GetConstantInt64(v.Y); ok && val != 0 {
subResX := s.computeRange(v.X, block, visited)
if val > 0 {
if val, ok := GetConstantUint64(v.Y); ok && val != 0 {
subResX := s.resolveRange(v.X, block, visited)
isSrcUnsigned := strings.HasPrefix(v.Type().Underlying().String(), "uint")
if isSrcUnsigned {
if subResX.minValueSet {
res.minValue = toUint64(toInt64(subResX.minValue) / val)
res.minValue = subResX.minValue / val
res.minValueSet = true
}
if subResX.maxValueSet {
res.maxValue = toUint64(toInt64(subResX.maxValue) / val)
res.maxValue = subResX.maxValue / val
res.maxValueSet = true
}
} else {
if subResX.maxValueSet {
res.minValue = toUint64(toInt64(subResX.maxValue) / val)
res.minValueSet = true
}
if subResX.minValueSet {
res.maxValue = toUint64(toInt64(subResX.minValue) / val)
res.maxValueSet = true
vVal := toInt64(val)
if vVal > 0 {
if subResX.minValueSet {
res.minValue = toUint64(toInt64(subResX.minValue) / vVal)
res.minValueSet = true
}
if subResX.maxValueSet {
res.maxValue = toUint64(toInt64(subResX.maxValue) / vVal)
res.maxValueSet = true
}
} else { // vVal < 0
if subResX.maxValueSet {
res.minValue = toUint64(toInt64(subResX.maxValue) / vVal)
res.minValueSet = true
}
if subResX.minValueSet {
res.maxValue = toUint64(toInt64(subResX.minValue) / vVal)
res.maxValueSet = true
}
}
}
res.isRangeCheck = subResX.isRangeCheck
}
}
case *ssa.UnOp:
if v.Op == token.SUB {
subRes := s.computeRange(v.X, block, visited)
subRes := s.resolveRange(v.X, block, visited)
switch v.Op {
case token.SUB:
// Negation: -x.
// Min = -Max. Max = -Min.
if subRes.maxValueSet {
res.minValue = toUint64(-toInt64(subRes.maxValue))
res.minValueSet = true
@@ -654,14 +847,45 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
res.maxValueSet = true
}
res.isRangeCheck = subRes.isRangeCheck
case token.XOR:
// Bitwise NOT: ^x = -x - 1.
// Min = ^Max. Max = ^Min.
if subRes.maxValueSet {
res.minValue = toUint64(toInt64(^subRes.maxValue))
res.minValueSet = true
}
if subRes.minValueSet {
res.maxValue = toUint64(toInt64(^subRes.minValue))
res.maxValueSet = true
}
res.isRangeCheck = subRes.isRangeCheck
}
case *ssa.Call:
if fn, ok := v.Call.Value.(*ssa.Builtin); ok {
switch fn.Name() {
case "len", "cap":
res.minValue = 0
res.minValueSet = true
res.isRangeCheck = true
if len(v.Call.Args) == 1 {
arg := v.Call.Args[0]
if _, ok := arg.(*ssa.Slice); ok || arg.Type().String() == "string" {
// len(slice) or len(string) is non-negative
// Try to resolve range of the slice/string length if possible?
// For now, just >= 0.
// We can also check if the slice came from make()
argRes := s.resolveRange(arg, block, visited)
if argRes.minValueSet {
res.minValue = argRes.minValue
res.minValueSet = true
} else {
res.minValue = 0
res.minValueSet = true
}
if argRes.maxValueSet {
res.maxValue = argRes.maxValue
res.maxValueSet = true
}
res.isRangeCheck = true
}
}
case "min":
for i, arg := range v.Call.Args {
argRes := s.resolveRange(arg, block, visited)
@@ -756,6 +980,17 @@ func (s *overflowState) computeRange(v ssa.Value, block *ssa.BasicBlock, visited
}
}
}
case *ssa.Convert:
subRes := s.resolveRange(v.X, block, visited)
if subRes.minValueSet || subRes.maxValueSet {
res = subRes
}
case *ssa.ChangeType:
subRes := s.resolveRange(v.X, block, visited)
if subRes.minValueSet || subRes.maxValueSet {
res = subRes
}
case *ssa.Const:
if val, ok := GetConstantInt64(v); ok {
res.minValue = toUint64(val)
@@ -949,7 +1184,17 @@ func isRangeCheck(v ssa.Value, x ssa.Value) bool {
switch op.Op {
case token.LSS, token.LEQ, token.GTR, token.GEQ, token.EQL, token.NEQ:
leftMatch := isSameOrRelated(op.X, x) || isSameOrRelated(op.X, compareVal)
if !leftMatch {
if rVal, _ := getRealValueFromOperation(op.X); rVal == x || (compareVal != nil && rVal == compareVal) {
leftMatch = true
}
}
rightMatch := isSameOrRelated(op.Y, x) || isSameOrRelated(op.Y, compareVal)
if !rightMatch {
if rVal, _ := getRealValueFromOperation(op.Y); rVal == x || (compareVal != nil && rVal == compareVal) {
rightMatch = true
}
}
return leftMatch || rightMatch
}
}
@@ -989,7 +1234,7 @@ func getRealValueFromOperation(v ssa.Value) (ssa.Value, operationInfo) {
return v, operationInfo{}
case *ssa.BinOp:
switch v.Op {
case token.ADD, token.SUB, token.AND, token.SHR, token.REM:
case token.ADD, token.SUB, token.AND, token.SHR, token.SHL, token.REM, token.MUL, token.QUO:
if _, ok := v.Y.(*ssa.Const); ok {
return v.X, operationInfo{op: v.Op.String(), extra: v.Y}
}
@@ -1054,6 +1299,10 @@ func explicitValsInRange(explicitPosVals []uint, explicitNegVals []int, dstInt I
// resolveRange combines definition-based range analysis (computeRange) with dominator-based constraints (If blocks) to determine the full range of a value.
func (s *overflowState) resolveRange(v ssa.Value, block *ssa.BasicBlock, visited map[ssa.Value]bool) rangeResult {
key := rangeCacheKey{block: block, val: v}
if res, ok := s.rangeCache[key]; ok {
return res
}
isSrcUnsigned := strings.HasPrefix(v.Type().Underlying().String(), "uint")
// Track bounds
result := rangeResult{
@@ -1086,11 +1335,10 @@ func (s *overflowState) resolveRange(v ssa.Value, block *ssa.BasicBlock, visited
// Check all dominating If instructions.
idoms := getDominators(block)
visitedIfs := make(map[*ssa.If]bool)
for _, idom := range idoms {
for _, instr := range idom.Instrs {
if vIf, ok := instr.(*ssa.If); ok {
domRes := s.getResultRangeForValue(vIf, v, block, visitedIfs)
domRes := s.getResultRangeForValue(vIf, v, block)
if domRes.isRangeCheck {
result.isRangeCheck = true
if domRes.minValueSet {
@@ -1114,10 +1362,46 @@ func (s *overflowState) resolveRange(v ssa.Value, block *ssa.BasicBlock, visited
// to avoid regressions in pure definition-based constant handling.
if binOp, ok := v.(*ssa.BinOp); ok {
switch binOp.Op {
case token.ADD:
// Handle x+C or C+x
if val, ok := GetConstantInt64(binOp.Y); ok {
subRes := s.resolveRange(binOp.X, block, visited)
if subRes.isRangeCheck {
if subRes.minValueSet {
updateRangeMinMax(&result, toUint64(toInt64(subRes.minValue)+val), true, isSrcUnsigned)
}
if subRes.maxValueSet {
updateRangeMinMax(&result, toUint64(toInt64(subRes.maxValue)+val), false, isSrcUnsigned)
}
}
} else if val, ok := GetConstantInt64(binOp.X); ok {
subRes := s.resolveRange(binOp.Y, block, visited)
if subRes.isRangeCheck {
if subRes.minValueSet {
updateRangeMinMax(&result, toUint64(val+toInt64(subRes.minValue)), true, isSrcUnsigned)
}
if subRes.maxValueSet {
updateRangeMinMax(&result, toUint64(val+toInt64(subRes.maxValue)), false, isSrcUnsigned)
}
}
}
case token.SUB:
// Handle x-C. C-x logic is harder (inverts min/max), skipping for simplicity/safety unless needed.
if val, ok := GetConstantInt64(binOp.Y); ok {
subRes := s.resolveRange(binOp.X, block, visited)
if subRes.isRangeCheck {
if subRes.minValueSet {
updateRangeMinMax(&result, toUint64(toInt64(subRes.minValue)-val), true, isSrcUnsigned)
}
if subRes.maxValueSet {
updateRangeMinMax(&result, toUint64(toInt64(subRes.maxValue)-val), false, isSrcUnsigned)
}
}
}
case token.MUL:
val, ok := GetConstantInt64(binOp.Y)
val, ok := GetConstantUint64(binOp.Y)
if !ok {
val, ok = GetConstantInt64(binOp.X)
val, ok = GetConstantUint64(binOp.X)
}
if ok && val != 0 {
var subRes rangeResult
@@ -1127,41 +1411,51 @@ func (s *overflowState) resolveRange(v ssa.Value, block *ssa.BasicBlock, visited
subRes = s.resolveRange(binOp.Y, block, visited)
}
if val > 0 {
if subRes.minValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.minValue) * val)
// Only update if tighter/set
if !result.minValueSet || (isSrcUnsigned && res > result.minValue) || (!isSrcUnsigned && toInt64(res) > toInt64(result.minValue)) {
result.minValue = res
result.minValueSet = true
// Inherit isRangeCheck to allow further propagation
result.isRangeCheck = true
if subRes.maxValueSet {
hi, _ := bits.Mul64(subRes.maxValue, val)
if hi != 0 {
break
}
}
if subRes.minValueSet && subRes.isRangeCheck {
updateRangeMinMax(&result, subRes.minValue*val, true, isSrcUnsigned)
}
if subRes.maxValueSet && subRes.isRangeCheck {
updateRangeMinMax(&result, subRes.maxValue*val, false, isSrcUnsigned)
}
}
case token.SHL:
if val, ok := GetConstantInt64(binOp.Y); ok && val >= 0 {
subRes := s.resolveRange(binOp.X, block, visited)
if subRes.isRangeCheck {
if subRes.minValueSet {
newMin := subRes.minValue << uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
// Check for overflow/wrap-around
// #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
if newMin>>uint(val) == subRes.minValue {
updateRangeMinMax(&result, newMin, true, isSrcUnsigned)
}
}
if subRes.maxValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.maxValue) * val)
if !result.maxValueSet || (isSrcUnsigned && res < result.maxValue) || (!isSrcUnsigned && toInt64(res) < toInt64(result.maxValue)) {
result.maxValue = res
result.maxValueSet = true
result.isRangeCheck = true
if subRes.maxValueSet {
newMax := subRes.maxValue << uint(val) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
// Check for overflow/wrap-around
// #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
if newMax>>uint(val) == subRes.maxValue {
updateRangeMinMax(&result, newMax, false, isSrcUnsigned)
}
}
} else {
if subRes.maxValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.maxValue) * val)
if !result.minValueSet || (isSrcUnsigned && res > result.minValue) || (!isSrcUnsigned && toInt64(res) > toInt64(result.minValue)) {
result.minValue = res
result.minValueSet = true
result.isRangeCheck = true
}
}
}
case token.SHR:
if val, ok := GetConstantInt64(binOp.Y); ok && val >= 0 {
subRes := s.resolveRange(binOp.X, block, visited)
if subRes.isRangeCheck {
if subRes.minValueSet {
updateRangeMinMax(&result, subRes.minValue>>uint(val), true, isSrcUnsigned) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
}
if subRes.minValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.minValue) * val)
if !result.maxValueSet || (isSrcUnsigned && res < result.maxValue) || (!isSrcUnsigned && toInt64(res) < toInt64(result.maxValue)) {
result.maxValue = res
result.maxValueSet = true
result.isRangeCheck = true
}
if subRes.maxValueSet {
updateRangeMinMax(&result, subRes.maxValue>>uint(val), false, isSrcUnsigned) // #nosec G115 - WORKAROUND for old golangci-lint, remove when updated
}
}
}
@@ -1170,37 +1464,17 @@ func (s *overflowState) resolveRange(v ssa.Value, block *ssa.BasicBlock, visited
subRes := s.resolveRange(binOp.X, block, visited)
if val > 0 {
if subRes.minValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.minValue) / val)
if !result.minValueSet || (isSrcUnsigned && res > result.minValue) || (!isSrcUnsigned && toInt64(res) > toInt64(result.minValue)) {
result.minValue = res
result.minValueSet = true
result.isRangeCheck = true
}
updateRangeMinMax(&result, toUint64(toInt64(subRes.minValue)/val), true, isSrcUnsigned)
}
if subRes.maxValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.maxValue) / val)
if !result.maxValueSet || (isSrcUnsigned && res < result.maxValue) || (!isSrcUnsigned && toInt64(res) < toInt64(result.maxValue)) {
result.maxValue = res
result.maxValueSet = true
result.isRangeCheck = true
}
updateRangeMinMax(&result, toUint64(toInt64(subRes.maxValue)/val), false, isSrcUnsigned)
}
} else {
if subRes.maxValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.maxValue) / val)
if !result.minValueSet || (isSrcUnsigned && res > result.minValue) || (!isSrcUnsigned && toInt64(res) > toInt64(result.minValue)) {
result.minValue = res
result.minValueSet = true
result.isRangeCheck = true
}
updateRangeMinMax(&result, toUint64(toInt64(subRes.maxValue)/val), true, isSrcUnsigned)
}
if subRes.minValueSet && subRes.isRangeCheck {
res := toUint64(toInt64(subRes.minValue) / val)
if !result.maxValueSet || (isSrcUnsigned && res < result.maxValue) || (!isSrcUnsigned && toInt64(res) < toInt64(result.maxValue)) {
result.maxValue = res
result.maxValueSet = true
result.isRangeCheck = true
}
updateRangeMinMax(&result, toUint64(toInt64(subRes.minValue)/val), false, isSrcUnsigned)
}
}
}
@@ -1220,6 +1494,7 @@ func (s *overflowState) resolveRange(v ssa.Value, block *ssa.BasicBlock, visited
}
}
}
// Persist in cache
s.rangeCache[key] = result
return result
}

View File

@@ -218,6 +218,18 @@ func GetConstantInt64(v ssa.Value) (int64, bool) {
return 0, false
}
// GetConstantUint64 extracts a constant uint64 value from an ssa.Value
func GetConstantUint64(v ssa.Value) (uint64, bool) {
if c, ok := v.(*ssa.Const); ok {
if c.Value != nil {
if val, ok := constant.Uint64Val(c.Value); ok {
return val, true
}
}
}
return 0, false
}
// GetSliceBounds extracts low, high, and max indices from a slice instruction
func GetSliceBounds(s *ssa.Slice) (int, int, int) {
var low, high, maxIdx int

View File

@@ -1412,4 +1412,240 @@ func shrProp(x uint8) uint8 {
return x >> 1
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func shlProp(x uint64) uint16 {
if x < 256 {
return uint16(x << 8) // max 255 << 8 = 65280. Fits in uint16 (65535)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func shlOverflow(x uint64) uint16 {
if x < 256 {
return uint16(x << 9) // max 255 << 9 = 130560. Overflows uint16.
}
return 0
}
`}, 1, gosec.NewConfig()},
{[]string{`
package main
func shlSafeCheck(x int) uint16 {
if x > 0 && x < 10 {
return uint16(x << 4) // max 9 << 4 = 144. Fits.
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func shlUnsafeCheck(x int) uint16 {
if x > 0 && x < 10000 {
return uint16(x << 4) // max 9999 << 4 = 159984. Overflows uint16.
}
return 0
}
`}, 1, gosec.NewConfig()},
{[]string{`
package main
func shlCompute(x int) uint8 {
// x & 0x0F -> range [0, 15]
// 15 << 2 = 60. Fits in uint8.
return uint8((x & 0x0F) << 2)
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func remUint(x uint) uint8 {
// x is uint (non-negative).
// x % 10 -> range [0, 9].
// Fits in uint8.
return uint8(x % 10)
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func shlCondition(x int) uint8 {
// if x << 2 < 100
// x range is inferred.
// x*4 < 100 => x < 25.
// uint8(x) is safe.
if (x << 2) < 100 && x >= 0 {
return uint8(x)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func shlMinUpdate(x int) uint8 {
// x > 10 -> x in [11, Max]
// x << 2 -> [44, Max]
if x > 10 && x < 20 {
return uint8(x << 2) // [44, 76] fits uint8
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
type S struct { F int }
func fieldCompareRHS(s *S) uint8 {
// 10 < s.F -> s.F > 10
// s.F is struct field, different SSA reads.
if 10 < s.F && s.F < 250 {
return uint8(s.F)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func rhsOpFallback(x int) uint8 {
// 100 > x << 2 => x << 2 < 100 => x < 25
if 100 > x << 2 && x >= 0 {
return uint8(x)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func inverseAddSafe(x int) uint8 {
// x + 1000 < 1010 => x < 10
// If we miss inverse op, we see x < 1010 (unsafe)
if x + 1000 < 1010 && x >= 0 {
return uint8(x) // Safe
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func inverseSubUnsafe(x int) uint8 {
// x - 1000 < 10 => x < 1010
// If we miss inverse op, we see x < 10 (safe)
// Actually unsafe.
if x - 1000 < 10 && x >= 0 {
return uint8(x) // Unsafe
}
return 0
}
`}, 1, gosec.NewConfig()},
{[]string{`
package main
func inverseShrSafe(x int) uint8 {
// x >> 2 < 10 => x < 40 (approx 10 << 2)
// Actually [0, 39] >> 2 is [0, 9]. 40 >> 2 is 10.
// So distinct x < 40.
if x >> 2 < 10 && x >= 0 {
return uint8(x) // Safe
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func inverseMulSafe(x int) uint8 {
// x * 10 < 100 => x < 10
if x * 10 < 100 && x >= 0 {
return uint8(x) // Safe
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func mulMinUpdate(x int) uint8 {
// x > 10. x * 2 > 20.
// if x < 50. x * 2 < 100.
// result [22, 100]. Fits uint8.
// Hits MUL minValue update (recursive tightens forward).
if x > 10 && x < 50 {
return uint8(x * 2)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func quoMinUpdate(x int) uint8 {
// x > 20. x / 2 > 10.
// x < 100. x / 2 < 50.
// result [10, 50]. Fits uint8.
// Hits QUO minValue update.
if x > 20 && x < 100 {
return uint8(x / 2)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func mulOverflow64(x uint64) uint8 {
if x >= 1 && x <= 2 {
return uint8(x * 0x8000000000000001)
}
return 0
}
`}, 1, gosec.NewConfig()},
{[]string{`
package main
type T int64
func testChangeType(x T) int8 {
if x > 0 && x < 100 {
return int8(x) // Propagate through ChangeType (T is int64-based)
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func testCommutativeAdd(x int) uint8 {
if 10 + x < 30 && x > 0 {
return uint8(x) // Safe [1, 19]
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func testXOR(x uint8) int8 {
if x < 128 {
y := ^x // [0, 127] -> [128, 255]
return int8(y) // Unsafe
}
return 0
}
`}, 1, gosec.NewConfig()},
{[]string{`
package main
func testInvFlippedQuo(x int) uint16 {
if x > 0 && 10000 / x < 5 {
return uint16(x) // Unsafe: x > 2000.
}
return 0
}
`}, 1, gosec.NewConfig()},
{[]string{`
package main
func testInvQuo(x int64) uint8 {
if x > 0 && x / 10 < 5 {
return uint8(x) // Safe: x < 50
}
return 0
}
`}, 0, gosec.NewConfig()},
{[]string{`
package main
func testDoubleReturn(x int) (uint8, uint16) {
if x > 0 && x < 10 {
return uint8(x), uint16(x)
}
return 0, 0
}
`}, 0, gosec.NewConfig()},
}