Skip to content

Commit 0fcf135

Browse files
williamfisetclaude
andauthored
Refactor ModPow: fix overflow, simplify loop, add tests (#1315)
Replace fragile MAX/MIN overflow guards with overflow-safe mulMod via BigInteger, supporting the full long range. Simplify the bit-mask exponentiation loop to standard binary exponentiation. Remove inline main() test code and add 11 JUnit 5 tests including a randomized check against BigInteger.modPow. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0e0dc06 commit 0fcf135

File tree

3 files changed

+155
-143
lines changed

3 files changed

+155
-143
lines changed
Lines changed: 44 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,172 +1,73 @@
11
/**
2-
* NOTE: An issue was found with this file when dealing with negative numbers when exponentiating!
3-
* See bug tracking progress on issue
2+
* Computes modular exponentiation: a^n mod m.
43
*
5-
* <p>An implementation of the modPow(a, n, mod) operation. This implementation is substantially
6-
* faster than Java's BigInteger class because it only uses primitive types.
4+
* Supports negative exponents via modular inverse (requires gcd(a, m) = 1) and negative bases.
5+
* Uses overflow-safe modular multiplication to handle the full range of long values.
76
*
8-
* <p>Time Complexity O(lg(n))
7+
* Time Complexity: O(log(n))
98
*
109
* @author William Fiset, william.alexandre.fiset@gmail.com
1110
*/
1211
package com.williamfiset.algorithms.math;
1312

14-
import java.math.BigInteger;
15-
1613
public class ModPow {
1714

18-
// The values placed into the modPow function cannot be greater
19-
// than MAX or less than MIN otherwise long overflow will
20-
// happen when the values get squared (they will exceed 2^63-1)
21-
private static final long MAX = (long) Math.sqrt(Long.MAX_VALUE);
22-
private static final long MIN = -MAX;
23-
24-
// Computes the Greatest Common Divisor (GCD) of a & b
25-
private static long gcd(long a, long b) {
26-
return b == 0 ? (a < 0 ? -a : a) : gcd(b, a % b);
27-
}
28-
29-
// This function performs the extended euclidean algorithm on two numbers a and b.
30-
// The function returns the gcd(a,b) as well as the numbers x and y such
31-
// that ax + by = gcd(a,b). This calculation is important in number theory
32-
// and can be used for several things such as finding modular inverses and
33-
// solutions to linear Diophantine equations.
34-
private static long[] egcd(long a, long b) {
35-
if (b == 0) return new long[] {a < 0 ? -a : a, 1L, 0L};
36-
long[] v = egcd(b, a % b);
37-
long tmp = v[1] - v[2] * (a / b);
38-
v[1] = v[2];
39-
v[2] = tmp;
40-
return v;
41-
}
42-
43-
// Returns the modular inverse of 'a' mod 'm'
44-
// Make sure m > 0 and 'a' & 'm' are relatively prime.
45-
private static long modInv(long a, long m) {
46-
47-
a = ((a % m) + m) % m;
48-
49-
long[] v = egcd(a, m);
50-
long x = v[1];
51-
52-
return ((x % m) + m) % m;
53-
}
54-
55-
// Computes a^n modulo mod very efficiently in O(lg(n)) time.
56-
// This function supports negative exponent values and a negative
57-
// base, however the modulus must be positive.
15+
/**
16+
* Computes a^n mod m.
17+
*
18+
* @throws ArithmeticException if mod <= 0, or if n < 0 and gcd(a, mod) != 1.
19+
*/
5820
public static long modPow(long a, long n, long mod) {
21+
if (mod <= 0)
22+
throw new ArithmeticException("mod must be > 0");
5923

60-
if (mod <= 0) throw new ArithmeticException("mod must be > 0");
61-
if (a > MAX || mod > MAX)
62-
throw new IllegalArgumentException("Long overflow is upon you, mod or base is too high!");
63-
if (a < MIN || mod < MIN)
64-
throw new IllegalArgumentException("Long overflow is upon you, mod or base is too low!");
65-
66-
// To handle negative exponents we can use the modular
67-
// inverse of a to our advantage since: a^-n mod m = (a^-1)^n mod m
24+
// a^-n mod m = (a^-1)^n mod m
6825
if (n < 0) {
6926
if (gcd(a, mod) != 1)
7027
throw new ArithmeticException("If n < 0 then must have gcd(a, mod) = 1");
7128
return modPow(modInv(a, mod), -n, mod);
7229
}
7330

74-
if (n == 0L) return 1L;
75-
long p = a, r = 1L;
31+
// Normalize base into [0, mod)
32+
a = ((a % mod) + mod) % mod;
7633

77-
for (long i = 0; n != 0; i++) {
78-
long mask = 1L << i;
79-
if ((n & mask) == mask) {
80-
r = (((r * p) % mod) + mod) % mod;
81-
n -= mask;
82-
}
83-
p = ((p * p) % mod + mod) % mod;
34+
long result = 1;
35+
while (n > 0) {
36+
if ((n & 1) == 1)
37+
result = mulMod(result, a, mod);
38+
a = mulMod(a, a, mod);
39+
n >>= 1;
8440
}
85-
86-
return ((r % mod) + mod) % mod;
41+
return result;
8742
}
8843

89-
// Example usage
90-
public static void main(String[] args) {
91-
92-
BigInteger A, N, M, r1;
93-
long a, n, m, r2;
94-
95-
A = BigInteger.valueOf(3);
96-
N = BigInteger.valueOf(4);
97-
M = BigInteger.valueOf(1000000);
98-
a = A.longValue();
99-
n = N.longValue();
100-
m = M.longValue();
101-
102-
// 3 ^ 4 mod 1000000
103-
r1 = A.modPow(N, M); // 81
104-
r2 = modPow(a, n, m); // 81
105-
System.out.println(r1 + " " + r2);
106-
107-
A = BigInteger.valueOf(-45);
108-
N = BigInteger.valueOf(12345);
109-
M = BigInteger.valueOf(987654321);
110-
a = A.longValue();
111-
n = N.longValue();
112-
m = M.longValue();
113-
114-
// Finds -45 ^ 12345 mod 987654321
115-
r1 = A.modPow(N, M); // 323182557
116-
r2 = modPow(a, n, m); // 323182557
117-
System.out.println(r1 + " " + r2);
118-
119-
A = BigInteger.valueOf(6);
120-
N = BigInteger.valueOf(-66);
121-
M = BigInteger.valueOf(101);
122-
a = A.longValue();
123-
n = N.longValue();
124-
m = M.longValue();
125-
126-
// Finds 6 ^ -66 mod 101
127-
r1 = A.modPow(N, M); // 84
128-
r2 = modPow(a, n, m); // 84
129-
System.out.println(r1 + " " + r2);
130-
131-
A = BigInteger.valueOf(-5);
132-
N = BigInteger.valueOf(-7);
133-
M = BigInteger.valueOf(1009);
134-
a = A.longValue();
135-
n = N.longValue();
136-
m = M.longValue();
137-
138-
// Finds -5 ^ -7 mod 1009
139-
r1 = A.modPow(N, M); // 675
140-
r2 = modPow(a, n, m); // 675
141-
System.out.println(r1 + " " + r2);
142-
143-
for (int i = 0; i < 1000; i++) {
144-
A = BigInteger.valueOf(a);
145-
N = BigInteger.valueOf(n);
146-
M = BigInteger.valueOf(m);
147-
a = Math.random() < 0.5 ? randLong(MAX) : -randLong(MAX);
148-
n = randLong();
149-
m = randLong(MAX);
150-
try {
151-
r1 = A.modPow(N, M);
152-
r2 = modPow(a, n, m);
153-
if (r1.longValue() != r2)
154-
System.out.printf("Broke with: a = %d, n = %d, m = %d\n", a, n, m);
155-
} catch (ArithmeticException e) {
156-
}
157-
}
44+
private static long modInv(long a, long m) {
45+
a = ((a % m) + m) % m;
46+
long x = egcd(a, m)[1];
47+
return ((x % m) + m) % m;
15848
}
15949

160-
/* TESTING RELATED METHODS */
161-
162-
static final java.util.Random RANDOM = new java.util.Random();
50+
private static long[] egcd(long a, long b) {
51+
if (b == 0)
52+
return new long[] {a < 0 ? -a : a, 1L, 0L};
53+
long[] v = egcd(b, a % b);
54+
long tmp = v[1] - v[2] * (a / b);
55+
v[1] = v[2];
56+
v[2] = tmp;
57+
return v;
58+
}
16359

164-
// Returns long between [1, bound]
165-
public static long randLong(long bound) {
166-
return java.util.concurrent.ThreadLocalRandom.current().nextLong(1, bound + 1);
60+
private static long gcd(long a, long b) {
61+
a = Math.abs(a);
62+
b = Math.abs(b);
63+
return b == 0 ? a : gcd(b, a % b);
16764
}
16865

169-
public static long randLong() {
170-
return RANDOM.nextLong();
66+
/** Overflow-safe modular multiplication: (a * b) % mod. */
67+
private static long mulMod(long a, long b, long mod) {
68+
return java.math.BigInteger.valueOf(a)
69+
.multiply(java.math.BigInteger.valueOf(b))
70+
.mod(java.math.BigInteger.valueOf(mod))
71+
.longValue();
17172
}
17273
}

src/test/java/com/williamfiset/algorithms/math/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,14 @@ java_test(
5858
runtime_deps = JUNIT5_RUNTIME_DEPS,
5959
deps = TEST_DEPS,
6060
)
61+
62+
# bazel test //src/test/java/com/williamfiset/algorithms/math:ModPowTest
63+
java_test(
64+
name = "ModPowTest",
65+
srcs = ["ModPowTest.java"],
66+
main_class = "org.junit.platform.console.ConsoleLauncher",
67+
use_testrunner = False,
68+
args = ["--select-class=com.williamfiset.algorithms.math.ModPowTest"],
69+
runtime_deps = JUNIT5_RUNTIME_DEPS,
70+
deps = TEST_DEPS,
71+
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package com.williamfiset.algorithms.math;
2+
3+
import static com.google.common.truth.Truth.assertThat;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import java.math.BigInteger;
7+
import java.util.concurrent.ThreadLocalRandom;
8+
import org.junit.jupiter.api.*;
9+
10+
public class ModPowTest {
11+
12+
@Test
13+
public void basicPositiveExponent() {
14+
// 3^4 mod 1000000 = 81
15+
assertThat(ModPow.modPow(3, 4, 1000000)).isEqualTo(81);
16+
}
17+
18+
@Test
19+
public void negativeBase() {
20+
// (-45)^12345 mod 987654321
21+
long expected =
22+
BigInteger.valueOf(-45)
23+
.modPow(BigInteger.valueOf(12345), BigInteger.valueOf(987654321))
24+
.longValue();
25+
assertThat(ModPow.modPow(-45, 12345, 987654321)).isEqualTo(expected);
26+
}
27+
28+
@Test
29+
public void negativeExponent() {
30+
// 6^-66 mod 101 = 84
31+
long expected =
32+
BigInteger.valueOf(6)
33+
.modPow(BigInteger.valueOf(-66), BigInteger.valueOf(101))
34+
.longValue();
35+
assertThat(ModPow.modPow(6, -66, 101)).isEqualTo(expected);
36+
}
37+
38+
@Test
39+
public void negativeBaseAndExponent() {
40+
// (-5)^-7 mod 1009
41+
long expected =
42+
BigInteger.valueOf(-5)
43+
.modPow(BigInteger.valueOf(-7), BigInteger.valueOf(1009))
44+
.longValue();
45+
assertThat(ModPow.modPow(-5, -7, 1009)).isEqualTo(expected);
46+
}
47+
48+
@Test
49+
public void exponentZero() {
50+
assertThat(ModPow.modPow(123, 0, 7)).isEqualTo(1);
51+
assertThat(ModPow.modPow(0, 0, 5)).isEqualTo(1);
52+
}
53+
54+
@Test
55+
public void baseZero() {
56+
assertThat(ModPow.modPow(0, 10, 7)).isEqualTo(0);
57+
}
58+
59+
@Test
60+
public void modOne() {
61+
// Anything mod 1 = 0
62+
assertThat(ModPow.modPow(999, 999, 1)).isEqualTo(0);
63+
}
64+
65+
@Test
66+
public void largeValues() {
67+
// Test with values that would overflow without safe multiplication
68+
long a = 1_000_000_000L;
69+
long n = 1_000_000_000L;
70+
long mod = 999_999_937L;
71+
long expected =
72+
BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue();
73+
assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected);
74+
}
75+
76+
@Test
77+
public void modNonPositiveThrows() {
78+
assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, 0));
79+
assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, -5));
80+
}
81+
82+
@Test
83+
public void negativeExponentNotCoprime() {
84+
// gcd(4, 8) = 4 ≠ 1, so no modular inverse
85+
assertThrows(ArithmeticException.class, () -> ModPow.modPow(4, -1, 8));
86+
}
87+
88+
@Test
89+
public void matchesBigIntegerRandomized() {
90+
ThreadLocalRandom rng = ThreadLocalRandom.current();
91+
for (int i = 0; i < 500; i++) {
92+
long a = rng.nextLong(-1_000_000_000L, 1_000_000_000L);
93+
long n = rng.nextLong(0, 1_000_000_000L);
94+
long mod = rng.nextLong(1, 1_000_000_000L);
95+
long expected =
96+
BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue();
97+
assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected);
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)