11/**
2- * Use the chinese remainder theorem to solve a set of congruence equations .
2+ * Solve a set of congruence equations using the Chinese Remainder Theorem (CRT) .
33 *
4- * <p>The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
5- * m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
6- * coefficient. A value of null is returned if the coefficient cannot be eliminated.
4+ * Given a system of simultaneous congruences:
75 *
8- * <p>The second method (reduce) is used to reduce a set of equations so that the moduli become
9- * pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
10- * output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
11- * m_n−1). Note that the number of equations may change during this process. A value of null is
12- * returned if the set of equations cannot be reduced to co-prime moduli.
6+ * x ≡ a_0 (mod m_0)
7+ * x ≡ a_1 (mod m_1)
8+ * ...
9+ * x ≡ a_{n-1} (mod m_{n-1})
1310 *
14- * <p>The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
15- * moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
16- * m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
17- * x≡a_new(mod m_new)x≡a_new(mod m_new).
11+ * where all moduli m_i are pairwise coprime (gcd(m_i, m_j) = 1 for i ≠ j), the CRT guarantees a
12+ * unique solution x modulo M = m_0 * m_1 * ... * m_{n-1}.
13+ *
14+ * The solution is constructed as x = sum of a_i * M_i * y_i (mod M), where M_i = M / m_i and y_i
15+ * is the modular inverse of M_i modulo m_i (found via the extended Euclidean algorithm). Each term
16+ * contributes a_i for the i-th congruence and vanishes (mod m_j) for all j ≠ i, so the sum
17+ * satisfies every equation simultaneously.
18+ *
19+ * When moduli are not pairwise coprime, the system must first be reduced. Each modulus is split
20+ * into prime-power factors (e.g. 12 = 4 * 3), converting one equation into several with
21+ * prime-power moduli. Redundant equations are removed and conflicting ones detected. After
22+ * reduction, the moduli are pairwise coprime and the standard CRT applies.
23+ *
24+ * The eliminateCoefficient method handles equations of the form cx ≡ a (mod m) by dividing through
25+ * by gcd(c, m) — which is only possible when gcd(c, m) divides a — and then multiplying by the
26+ * modular inverse of the reduced coefficient.
1827 *
1928 * @author Micah Stairs
2029 */
2433
2534public class ChineseRemainderTheorem {
2635
27- // eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
36+ /**
37+ * Reduces cx ≡ a (mod m) to x ≡ a' (mod m').
38+ *
39+ * @return {a', m'} or null if unsolvable.
40+ */
2841 public static long [] eliminateCoefficient (long c , long a , long m ) {
29-
3042 long d = egcd (c , m )[0 ];
3143
32- if (a % d != 0 ) return null ;
44+ if (a % d != 0 )
45+ return null ;
3346
3447 c /= d ;
3548 a /= d ;
@@ -42,36 +55,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) {
4255 return new long [] {a , m };
4356 }
4457
45- // reduce() takes a set of equations and reduces them to an equivalent
46- // set with pairwise co-prime moduli (or null if not solvable).
58+ /**
59+ * Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli.
60+ *
61+ * @return {a[], m[]} with coprime moduli, or null if the system is inconsistent.
62+ */
4763 public static long [][] reduce (long [] a , long [] m ) {
64+ List <Long > aNew = new ArrayList <>();
65+ List <Long > mNew = new ArrayList <>();
4866
49- List <Long > aNew = new ArrayList <Long >();
50- List <Long > mNew = new ArrayList <Long >();
51-
52- // Split up each equation into prime factors
67+ // Split each modulus into prime-power factors
5368 for (int i = 0 ; i < a .length ; i ++) {
5469 List <Long > factors = primeFactorization (m [i ]);
5570 Collections .sort (factors );
56- ListIterator <Long > iterator = factors .listIterator ();
57- while (iterator .hasNext ()) {
58- long val = iterator .next ();
59- long total = val ;
60- while (iterator .hasNext ()) {
61- long nextVal = iterator .next ();
62- if (nextVal == val ) {
63- total *= val ;
64- } else {
65- iterator .previous ();
66- break ;
67- }
71+
72+ int j = 0 ;
73+ while (j < factors .size ()) {
74+ long p = factors .get (j );
75+ long pk = p ;
76+ j ++;
77+ while (j < factors .size () && factors .get (j ) == p ) {
78+ pk *= p ;
79+ j ++;
6880 }
69- aNew .add (a [i ] % total );
70- mNew .add (total );
81+ aNew .add (a [i ] % pk );
82+ mNew .add (pk );
7183 }
7284 }
7385
74- // Throw away repeated information and look for conflicts
86+ // Remove redundant equations and detect conflicts
7587 for (int i = 0 ; i < aNew .size (); i ++) {
7688 for (int j = i + 1 ; j < aNew .size (); j ++) {
7789 if (mNew .get (i ) % mNew .get (j ) == 0 || mNew .get (j ) % mNew .get (i ) == 0 ) {
@@ -81,111 +93,161 @@ public static long[][] reduce(long[] a, long[] m) {
8193 mNew .remove (j );
8294 j --;
8395 continue ;
84- } else return null ;
96+ } else
97+ return null ;
8598 } else {
8699 if ((aNew .get (j ) % mNew .get (i )) == aNew .get (i )) {
87100 aNew .remove (i );
88101 mNew .remove (i );
89102 i --;
90103 break ;
91- } else return null ;
104+ } else
105+ return null ;
92106 }
93107 }
94108 }
95109 }
96110
97- // Put result into an array
98111 long [][] res = new long [2 ][aNew .size ()];
99112 for (int i = 0 ; i < aNew .size (); i ++) {
100113 res [0 ][i ] = aNew .get (i );
101114 res [1 ][i ] = mNew .get (i );
102115 }
103-
104116 return res ;
105117 }
106118
119+ /**
120+ * Solves x ≡ a[i] (mod m[i]) assuming all moduli are pairwise coprime.
121+ *
122+ * @return {x, M} where M is the product of all moduli.
123+ */
107124 public static long [] crt (long [] a , long [] m ) {
108-
109125 long M = 1 ;
110- for (int i = 0 ; i < m .length ; i ++) M *= m [i ];
111-
112- long [] inv = new long [a .length ];
113- for (int i = 0 ; i < inv .length ; i ++) inv [i ] = egcd (M / m [i ], m [i ])[1 ];
126+ for (long mi : m )
127+ M *= mi ;
114128
115129 long x = 0 ;
116130 for (int i = 0 ; i < m .length ; i ++) {
117- x += (M / m [i ]) * a [i ] * inv [i ]; // Overflow could occur here
131+ long Mi = M / m [i ];
132+ long inv = egcd (Mi , m [i ])[1 ];
133+ x += Mi * a [i ] * inv ;
118134 x = ((x % M ) + M ) % M ;
119135 }
120136
121137 return new long [] {x , M };
122138 }
123139
124- private static ArrayList <Long > primeFactorization (long n ) {
125- ArrayList <Long > factors = new ArrayList <Long >();
126- if (n <= 0 ) throw new IllegalArgumentException ();
127- else if (n == 1 ) return factors ;
128- PriorityQueue <Long > divisorQueue = new PriorityQueue <Long >();
129- divisorQueue .add (n );
130- while (!divisorQueue .isEmpty ()) {
131- long divisor = divisorQueue .remove ();
132- if (isPrime (divisor )) {
133- factors .add (divisor );
134- continue ;
135- }
136- long next_divisor = pollardRho (divisor );
137- if (next_divisor == divisor ) {
138- divisorQueue .add (divisor );
139- } else {
140- divisorQueue .add (next_divisor );
141- divisorQueue .add (divisor / next_divisor );
142- }
143- }
140+ /** Extended Euclidean algorithm. Returns {gcd(a,b), x, y} where ax + by = gcd(a,b). */
141+ static long [] egcd (long a , long b ) {
142+ if (b == 0 )
143+ return new long [] {a , 1 , 0 };
144+ long [] ret = egcd (b , a % b );
145+ long tmp = ret [1 ] - ret [2 ] * (a / b );
146+ ret [1 ] = ret [2 ];
147+ ret [2 ] = tmp ;
148+ return ret ;
149+ }
150+
151+ private static List <Long > primeFactorization (long n ) {
152+ if (n <= 0 )
153+ throw new IllegalArgumentException ();
154+ List <Long > factors = new ArrayList <>();
155+ factor (n , factors );
144156 return factors ;
145157 }
146158
159+ private static void factor (long n , List <Long > factors ) {
160+ if (n == 1 )
161+ return ;
162+ if (isPrime (n )) {
163+ factors .add (n );
164+ return ;
165+ }
166+ long d = pollardRho (n );
167+ factor (d , factors );
168+ factor (n / d , factors );
169+ }
170+
147171 private static long pollardRho (long n ) {
148- if (n % 2 == 0 ) return 2 ;
149- // Get a number in the range [2, 10^6]
172+ if (n % 2 == 0 )
173+ return 2 ;
150174 long x = 2 + (long ) (999999 * Math .random ());
151175 long c = 2 + (long ) (999999 * Math .random ());
152176 long y = x ;
153177 long d = 1 ;
154178 while (d == 1 ) {
155- x = (x * x + c ) % n ;
156- y = (y * y + c ) % n ;
157- y = (y * y + c ) % n ;
158- d = gcf (Math .abs (x - y ), n );
159- if (d == n ) break ;
179+ x = mulMod (x , x , n ) + c ;
180+ if (x >= n )
181+ x -= n ;
182+ y = mulMod (y , y , n ) + c ;
183+ if (y >= n )
184+ y -= n ;
185+ y = mulMod (y , y , n ) + c ;
186+ if (y >= n )
187+ y -= n ;
188+ d = gcd (Math .abs (x - y ), n );
189+ if (d == n )
190+ break ;
160191 }
161192 return d ;
162193 }
163194
164- // Extended euclidean algorithm
165- private static long [] egcd (long a , long b ) {
166- if (b == 0 ) return new long [] {a , 1 , 0 };
167- else {
168- long [] ret = egcd (b , a % b );
169- long tmp = ret [1 ] - ret [2 ] * (a / b );
170- ret [1 ] = ret [2 ];
171- ret [2 ] = tmp ;
172- return ret ;
195+ /**
196+ * Deterministic Miller-Rabin primality test, correct for all long values. Uses 12 witnesses
197+ * guaranteed correct for n < 3.317 × 10^24.
198+ */
199+ private static boolean isPrime (long n ) {
200+ if (n < 2 )
201+ return false ;
202+ if (n < 4 )
203+ return true ;
204+ if (n % 2 == 0 || n % 3 == 0 )
205+ return false ;
206+
207+ long d = n - 1 ;
208+ int r = Long .numberOfTrailingZeros (d );
209+ d >>= r ;
210+
211+ for (long a : new long [] {2 , 3 , 5 , 7 , 11 , 13 , 17 , 19 , 23 , 29 , 31 , 37 }) {
212+ if (a >= n )
213+ continue ;
214+ long x = powMod (a , d , n );
215+ if (x == 1 || x == n - 1 )
216+ continue ;
217+ boolean composite = true ;
218+ for (int i = 0 ; i < r - 1 ; i ++) {
219+ x = mulMod (x , x , n );
220+ if (x == n - 1 ) {
221+ composite = false ;
222+ break ;
223+ }
224+ }
225+ if (composite )
226+ return false ;
173227 }
228+ return true ;
174229 }
175230
176- private static long gcf (long a , long b ) {
177- return b == 0 ? a : gcf (b , a % b );
231+ private static long powMod (long base , long exp , long mod ) {
232+ long result = 1 ;
233+ base %= mod ;
234+ while (exp > 0 ) {
235+ if ((exp & 1 ) == 1 )
236+ result = mulMod (result , base , mod );
237+ exp >>= 1 ;
238+ base = mulMod (base , base , mod );
239+ }
240+ return result ;
178241 }
179242
180- private static boolean isPrime (long n ) {
181- if (n < 2 ) return false ;
182- if (n == 2 || n == 3 ) return true ;
183- if (n % 2 == 0 || n % 3 == 0 ) return false ;
184-
185- int limit = (int ) Math .sqrt (n );
186-
187- for (int i = 5 ; i <= limit ; i += 6 ) if (n % i == 0 || n % (i + 2 ) == 0 ) return false ;
243+ private static long mulMod (long a , long b , long mod ) {
244+ return java .math .BigInteger .valueOf (a )
245+ .multiply (java .math .BigInteger .valueOf (b ))
246+ .mod (java .math .BigInteger .valueOf (mod ))
247+ .longValue ();
248+ }
188249
189- return true ;
250+ private static long gcd (long a , long b ) {
251+ return b == 0 ? a : gcd (b , a % b );
190252 }
191253}
0 commit comments