Skip to content

Commit b0d9273

Browse files
authored
Implements weighters (UnitTestBot#67)
Signed-off-by: Старцев Матвей <tozarin@yandex.ru>
1 parent ce262d0 commit b0d9273

29 files changed

Lines changed: 680 additions & 231 deletions
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.usvm.ps.weighters
2+
3+
class CombinedStateStableFloatWeighter<in State> : CombinedStateWeighter<State, Float> {
4+
5+
constructor(
6+
weighters: List<StateWeighterWithReport<State, Float>>
7+
) : super(weighters, ARITHMETIC)
8+
9+
constructor(
10+
weighters: List<StateWeighterWithReport<State, Float>>,
11+
metaWeights: List<Float>
12+
) : super(weighters, metaWeights, ARITHMETIC)
13+
14+
companion object {
15+
fun <State> withNorm(weighters: List<NormalizableWeighter<State, Float>>, metaWeights: List<Float>) =
16+
CombinedStateStableFloatWeighter<State>(weighters.map { it.normalize() }, metaWeights)
17+
18+
fun <State> fromInt(
19+
weighters: List<StateWeighterWithReport<State, Int>>,
20+
metaWeights: List<Float>
21+
) = CombinedStateStableFloatWeighter(
22+
weighters.map { StateWeighterWithCast(it, Int::toFloat) },
23+
metaWeights
24+
)
25+
26+
private val ARITHMETIC = StableFloatArithmetic
27+
}
28+
}
Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
package org.usvm.ps.weighters
22

3-
import org.usvm.ps.StateWeighter
4-
5-
class CombinedStateStableIntWeighter<in State> : CombinedStateWeighter<State, Int, Float> {
3+
class CombinedStateStableIntWeighter<in State> : CombinedStateWeighter<State, Int> {
64

75
constructor(
8-
weighters: List<StateWeighter<State, Int>>
9-
) : super(weighters, Int::stableAdd)
6+
weighters: List<StateWeighterWithReport<State, Int>>
7+
) : super(weighters, StableIntArithmetic)
108

119
constructor(
12-
weighters: List<StateWeighter<State, Int>>,
13-
metaWeights: List<Float>
14-
) : super(
15-
weighters,
16-
metaWeights,
17-
Int::stableAdd,
18-
Int::stableMul
19-
)
10+
weighters: List<StateWeighterWithReport<State, Int>>,
11+
metaWeights: List<Int>
12+
) : super(weighters, metaWeights, StableIntArithmetic)
2013
}
Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,38 @@
11
package org.usvm.ps.weighters
22

