@@ -8,17 +8,18 @@ static void Main(string[] args)
88 {
99 int originalBitCount = int . Parse ( args [ 0 ] ) ;
1010 int newBitCount = int . Parse ( args [ 1 ] ) ;
11- EstimateConstants ( originalBitCount , newBitCount , out uint a , out uint b , out int c ) ;
11+ bool perfectMatch = EstimateConstants ( originalBitCount , newBitCount , out uint a , out long b , out int c ) ;
1212 Console . WriteLine ( "Done!" ) ;
1313 Console . WriteLine ( $ "Result: ({ a } * x + { b } ) >> { c } ") ;
14+ Console . WriteLine ( $ "Perfect match: { perfectMatch } ") ;
1415 }
1516
16- private static void EstimateConstants ( int originalBitCount , int newBitCount , out uint a , out uint b , out int c )
17+ private static bool EstimateConstants ( int originalBitCount , int newBitCount , out uint a , out long b , out int c )
1718 {
1819 a = default ;
1920 b = default ;
2021 c = default ;
21- float bestError = float . MaxValue ; //Sum of squared errors
22+ double bestError = double . MaxValue ; //Sum of squared errors
2223 uint originalMaxValue = GetIntegerMaxValue ( originalBitCount ) ;
2324 uint newMaxValue = GetIntegerMaxValue ( newBitCount ) ;
2425 int bitDiff = newBitCount - originalBitCount ;
@@ -27,61 +28,94 @@ private static void EstimateConstants(int originalBitCount, int newBitCount, out
2728 //If bitDiff is negative, it means we are going from large numbers to small numbers.
2829 //In other words, we will always have to shift right by at least enough to overcome the bit surplus.
2930 int minc = int . Max ( 0 , - bitDiff ) ;
31+ // We are limited to a 32-bit result, so we cannot shift left so much that we exceed 32 bits.
3032 int maxc = 32 - newBitCount ;
3133
3234 for ( int ic = minc ; ic <= maxc ; ic ++ )
3335 {
3436 Console . WriteLine ( ic ) ;
3537
36- //a: 2 ^ (c + bitDiff) to 2 ^ (c + bitDiff + 1)
37- uint mina = Pow2 ( ic + bitDiff ) ;
38- uint maxa = Pow2 ( ic + bitDiff + 1 ) ;
38+ //a: 2^c * newMaxValue / originalMaxValue rounded to nearest integer
39+ double ia_exact = ( double ) Pow2 ( ic ) * newMaxValue / originalMaxValue ;
3940
40- //b: 0 to 2^c - 1
41- //Any bigger and it will cause 0 not to correspond with 0.
42- uint minb = 0 ;
43- uint maxb = Pow2 ( ic ) - 1 ;
41+ ReadOnlySpan < uint > ia_values = ia_exact < 1 ? [ ( uint ) double . Ceiling ( ia_exact ) ] : [ ( uint ) double . Floor ( ia_exact ) , ( uint ) double . Ceiling ( ia_exact ) ] ;
4442
45- for ( uint ia = mina ; ia <= maxa ; ia ++ )
43+ foreach ( uint ia in ia_values )
4644 {
47- if ( ConvertEstimate ( originalMaxValue , ia , minb , ic ) > newMaxValue || ConvertEstimate ( originalMaxValue , ia , maxb , ic ) < newMaxValue )
45+ //b: 0 to 2^c - 1
46+ //Any bigger and it will cause 0 not to correspond with 0.
47+ uint minb ;
48+ uint maxb ;
4849 {
49- continue ;
50+ //We can restrict b further by ensuring that originalMaxValue maps to newMaxValue.
51+ long minb_one = ( long ) ( newMaxValue << ic ) - ( ia * originalMaxValue ) ;
52+ long maxb_one = minb_one + Pow2 ( ic ) - 1 ;
53+
54+ minb = ( uint ) long . Max ( 0 , minb_one ) ;
55+ maxb = ( uint ) long . Min ( Pow2 ( ic ) - 1 , long . Max ( 0 , maxb_one ) ) ;
56+ minb = uint . Min ( minb , maxb ) ;
5057 }
5158
52- for ( uint ib = minb ; ib <= maxb ; ib ++ )
59+ //Errors for b are nearly parabolic, so we try to find the minimum by ternary search.
60+ //This is not guaranteed to find the absolute minimum since there's noise in the error, but it should get close enough.
61+ uint current_b_low = minb ;
62+ uint current_b_high = maxb ;
63+ while ( current_b_high - current_b_low > 3 )
5364 {
54- if ( ConvertEstimate ( originalMaxValue , ia , ib , ic ) == newMaxValue )
65+ uint range = current_b_high - current_b_low ;
66+ uint b1 = current_b_low + range / 3 ;
67+ uint b2 = current_b_high - range / 3 ;
68+ double error1 = CalculateMeanSquaredError ( ia , b1 , ic , originalMaxValue , newMaxValue , out _ ) ;
69+ double error2 = CalculateMeanSquaredError ( ia , b2 , ic , originalMaxValue , newMaxValue , out _ ) ;
70+ if ( error1 < error2 )
71+ {
72+ current_b_high = b2 ;
73+ }
74+ else
75+ {
76+ current_b_low = b1 ;
77+ }
78+ }
79+
80+ for ( uint ib = current_b_low ; ib <= current_b_high ; ib ++ )
81+ {
82+ double error = CalculateMeanSquaredError ( ia , ib , ic , originalMaxValue , newMaxValue , out bool anyIncorrect ) ;
83+ if ( error < bestError )
84+ {
85+ bestError = error ;
86+ a = ia ;
87+ b = ib ;
88+ c = ic ;
89+ Console . WriteLine ( $ "New best: a={ a } , b={ b } , c={ c } , MSE={ bestError } ") ;
90+ }
91+ if ( ! anyIncorrect )
5592 {
56- float error = 0 ;
57- bool anyIncorrect = false ;
58- for ( uint x = 1 ; x < originalMaxValue ; x ++ )
59- {
60- uint y = ConvertEstimate ( x , ia , ib , ic ) ;
61- float yExact = ( float ) x * newMaxValue / originalMaxValue ;
62- uint yExactRounded = ( uint ) float . Round ( yExact ) ;
63- anyIncorrect |= ( y != yExactRounded ) ;
64- float diff = yExact - y ;
65- error += diff * diff ;
66- }
67- if ( ! anyIncorrect )
68- {
69- a = ia ;
70- b = ib ;
71- c = ic ;
72- return ;
73- }
74- if ( error < bestError )
75- {
76- a = ia ;
77- b = ib ;
78- c = ic ;
79- bestError = error ;
80- }
93+ //Perfect match
94+ return true ;
8195 }
8296 }
8397 }
8498 }
99+
100+ return false ;
101+ }
102+
103+ private static double CalculateMeanSquaredError ( uint a , uint b , int c , uint originalMaxValue , uint newMaxValue , out bool anyIncorrect )
104+ {
105+ double meanSquaredError = 0 ;
106+ anyIncorrect = false ;
107+ for ( uint x = 0 ; x <= originalMaxValue ; x ++ )
108+ {
109+ uint y = ConvertEstimate ( x , a , b , c ) ;
110+ double yExact = ( double ) x * newMaxValue / originalMaxValue ;
111+
112+ double diff = yExact - y ;
113+ meanSquaredError = meanSquaredError * x / ( x + 1 ) + ( diff * diff ) / ( x + 1 ) ;
114+
115+ uint yExactInt = ( uint ) double . Round ( yExact , MidpointRounding . AwayFromZero ) ;
116+ anyIncorrect |= ( y != yExactInt ) ;
117+ }
118+ return meanSquaredError ;
85119 }
86120
87121 [ MethodImpl ( MethodImplOptions . AggressiveInlining | MethodImplOptions . AggressiveOptimization ) ]
0 commit comments