|
1 | 1 | /** |
2 | | - * NOTE: An issue was found with this file when dealing with negative numbers when exponentiating! |
3 | | - * See bug tracking progress on issue |
| 2 | + * Computes modular exponentiation: a^n mod m. |
4 | 3 | * |
5 | | - * <p>An implementation of the modPow(a, n, mod) operation. This implementation is substantially |
6 | | - * faster than Java's BigInteger class because it only uses primitive types. |
| 4 | + * Supports negative exponents via modular inverse (requires gcd(a, m) = 1) and negative bases. |
| 5 | + * Uses overflow-safe modular multiplication to handle the full range of long values. |
7 | 6 | * |
8 | | - * <p>Time Complexity O(lg(n)) |
| 7 | + * Time Complexity: O(log(n)) |
9 | 8 | * |
10 | 9 | * @author William Fiset, william.alexandre.fiset@gmail.com |
11 | 10 | */ |
12 | 11 | package com.williamfiset.algorithms.math; |
13 | 12 |
|
14 | | -import java.math.BigInteger; |
15 | | - |
16 | 13 | public class ModPow { |
17 | 14 |
|
18 | | - // The values placed into the modPow function cannot be greater |
19 | | - // than MAX or less than MIN otherwise long overflow will |
20 | | - // happen when the values get squared (they will exceed 2^63-1) |
21 | | - private static final long MAX = (long) Math.sqrt(Long.MAX_VALUE); |
22 | | - private static final long MIN = -MAX; |
23 | | - |
24 | | - // Computes the Greatest Common Divisor (GCD) of a & b |
25 | | - private static long gcd(long a, long b) { |
26 | | - return b == 0 ? (a < 0 ? -a : a) : gcd(b, a % b); |
27 | | - } |
28 | | - |
29 | | - // This function performs the extended euclidean algorithm on two numbers a and b. |
30 | | - // The function returns the gcd(a,b) as well as the numbers x and y such |
31 | | - // that ax + by = gcd(a,b). This calculation is important in number theory |
32 | | - // and can be used for several things such as finding modular inverses and |
33 | | - // solutions to linear Diophantine equations. |
34 | | - private static long[] egcd(long a, long b) { |
35 | | - if (b == 0) return new long[] {a < 0 ? -a : a, 1L, 0L}; |
36 | | - long[] v = egcd(b, a % b); |
37 | | - long tmp = v[1] - v[2] * (a / b); |
38 | | - v[1] = v[2]; |
39 | | - v[2] = tmp; |
40 | | - return v; |
41 | | - } |
42 | | - |
43 | | - // Returns the modular inverse of 'a' mod 'm' |
44 | | - // Make sure m > 0 and 'a' & 'm' are relatively prime. |
45 | | - private static long modInv(long a, long m) { |
46 | | - |
47 | | - a = ((a % m) + m) % m; |
48 | | - |
49 | | - long[] v = egcd(a, m); |
50 | | - long x = v[1]; |
51 | | - |
52 | | - return ((x % m) + m) % m; |
53 | | - } |
54 | | - |
55 | | - // Computes a^n modulo mod very efficiently in O(lg(n)) time. |
56 | | - // This function supports negative exponent values and a negative |
57 | | - // base, however the modulus must be positive. |
| 15 | + /** |
| 16 | + * Computes a^n mod m. |
| 17 | + * |
| 18 | + * @throws ArithmeticException if mod <= 0, or if n < 0 and gcd(a, mod) != 1. |
| 19 | + */ |
58 | 20 | public static long modPow(long a, long n, long mod) { |
| 21 | + if (mod <= 0) |
| 22 | + throw new ArithmeticException("mod must be > 0"); |
59 | 23 |
|
60 | | - if (mod <= 0) throw new ArithmeticException("mod must be > 0"); |
61 | | - if (a > MAX || mod > MAX) |
62 | | - throw new IllegalArgumentException("Long overflow is upon you, mod or base is too high!"); |
63 | | - if (a < MIN || mod < MIN) |
64 | | - throw new IllegalArgumentException("Long overflow is upon you, mod or base is too low!"); |
65 | | - |
66 | | - // To handle negative exponents we can use the modular |
67 | | - // inverse of a to our advantage since: a^-n mod m = (a^-1)^n mod m |
| 24 | + // a^-n mod m = (a^-1)^n mod m |
68 | 25 | if (n < 0) { |
69 | 26 | if (gcd(a, mod) != 1) |
70 | 27 | throw new ArithmeticException("If n < 0 then must have gcd(a, mod) = 1"); |
71 | 28 | return modPow(modInv(a, mod), -n, mod); |
72 | 29 | } |
73 | 30 |
|
74 | | - if (n == 0L) return 1L; |
75 | | - long p = a, r = 1L; |
| 31 | + // Normalize base into [0, mod) |
| 32 | + a = ((a % mod) + mod) % mod; |
76 | 33 |
|
77 | | - for (long i = 0; n != 0; i++) { |
78 | | - long mask = 1L << i; |
79 | | - if ((n & mask) == mask) { |
80 | | - r = (((r * p) % mod) + mod) % mod; |
81 | | - n -= mask; |
82 | | - } |
83 | | - p = ((p * p) % mod + mod) % mod; |
| 34 | + long result = 1; |
| 35 | + while (n > 0) { |
| 36 | + if ((n & 1) == 1) |
| 37 | + result = mulMod(result, a, mod); |
| 38 | + a = mulMod(a, a, mod); |
| 39 | + n >>= 1; |
84 | 40 | } |
85 | | - |
86 | | - return ((r % mod) + mod) % mod; |
| 41 | + return result; |
87 | 42 | } |
88 | 43 |
|
89 | | - // Example usage |
90 | | - public static void main(String[] args) { |
91 | | - |
92 | | - BigInteger A, N, M, r1; |
93 | | - long a, n, m, r2; |
94 | | - |
95 | | - A = BigInteger.valueOf(3); |
96 | | - N = BigInteger.valueOf(4); |
97 | | - M = BigInteger.valueOf(1000000); |
98 | | - a = A.longValue(); |
99 | | - n = N.longValue(); |
100 | | - m = M.longValue(); |
101 | | - |
102 | | - // 3 ^ 4 mod 1000000 |
103 | | - r1 = A.modPow(N, M); // 81 |
104 | | - r2 = modPow(a, n, m); // 81 |
105 | | - System.out.println(r1 + " " + r2); |
106 | | - |
107 | | - A = BigInteger.valueOf(-45); |
108 | | - N = BigInteger.valueOf(12345); |
109 | | - M = BigInteger.valueOf(987654321); |
110 | | - a = A.longValue(); |
111 | | - n = N.longValue(); |
112 | | - m = M.longValue(); |
113 | | - |
114 | | - // Finds -45 ^ 12345 mod 987654321 |
115 | | - r1 = A.modPow(N, M); // 323182557 |
116 | | - r2 = modPow(a, n, m); // 323182557 |
117 | | - System.out.println(r1 + " " + r2); |
118 | | - |
119 | | - A = BigInteger.valueOf(6); |
120 | | - N = BigInteger.valueOf(-66); |
121 | | - M = BigInteger.valueOf(101); |
122 | | - a = A.longValue(); |
123 | | - n = N.longValue(); |
124 | | - m = M.longValue(); |
125 | | - |
126 | | - // Finds 6 ^ -66 mod 101 |
127 | | - r1 = A.modPow(N, M); // 84 |
128 | | - r2 = modPow(a, n, m); // 84 |
129 | | - System.out.println(r1 + " " + r2); |
130 | | - |
131 | | - A = BigInteger.valueOf(-5); |
132 | | - N = BigInteger.valueOf(-7); |
133 | | - M = BigInteger.valueOf(1009); |
134 | | - a = A.longValue(); |
135 | | - n = N.longValue(); |
136 | | - m = M.longValue(); |
137 | | - |
138 | | - // Finds -5 ^ -7 mod 1009 |
139 | | - r1 = A.modPow(N, M); // 675 |
140 | | - r2 = modPow(a, n, m); // 675 |
141 | | - System.out.println(r1 + " " + r2); |
142 | | - |
143 | | - for (int i = 0; i < 1000; i++) { |
144 | | - A = BigInteger.valueOf(a); |
145 | | - N = BigInteger.valueOf(n); |
146 | | - M = BigInteger.valueOf(m); |
147 | | - a = Math.random() < 0.5 ? randLong(MAX) : -randLong(MAX); |
148 | | - n = randLong(); |
149 | | - m = randLong(MAX); |
150 | | - try { |
151 | | - r1 = A.modPow(N, M); |
152 | | - r2 = modPow(a, n, m); |
153 | | - if (r1.longValue() != r2) |
154 | | - System.out.printf("Broke with: a = %d, n = %d, m = %d\n", a, n, m); |
155 | | - } catch (ArithmeticException e) { |
156 | | - } |
157 | | - } |
| 44 | + private static long modInv(long a, long m) { |
| 45 | + a = ((a % m) + m) % m; |
| 46 | + long x = egcd(a, m)[1]; |
| 47 | + return ((x % m) + m) % m; |
158 | 48 | } |
159 | 49 |
|
160 | | - /* TESTING RELATED METHODS */ |
161 | | - |
162 | | - static final java.util.Random RANDOM = new java.util.Random(); |
| 50 | + private static long[] egcd(long a, long b) { |
| 51 | + if (b == 0) |
| 52 | + return new long[] {a < 0 ? -a : a, 1L, 0L}; |
| 53 | + long[] v = egcd(b, a % b); |
| 54 | + long tmp = v[1] - v[2] * (a / b); |
| 55 | + v[1] = v[2]; |
| 56 | + v[2] = tmp; |
| 57 | + return v; |
| 58 | + } |
163 | 59 |
|
164 | | - // Returns long between [1, bound] |
165 | | - public static long randLong(long bound) { |
166 | | - return java.util.concurrent.ThreadLocalRandom.current().nextLong(1, bound + 1); |
| 60 | + private static long gcd(long a, long b) { |
| 61 | + a = Math.abs(a); |
| 62 | + b = Math.abs(b); |
| 63 | + return b == 0 ? a : gcd(b, a % b); |
167 | 64 | } |
168 | 65 |
|
169 | | - public static long randLong() { |
170 | | - return RANDOM.nextLong(); |
| 66 | + /** Overflow-safe modular multiplication: (a * b) % mod. */ |
| 67 | + private static long mulMod(long a, long b, long mod) { |
| 68 | + return java.math.BigInteger.valueOf(a) |
| 69 | + .multiply(java.math.BigInteger.valueOf(b)) |
| 70 | + .mod(java.math.BigInteger.valueOf(mod)) |
| 71 | + .longValue(); |
171 | 72 | } |
172 | 73 | } |
0 commit comments