3-
import org.usvm.ps.StateWeighter
3+
open class CombinedStateWeighter<in State, Weight> : StateWeighterWithReport<State, Weight> {
44

5-
open class CombinedStateWeighter<in State, out Weight, MetaWeight> : StateWeighter<State, Weight> {
5+
override val weighterName = "CombinedStateWeighter"
6+
private val arithmetic: Arithmetic<Weight>
67

7-
private val weightersWithWeights: List<Pair<StateWeighter<State, Weight>, MetaWeight?>>
8-
private val weightCombiner: (Weight, Weight) -> Weight
9-
private val metaWeightApplier: ((Weight, MetaWeight) -> Weight)?
8+
private val weightersWithWeights: List<Pair<StateWeighterWithReport<State, Weight>, Weight>>
109

1110
constructor(
12-
weighters: List<StateWeighter<State, Weight>>,
13-
weightCombiner: (Weight, Weight) -> Weight
14-
) {
15-
check(weighters.isNotEmpty()) { "CombinedStateWeighter must have at least one weighter" }
16-
this.weightersWithWeights = weighters.map { it to null }
17-
this.weightCombiner = weightCombiner
18-
this.metaWeightApplier = null
19-
}
11+
weighters: List<StateWeighterWithReport<State, Weight>>,
12+
arithmetic: Arithmetic<Weight>
13+
) : this(weighters, weighters.map { arithmetic.one }, arithmetic)
2014

2115
constructor(
22-
weighters: List<StateWeighter<State, Weight>>,
23-
metaWeights: List<MetaWeight>,
24-
weightCombiner: (Weight, Weight) -> Weight,
25-
metaWeightApplier: (Weight, MetaWeight) -> Weight
16+
weighters: List<StateWeighterWithReport<State, Weight>>,
17+
metaWeights: List<Weight>,
18+
arithmetic: Arithmetic<Weight>
2619
) {
2720
check(weighters.isNotEmpty()) { "CombinedStateWeighter must have at least one weighter" }
2821
this.weightersWithWeights = weighters.zip(metaWeights)
29-
this.weightCombiner = weightCombiner
30-
this.metaWeightApplier = metaWeightApplier
22+
this.arithmetic = arithmetic
3123
}
3224

33-
override fun weight(state: State): Weight {
34-
var result: Weight? = null
35-
if (metaWeightApplier != null) {
36-
val applier: ((Weight, MetaWeight) -> Weight) = metaWeightApplier
37-
for ((weighter, metaWeight) in weightersWithWeights) {
38-
val stateWeight = weighter.weight(state)
39-
val weight = applier(stateWeight, metaWeight!!)
40-
result = if (result == null) weight else weightCombiner(weight, result)
41-
}
42-
43-
return result!!
44-
}
25+
override fun weight(state: State): Weight = weightWithReport(state).weight
4526

46-
for ((weighter, _) in weightersWithWeights) {
47-
val weight = weighter.weight(state)
48-
result = if (result == null) weight else weightCombiner(weight, result)
27+
override fun weightWithReport(state: State) = with(arithmetic) {
28+
val reports = mutableListOf<WeighterReport<Weight>>()
29+
val result = weightersWithWeights.fold(zero) { sum, (weighter, metaWeight) ->
30+
val report = weighter.weightWithReport(state)
31+
val weight = report.weight.mulTo(metaWeight)
32+
reports.add(report)
33+
sum.plusTo(weight)
4934
}
5035

51-
return result!!
36+
CombinedWeighterReport(result, weighterName, reports)
5237
}
5338
}
Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,85 @@
11
package org.usvm.ps.weighters
22

3-
private fun Long.stableToInt(): Int {
4-
return when {
5-
this < Int.MIN_VALUE -> Int.MIN_VALUE
6-
this > Int.MAX_VALUE -> Int.MAX_VALUE
3+
abstract class Arithmetic<T> {
4+
abstract val maxValue: T
5+
abstract val minValue: T
6+
abstract val zero: T
7+
abstract val one: T
8+
abstract val negativeOne: T
9+
10+
open val comparator: Comparator<T>? = null
11+
12+
abstract fun plus(left: T, right: T): T
13+
abstract fun minus(left: T, right: T): T
14+
abstract fun mul(left: T, right: T): T
15+
abstract fun div(left: T, right: T): T
16+
17+
open fun negate(value: T) = mul(negativeOne, value)
18+
19+
open fun compare(left: T, right: T) = comparator?.compare(left, right)
20+
?: error("Comparator was not defined for arithmetic")
21+
22+
open fun isGreater(left: T, right: T) = compare(left, right) > 0
23+
open fun isLess(left: T, right: T) = compare(left, right) < 0
24+
open fun isEquals(left: T, right: T) = compare(left, right) == 0
25+
26+
fun max(left: T, right: T) = if (left.isGreaterTo(right)) left else right
27+
fun min(left: T, right: T) = if (left.isLessTo(right)) left else right
28+
29+
fun T.plusTo(other: T) = plus(this, other)
30+
fun T.minusTo(other: T) = minus(this, other)
31+
fun T.mulTo(other: T) = mul(this, other)
32+
fun T.divTo(other: T) = div(this, other)
33+
fun T.negateTo() = negate(this)
34+
35+
fun T.compareTo(other: T) = compare(this, other)
36+
fun T.isGreaterTo(other: T) = isGreater(this, other)
37+
fun T.isLessTo(other: T) = isLess(this, other)
38+
fun T.isEqualsTo(other: T) = isEquals(this, other)
39+
}
40+
41+
object StableIntArithmetic : Arithmetic<Int>() {
42+
override val maxValue = Int.MAX_VALUE
43+
override val minValue = Int.MIN_VALUE
44+
override val zero = 0
45+
override val one = 1
46+
override val negativeOne = -1
47+
48+
override val comparator: Comparator<Int> = compareBy<Int> { it }
49+
50+
private fun Long.stableToInt() = when {
51+
this < minValue -> minValue
52+
this > maxValue -> maxValue
753
else -> this.toInt()
854
}
9-
}
1055

11-
fun Int.stableAdd(other: Int): Int {
12-
val longResult = toLong() + other
13-
return longResult.stableToInt()
14-
}
56+
override fun plus(left: Int, right: Int) = (left.toLong() + right).stableToInt()
57+
override fun minus(left: Int, right: Int) = plus(negate(right), left)
58+
override fun mul(left: Int, right: Int) = (left.toLong() * right).stableToInt()
59+
override fun div(left: Int, right: Int) = (left.toLong() / right).stableToInt()
1560

16-
fun Int.stableMul(other: Int): Int {
17-
val longResult = toLong() * other
18-
return longResult.stableToInt()
61+
override fun negate(value: Int) = if (value == minValue) maxValue else -value
1962
}
2063

21-
fun Int.stableMul(other: Float): Int {
22-
val longResult = (this.toDouble() * other).toLong()
23-
return longResult.stableToInt()
24-
}
64+
object StableFloatArithmetic : Arithmetic<Float>() {
65+
override val maxValue = Float.POSITIVE_INFINITY
66+
override val minValue = Float.NEGATIVE_INFINITY
67+
override val zero = 0f
68+
override val one = 1f
69+
override val negativeOne = -1f
70+
71+
override val comparator: Comparator<Float> = compareBy<Float> { it }
72+
73+
private fun Double.stableToFloat() = when {
74+
this < minValue -> minValue
75+
this > maxValue -> maxValue
76+
else -> this.toFloat()
77+
}
78+
79+
override fun plus(left: Float, right: Float) = (left.toDouble() + right).stableToFloat()
80+
override fun minus(left: Float, right: Float) = plus(negate(right), left)
81+
override fun mul(left: Float, right: Float) = (left.toDouble() * right).stableToFloat()
82+
override fun div(left: Float, right: Float) = (left.toDouble() / right).stableToFloat()
2583

26-
fun Int.stableUnaryMinus(other: Int): Int {
27-
return if (this == Int.MIN_VALUE) Int.MAX_VALUE else -this
84+
override fun negate(value: Float) = if (value == minValue) maxValue else -value
2885
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package org.usvm.ps.weighters
2+
3+
class StateWeighterWithCast<in State, InWeight, OutWeight>(
4+
val weighter: StateWeighterWithReport<State, InWeight>,
5+
val cast: (InWeight) -> OutWeight
6+
): StateWeighterWithReport<State, OutWeight>() {
7+
override val weighterName = "${weighter.weighterName} (cast.)"
8+
9+
override fun weight(state: State) = cast(weighter.weight(state))
10+
override fun weightWithReport(state: State) =
11+
CastWeighterReport(weighterName, weighter.weightWithReport(state), cast)
12+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package org.usvm.ps.weighters
2+
3+
fun interface NormalizableWeighter<in State, out Weight> {
4+
fun normalize(): StateWeighterWithReport<State, Weight>
5+
}
6+
7+
enum class WeightNormalizerType {
8+
POSITIVE, // normed values should be from [0; 1]
9+
NEGATIVE, // -.- from [-1; 0]
10+
CUSTOM // own min and max values to norm
11+
}
12+
13+
class StateWeighterWithNorm<in State, InWeight, OutWeight>(
14+
val weighter: StateWeighterWithReport<State, InWeight>,
15+
val cast: InWeight.() -> OutWeight,
16+
val maxWeight: OutWeight,
17+
val minWeight: OutWeight,
18+
val arithmetic: Arithmetic<OutWeight>,
19+
val type: WeightNormalizerType
20+
) : StateWeighterWithReport<State, OutWeight>() {
21+
22+
init {
23+
check(type != WeightNormalizerType.CUSTOM) { "CUSTOM type for StateWeighterWithNorm is not implemented" }
24+
check(arithmetic.isGreater(maxWeight, minWeight)) { "MAX_WEIGHT is bigger that MIN_WEIGHT in Weighter norm" }
25+
}
26+
27+
override val weighterName = "${weighter.weighterName} (n.)"
28+
29+
override fun weight(state: State) = with(arithmetic) {
30+
val weight = weighter.weight(state).cast()
31+
val norm = when {
32+
weight.isGreaterTo(maxWeight) -> one
33+
weight.isLessTo(minWeight) -> zero
34+
else -> weight.minusTo(minWeight).divTo(maxWeight.minusTo(minWeight))
35+
}
36+
37+
if (type == WeightNormalizerType.POSITIVE) norm else norm.minusTo(one)
38+
}
39+
40+
companion object {
41+
fun <State> normalizeInt(
42+
weighter: StateWeighterWithReport<State, Int>,
43+
maxWeight: Int, minWeight: Int,
44+
type: WeightNormalizerType
45+
) = StateWeighterWithNorm(weighter, { this }, maxWeight, minWeight, StableIntArithmetic, type)
46+
47+
fun <State> normalizeFloat(
48+
weighter: StateWeighterWithReport<State, Float>,
49+
maxWeight: Float,
50+
minWeight: Float,
51+
type: WeightNormalizerType
52+
) = StateWeighterWithNorm(weighter, { this }, maxWeight, minWeight, StableFloatArithmetic, type)
53+
54+
fun <State> normalizeIntToFloat(
55+
weighter: StateWeighterWithReport<State, Int>,
56+
maxWeight: Int,
57+
minWeight: Int,
58+
type: WeightNormalizerType
59+
) = StateWeighterWithNorm(
60+
weighter,
61+
{ this.toFloat() },
62+
maxWeight.toFloat(),
63+
minWeight.toFloat(),
64+
StableFloatArithmetic,
65+
type
66+
)
67+
}
68+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.usvm.ps.weighters
2+
3+
import org.usvm.ps.StateWeighter
4+
5+
abstract class StateWeighterWithReport<in State, out Weight> : StateWeighter<State, Weight> {
6+
abstract val weighterName: String
7+
open fun weightWithReport(state: State): WeighterReport<Weight> = SingleWeighterReport(weight(state), weighterName)
8+
}
Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,18 @@
11
package org.usvm.ps.weighters
22

3-
import java.io.File
43
import org.usvm.UState
54
import org.usvm.ps.StateWeighter
65
import org.usvm.statistics.CoverageStatistics
7-
import java.io.PrintStream
86

9-
// TODO: delete #KEK
10-
val weightersLogFile = File(System.getProperty("user.dir")).resolve("weighters.log").also {
11-
if (it.exists()) it.delete()
12-
it.createNewFile()
13-
}
14-
15-
val weightersLog = PrintStream(weightersLogFile)
16-
17-
class UncoveredStateWeighter<Method, Statement, in State : UState<*, Method, Statement, *, *, in State>>(
7+
abstract class UncoveredStateWeighter<Method, Statement, in State : UState<*, Method, Statement, *, *, in State>>(
188
coverageStatistics: CoverageStatistics<Method, Statement, in State>,
19-
) : StateWeighter<State, Int> {
9+
) : StateWeighterWithReport<State, Int>() {
2010

21-
private val uncoveredStatements: HashSet<Statement> = HashSet(coverageStatistics.getUncoveredStatements())
11+
protected val uncoveredStatements: HashSet<Statement> = HashSet(coverageStatistics.getUncoveredStatements())
2212

2313
init {
2414
coverageStatistics.addOnCoveredObserver { _, _, statement ->
2515
uncoveredStatements.remove(statement)
2616
}
2717
}
28-
29-
override fun weight(state: State): Int {
30-
val result = state.pathNode.allStatements.count { it in uncoveredStatements }
31-
weightersLog.println("UncoveredStateWeighter: state = ${state.id}, weight = $result")
32-
return result
33-
}
3418
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.usvm.ps.weighters
2+
3+
abstract class WeighterReport<out Weight> {
4+
abstract val weight: Weight
5+
abstract val weighterName: String
6+
7+
open fun report2String() = "[$weighterName $weight]"
8+
}
9+
10+
class SingleWeighterReport<Weight>(
11+
override val weight: Weight,
12+
override val weighterName: String
13+
) : WeighterReport<Weight>() {
14+
}
15+
16+
class CombinedWeighterReport<Weight>(
17+
override val weight: Weight,
18+
override val weighterName: String,
19+
val reports: List<WeighterReport<Weight>>
20+
) : WeighterReport<Weight>() {
21+
override fun report2String() = "[${reports.joinToString(", ") { it.report2String() }}]"
22+
}
23+
24+
class CastWeighterReport<InWeight, OutWeight>(
25+
override val weighterName: String,
26+
val report: WeighterReport<InWeight>,
27+
val cast: (InWeight) -> OutWeight
28+
) : WeighterReport<OutWeight>() {
29+
override val weight = cast(report.weight)
30+
override fun report2String() = report.report2String() + "*"
31+
}

0 commit comments

Comments
 (0)