Skip to content

Commit aacbfbd

Browse files
committed
More efficient constant estimation
Some accuracy is sacrificed
1 parent 93fe694 commit aacbfbd

1 file changed

Lines changed: 74 additions & 40 deletions

File tree

IntegerConversion.ConstantEstimation/Program.cs

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@ static void Main(string[] args)
88
{
99
int originalBitCount = int.Parse(args[0]);
1010
int newBitCount = int.Parse(args[1]);
11-
EstimateConstants(originalBitCount, newBitCount, out uint a, out uint b, out int c);
11+
bool perfectMatch = EstimateConstants(originalBitCount, newBitCount, out uint a, out long b, out int c);
1212
Console.WriteLine("Done!");
1313
Console.WriteLine($"Result: ({a} * x + {b}) >> {c}");
14+
Console.WriteLine($"Perfect match: {perfectMatch}");
1415
}
1516

16-
private static void EstimateConstants(int originalBitCount, int newBitCount, out uint a, out uint b, out int c)
17+
private static bool EstimateConstants(int originalBitCount, int newBitCount, out uint a, out long b, out int c)
1718
{
1819
a = default;
1920
b = default;
2021
c = default;
21-
float bestError = float.MaxValue;//Sum of squared errors
22+
double bestError = double.MaxValue;//Sum of squared errors
2223
uint originalMaxValue = GetIntegerMaxValue(originalBitCount);
2324
uint newMaxValue = GetIntegerMaxValue(newBitCount);
2425
int bitDiff = newBitCount - originalBitCount;
@@ -27,61 +28,94 @@ private static void EstimateConstants(int originalBitCount, int newBitCount, out
2728
//If bitDiff is negative, it means we are going from large numbers to small numbers.
2829
//In other words, we will always have to shift right by at least enough to overcome the bit surplus.
2930
int minc = int.Max(0, -bitDiff);
31+
// We are limited to a 32-bit result, so we cannot shift left so much that we exceed 32 bits.
3032
int maxc = 32 - newBitCount;
3133

3234
for (int ic = minc; ic <= maxc; ic++)
3335
{
3436
Console.WriteLine(ic);
3537

36-
//a: 2 ^ (c + bitDiff) to 2 ^ (c + bitDiff + 1)
37-
uint mina = Pow2(ic + bitDiff);
38-
uint maxa = Pow2(ic + bitDiff + 1);
38+
//a: 2^c * newMaxValue / originalMaxValue rounded to nearest integer
39+
double ia_exact = (double)Pow2(ic) * newMaxValue / originalMaxValue;
3940

40-
//b: 0 to 2^c - 1
41-
//Any bigger and it will cause 0 not to correspond with 0.
42-
uint minb = 0;
43-
uint maxb = Pow2(ic) - 1;
41+
ReadOnlySpan<uint> ia_values = ia_exact < 1 ? [(uint)double.Ceiling(ia_exact)] : [(uint)double.Floor(ia_exact), (uint)double.Ceiling(ia_exact)];
4442

45-
for (uint ia = mina; ia <= maxa; ia++)
43+
foreach (uint ia in ia_values)
4644
{
47-
if (ConvertEstimate(originalMaxValue, ia, minb, ic) > newMaxValue || ConvertEstimate(originalMaxValue, ia, maxb, ic) < newMaxValue)
45+
//b: 0 to 2^c - 1
46+
//Any bigger and it will cause 0 not to correspond with 0.
47+
uint minb;
48+
uint maxb;
4849
{
49-
continue;
50+
//We can restrict b further by ensuring that originalMaxValue maps to newMaxValue.
51+
long minb_one = (long)(newMaxValue << ic) - (ia * originalMaxValue);
52+
long maxb_one = minb_one + Pow2(ic) - 1;
53+
54+
minb = (uint)long.Max(0, minb_one);
55+
maxb = (uint)long.Min(Pow2(ic) - 1, long.Max(0, maxb_one));
56+
minb = uint.Min(minb, maxb);
5057
}
5158

52-
for (uint ib = minb; ib <= maxb; ib++)
59+
//Errors for b are nearly parabolic, so we try to find the minimum by ternary search.
60+
//This is not guaranteed to find the absolute minimum since there's noise in the error, but it should get close enough.
61+
uint current_b_low = minb;
62+
uint current_b_high = maxb;
63+
while (current_b_high - current_b_low > 3)
5364
{
54-
if (ConvertEstimate(originalMaxValue, ia, ib, ic) == newMaxValue)
65+
uint range = current_b_high - current_b_low;
66+
uint b1 = current_b_low + range / 3;
67+
uint b2 = current_b_high - range / 3;
68+
double error1 = CalculateMeanSquaredError(ia, b1, ic, originalMaxValue, newMaxValue, out _);
69+
double error2 = CalculateMeanSquaredError(ia, b2, ic, originalMaxValue, newMaxValue, out _);
70+
if (error1 < error2)
71+
{
72+
current_b_high = b2;
73+
}
74+
else
75+
{
76+
current_b_low = b1;
77+
}
78+
}
79+
80+
for (uint ib = current_b_low; ib <= current_b_high; ib++)
81+
{
82+
double error = CalculateMeanSquaredError(ia, ib, ic, originalMaxValue, newMaxValue, out bool anyIncorrect);
83+
if (error < bestError)
84+
{
85+
bestError = error;
86+
a = ia;
87+
b = ib;
88+
c = ic;
89+
Console.WriteLine($"New best: a={a}, b={b}, c={c}, MSE={bestError}");
90+
}
91+
if (!anyIncorrect)
5592
{
56-
float error = 0;
57-
bool anyIncorrect = false;
58-
for (uint x = 1; x < originalMaxValue; x++)
59-
{
60-
uint y = ConvertEstimate(x, ia, ib, ic);
61-
float yExact = (float)x * newMaxValue / originalMaxValue;
62-
uint yExactRounded = (uint)float.Round(yExact);
63-
anyIncorrect |= (y != yExactRounded);
64-
float diff = yExact - y;
65-
error += diff * diff;
66-
}
67-
if (!anyIncorrect)
68-
{
69-
a = ia;
70-
b = ib;
71-
c = ic;
72-
return;
73-
}
74-
if (error < bestError)
75-
{
76-
a = ia;
77-
b = ib;
78-
c = ic;
79-
bestError = error;
80-
}
93+
//Perfect match
94+
return true;
8195
}
8296
}
8397
}
8498
}
99+
100+
return false;
101+
}
102+
103+
private static double CalculateMeanSquaredError(uint a, uint b, int c, uint originalMaxValue, uint newMaxValue, out bool anyIncorrect)
104+
{
105+
double meanSquaredError = 0;
106+
anyIncorrect = false;
107+
for (uint x = 0; x <= originalMaxValue; x++)
108+
{
109+
uint y = ConvertEstimate(x, a, b, c);
110+
double yExact = (double)x * newMaxValue / originalMaxValue;
111+
112+
double diff = yExact - y;
113+
meanSquaredError = meanSquaredError * x / (x + 1) + (diff * diff) / (x + 1);
114+
115+
uint yExactInt = (uint)double.Round(yExact, MidpointRounding.AwayFromZero);
116+
anyIncorrect |= (y != yExactInt);
117+
}
118+
return meanSquaredError;
85119
}
86120

87121
[MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)]

0 commit comments

Comments
 (0)