diff --git a/mrbgems/mruby-bigint/core/bigint.c b/mrbgems/mruby-bigint/core/bigint.c index b4385625eb..894f75ac26 100644 --- a/mrbgems/mruby-bigint/core/bigint.c +++ b/mrbgems/mruby-bigint/core/bigint.c @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include "bigint.h" @@ -79,8 +80,8 @@ pool_restore(mpz_ctx_t *ctx, size_t state) /* Forward declarations */ static void mpz_mul_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e); -static void mpz_sub(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mpz_t *y); -static void mpz_add_int(mpz_ctx_t *ctx, mpz_t *x, mrb_int y); +static void mpz_div_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e); +static void mpz_mod_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e); static void mpz_set_int(mpz_ctx_t *ctx, mpz_t *y, mrb_int v); static mp_limb* @@ -432,8 +433,8 @@ uadd(mpz_t *z, mpz_t *x, mpz_t *y) c >>= DIG_SIZE; } - /* Store final carry */ - z->p[y->sz] = (mp_limb)c; + /* Store final carry at correct position (after all limbs) */ + z->p[i] = (mp_limb)c; } /* z = y - x, ignoring sign */ @@ -979,6 +980,72 @@ mpz_mul_basic_limbs(mp_limb *result, const mp_limb *x, size_t x_len, } } +/* + * Schoolbook squaring - exploits symmetry for ~1.5x speedup. + * + * For x = [x0, x1, x2, ...], x^2 has terms: + * - Diagonal: xi^2 (computed once) + * - Off-diagonal: 2*xi*xj for i> DIG_SIZE) + HIGH(sq); + + for (size_t k = 2*i + 1; acc && k < result_len; k++) { + acc += result[k]; + result[k] = LOW(acc); + acc >>= DIG_SIZE; + } + } +} + /* Calculate scratch space needed for Karatsuba */ static size_t karatsuba_scratch_size(size_t x_len, size_t y_len) @@ -1115,6 +1182,113 @@ mpz_mul_karatsuba(mpz_ctx_t *ctx, mp_limb *result, limb_add_at(result, result_len, z2, z2_len, 2 * half); } +/* Calculate scratch space needed for Karatsuba squaring */ +static size_t +karatsuba_sqr_scratch_size(size_t n) +{ + if (n < KARATSUBA_THRESHOLD) { + return 0; + } + + size_t half = n / 2; + size_t x1_len = n - half; + size_t sum_len = x1_len + 1; + + /* z0, z2, z1 for the three recursive results, plus sum_x */ + size_t z0_len = 2 * half; + size_t z2_len = 2 * x1_len; + size_t z1_len = 2 * sum_len; + + size_t current = z0_len + z2_len + z1_len + sum_len; + size_t sub = karatsuba_sqr_scratch_size(sum_len); + size_t sub2 = karatsuba_sqr_scratch_size(x1_len); + size_t sub3 = karatsuba_sqr_scratch_size(half); + size_t max_sub = sub > sub2 ? sub : sub2; + max_sub = max_sub > sub3 ? max_sub : sub3; + + return current + max_sub; +} + +/* + * Karatsuba squaring - uses 3 recursive squarings instead of 3 multiplications. + * + * x = x1 * B + x0 + * x^2 = x1^2 * B^2 + 2*x1*x0 * B + x0^2 + * + * Using Karatsuba trick: + * z0 = x0^2 + * z2 = x1^2 + * z1 = (x0 + x1)^2 - z0 - z2 (= 2*x0*x1) + * x^2 = z2 * B^2 + z1 * B + z0 + */ +static void +mpz_sqr_karatsuba(mpz_ctx_t *ctx, mp_limb *result, const mp_limb *x, size_t n, + mp_limb *scratch) +{ + if (n < KARATSUBA_THRESHOLD) { + mpz_sqr_basic_limbs(result, x, n); + return; + } + + size_t half = n / 2; + const mp_limb *x0 = x; + const mp_limb *x1 = x + half; + size_t x0_len = half; + size_t x1_len = n - half; + + /* Partition scratch memory */ + size_t offset = 0; + mp_limb *z0 = scratch + offset; offset += 2 * x0_len; + mp_limb *z2 = scratch + offset; offset += 2 * x1_len; + mp_limb *sum_x = scratch + offset; offset += x1_len + 1; + mp_limb *z1 = scratch + offset; + size_t z1_alloc_len = 2 * (x1_len + 1); + offset += z1_alloc_len; + + /* Step 1: Compute sum x0 + x1 */ + mp_limb carry = 0; + size_t i; + for (i = 0; i < x1_len; i++) { + mp_dbl_limb sum = (mp_dbl_limb)(i < x0_len ? x0[i] : 0) + (mp_dbl_limb)x1[i] + carry; + sum_x[i] = LOW(sum); + carry = HIGH(sum); + } + sum_x[i] = carry; + size_t sum_x_len = x1_len + (carry != 0); + + /* Step 2: Recursive squarings */ + mp_limb *recursive_scratch = scratch + offset; + mpz_sqr_karatsuba(ctx, z0, x0, x0_len, recursive_scratch); + mpz_sqr_karatsuba(ctx, z2, x1, x1_len, recursive_scratch); + mpz_sqr_karatsuba(ctx, z1, sum_x, sum_x_len, recursive_scratch); + + /* Step 3: Compute z1 = z1 - z0 - z2 */ + size_t z0_len = 2 * x0_len; + size_t z2_len = 2 * x1_len; + size_t z1_len = 2 * sum_x_len; + + mp_limb borrow = limb_sub(z1, z0, z0_len); + for (i = z0_len; i < z1_len && borrow; i++) { + mp_dbl_limb_signed diff = (mp_dbl_limb_signed)z1[i] - borrow; + z1[i] = LOW(diff); + borrow = (diff < 0) ? 1 : 0; + } + + borrow = limb_sub(z1, z2, z2_len); + for (i = z2_len; i < z1_len && borrow; i++) { + mp_dbl_limb_signed diff = (mp_dbl_limb_signed)z1[i] - borrow; + z1[i] = LOW(diff); + borrow = (diff < 0) ? 1 : 0; + } + + /* Step 4: Final assembly: result = z0 + z1*B + z2*B^2 */ + size_t result_len = 2 * n; + limb_zero(result, result_len); + limb_copy(result, z0, z0_len); + limb_add_at(result, result_len, z1, z1_len, half); + limb_add_at(result, result_len, z2, z2_len, 2 * half); +} + /* * Check if mpz is "all ones" pattern (2^n - 1). * For such a number: @@ -1172,6 +1346,108 @@ mpz_power_of_2_exp(mpz_t *x) return (x->sz - 1) * DIG_SIZE + bit_pos; } +/* Count set bits in a limb */ +static int +limb_popcount(mp_limb x) +{ +#if defined(__GNUC__) || __has_builtin(__builtin_popcount) + if (sizeof(mp_limb) == sizeof(unsigned long long)) + return __builtin_popcountll(x); + else + return __builtin_popcount(x); +#else + int count = 0; + while (x) { + count++; + x &= x - 1; /* Clear lowest set bit */ + } + return count; +#endif +} + +/* + * Count total set bits in mpz. + * Returns popcount, or max_count+1 if exceeded (for early exit). + */ +static size_t +mpz_popcount(mpz_t *x, size_t max_count) +{ + if (x->sn <= 0 || x->sz == 0) return 0; + + size_t count = 0; + for (size_t i = 0; i < x->sz && count <= max_count; i++) { + count += limb_popcount(x->p[i]); + } + return count; +} + +/* Maximum bits for sparse multiplication optimization */ +#define SPARSE_MAX_BITS 8 + +/* + * Check if x is sparse (few bits set) and worth optimizing. + * Only worthwhile for large numbers where Karatsuba would be used. + * Returns popcount if sparse and optimizable, 0 otherwise. + */ +static size_t +mpz_sparse_p(mpz_t *x) +{ + if (x->sn <= 0 || x->sz < KARATSUBA_THRESHOLD) return 0; + + size_t popcount = mpz_popcount(x, SPARSE_MAX_BITS); + if (popcount > SPARSE_MAX_BITS) return 0; + + return popcount; +} + +/* + * Multiply sparse number by dense number using shift-add. + * sparse * dense = sum of (dense << bit_position) for each set bit + * + * O(k * n) where k = popcount, much faster than Karatsuba when k is small. + */ +static void +mpz_mul_sparse(mpz_ctx_t *ctx, mpz_t *w, mpz_t *sparse, mpz_t *dense) +{ + mpz_t shifted, temp; + + mpz_init(ctx, &shifted); + mpz_init(ctx, &temp); + zero(w); + + for (size_t i = 0; i < sparse->sz; i++) { + mp_limb limb = sparse->p[i]; + size_t base_bit = i * DIG_SIZE; + + while (limb) { + /* Find position of lowest set bit */ + int bit = 0; +#if defined(__GNUC__) || __has_builtin(__builtin_ctz) + if (sizeof(mp_limb) == sizeof(unsigned long long)) + bit = __builtin_ctzll(limb); + else + bit = __builtin_ctz(limb); +#else + while ((limb & ((mp_limb)1 << bit)) == 0) bit++; +#endif + + /* Add dense << (base_bit + bit) to result */ + mpz_mul_2exp(ctx, &shifted, dense, base_bit + bit); + mpz_add(ctx, &temp, w, &shifted); + mpz_set(ctx, w, &temp); + + /* Clear this bit */ + limb &= limb - 1; + } + } + + /* Handle sign */ + if (sparse->sn < 0) w->sn = -w->sn; + + mpz_clear(ctx, &shifted); + mpz_clear(ctx, &temp); +} + /* * Multiply two "all ones" numbers using algebraic identity: * (2^n - 1) * (2^m - 1) = 2^(n+m) - 2^n - 2^m + 1 @@ -1184,50 +1460,130 @@ mpz_power_of_2_exp(mpz_t *x) static void mpz_mul_all_ones(mpz_ctx_t *ctx, mpz_t *w, size_t n, size_t m) { - mpz_t a; + struct mrb_jmpbuf *prev_jmp = ctx->mrb->jmp; + struct mrb_jmpbuf c_jmp; + mpz_t a = {0, 0, 0}; + mpz_t b = {0, 0, 0}; - if (n == m) { - /* Squaring: (2^n - 1)^2 = 2^(2n) - 2^(n+1) + 1 */ - /* Start with 2^(2n) */ - mpz_init(ctx, &a); - mpz_set_int(ctx, &a, 1); - mpz_mul_2exp(ctx, w, &a, 2*n); + MRB_TRY(&c_jmp) { + ctx->mrb->jmp = &c_jmp; - /* Subtract 2^(n+1) */ - mpz_set_int(ctx, &a, 1); - mpz_mul_2exp(ctx, &a, &a, n+1); - mpz_sub(ctx, w, w, &a); + if (n == m) { + /* Squaring: (2^n - 1)^2 = 2^(2n) - 2^(n+1) + 1 */ + /* Start with 2^(2n) */ + mpz_init(ctx, &a); + mpz_set_int(ctx, &a, 1); + mpz_mul_2exp(ctx, w, &a, 2*n); - /* Add 1 */ - mpz_add_int(ctx, w, 1); + /* Subtract 2^(n+1) */ + mpz_set_int(ctx, &a, 1); + mpz_mul_2exp(ctx, &a, &a, n+1); + mpz_sub(ctx, w, w, &a); + + /* Add 1 */ + mpz_add_int(ctx, w, 1); + } + else { + /* General: (2^n - 1) * (2^m - 1) = 2^(n+m) - 2^n - 2^m + 1 */ + mpz_init(ctx, &a); + mpz_init(ctx, &b); + + /* Start with 2^(n+m) */ + mpz_set_int(ctx, &a, 1); + mpz_mul_2exp(ctx, w, &a, n+m); + + /* Subtract 2^n */ + mpz_set_int(ctx, &a, 1); + mpz_mul_2exp(ctx, &a, &a, n); + mpz_sub(ctx, w, w, &a); + + /* Subtract 2^m */ + mpz_set_int(ctx, &b, 1); + mpz_mul_2exp(ctx, &b, &b, m); + mpz_sub(ctx, w, w, &b); + + /* Add 1 */ + mpz_add_int(ctx, w, 1); + } + + ctx->mrb->jmp = prev_jmp; + mpz_clear(ctx, &a); + mpz_clear(ctx, &b); + } + MRB_CATCH(&c_jmp) { + ctx->mrb->jmp = prev_jmp; mpz_clear(ctx, &a); + mpz_clear(ctx, &b); + MRB_THROW(ctx->mrb->jmp); } - else { - /* General: (2^n - 1) * (2^m - 1) = 2^(n+m) - 2^n - 2^m + 1 */ - mpz_t b; - mpz_init(ctx, &a); - mpz_init(ctx, &b); + MRB_END_EXC(&c_jmp); +} - /* Start with 2^(n+m) */ - mpz_set_int(ctx, &a, 1); - mpz_mul_2exp(ctx, w, &a, n+m); +/* w = u^2 (squaring - faster than general multiplication) */ +static void +mpz_sqr(mpz_ctx_t *ctx, mpz_t *ww, mpz_t *u) +{ + if (zero_p(u)) { + zero(ww); + return; + } - /* Subtract 2^n */ - mpz_set_int(ctx, &a, 1); - mpz_mul_2exp(ctx, &a, &a, n); - mpz_sub(ctx, w, w, &a); + /* Fast path for power of 2: (2^n)^2 = 2^(2n) */ + size_t u_pow2 = mpz_power_of_2_exp(u); + if (u_pow2) { + mpz_t one; + mpz_init(ctx, &one); + mpz_set_int(ctx, &one, 1); + mpz_mul_2exp(ctx, ww, &one, 2 * u_pow2); + mpz_clear(ctx, &one); + return; + } - /* Subtract 2^m */ - mpz_set_int(ctx, &b, 1); - mpz_mul_2exp(ctx, &b, &b, m); - mpz_sub(ctx, w, w, &b); + /* Fast path for all-ones: (2^n - 1)^2 = 2^(2n) - 2^(n+1) + 1 */ + size_t u_ones = mpz_all_ones_p(u); + if (u_ones) { + mpz_mul_all_ones(ctx, ww, u_ones, u_ones); + return; + } - /* Add 1 */ - mpz_add_int(ctx, w, 1); + /* Use schoolbook squaring for small numbers */ + if (u->sz < KARATSUBA_THRESHOLD) { + mpz_t w; + mpz_init_heap(ctx, &w, 2 * u->sz); + mpz_sqr_basic_limbs(w.p, u->p, u->sz); + w.sz = 2 * u->sz; + w.sn = 1; /* Square is always positive */ + trim(&w); + mpz_move(ctx, ww, &w); + return; + } - mpz_clear(ctx, &a); - mpz_clear(ctx, &b); + /* Karatsuba squaring for large numbers */ + size_t result_size = 2 * u->sz; + mpz_realloc(ctx, ww, result_size); + + size_t scratch_size = karatsuba_sqr_scratch_size(u->sz); + scratch_size += (scratch_size >> 3) + 16; + size_t pool_state = pool_save(ctx); + mp_limb *scratch = NULL; + + if (MPZ_HAS_POOL(ctx)) { + scratch = pool_alloc(MPZ_POOL(ctx), scratch_size); + } + + if (scratch) { + mpz_sqr_karatsuba(ctx, ww->p, u->p, u->sz, scratch); + pool_restore(ctx, pool_state); } + else { + scratch = (mp_limb*)mrb_malloc(MPZ_MRB(ctx), scratch_size * sizeof(mp_limb)); + mpz_sqr_karatsuba(ctx, ww->p, u->p, u->sz, scratch); + mrb_free(MPZ_MRB(ctx), scratch); + } + + ww->sz = result_size; + ww->sn = 1; /* Square is always positive */ + trim(ww); } /* w = u * v */ @@ -1239,6 +1595,12 @@ mpz_mul(mpz_ctx_t *ctx, mpz_t *ww, mpz_t *u, mpz_t *v) return; } + /* Fast path for squaring: u * u uses optimized squaring algorithm */ + if (u == v) { + mpz_sqr(ctx, ww, u); + return; + } + /* Fast path for "all ones" numbers (2^n - 1) */ size_t u_ones = mpz_all_ones_p(u); size_t v_ones = mpz_all_ones_p(v); @@ -1277,6 +1639,18 @@ mpz_mul(mpz_ctx_t *ctx, mpz_t *ww, mpz_t *u, mpz_t *v) return; } + /* Fast path for sparse numbers (few bits set): use shift-add */ + size_t u_sparse = mpz_sparse_p(u); + if (u_sparse) { + mpz_mul_sparse(ctx, ww, u, v); + return; + } + size_t v_sparse = mpz_sparse_p(v); + if (v_sparse) { + mpz_mul_sparse(ctx, ww, v, u); + return; + } + if (!should_use_karatsuba(u->sz, v->sz)) { mpz_mul_basic(ctx, ww, u, v); return; @@ -1340,8 +1714,10 @@ urshift(mpz_ctx_t *ctx, mpz_t *c1, mpz_t *a, size_t n) { mrb_assert(n < DIG_SIZE); - if (n == 0) + if (n == 0) { mpz_set(ctx, c1, a); + trim(c1); + } else if (uzero_p(a)) { zero(c1); } @@ -1366,8 +1742,10 @@ static void ulshift(mpz_ctx_t *ctx, mpz_t *c1, mpz_t *a, size_t n) { mrb_assert(n < DIG_SIZE); - if (n == 0) + if (n == 0) { mpz_set(ctx, c1, a); + trim(c1); + } else if (uzero_p(a)) { zero(c1); } @@ -1523,8 +1901,6 @@ div_limb(mpz_ctx_t *ctx, mpz_t *q, mpz_t *r, mpz_t *x, mp_limb d) } -/* internal routine to compute x/y and x%y ignoring signs */ -/* qq = xx/yy; rr = xx%yy */ static void udiv(mpz_ctx_t *ctx, mpz_t *qq, mpz_t *rr, mpz_t *xx, mpz_t *yy) { @@ -1605,60 +1981,14 @@ udiv(mpz_ctx_t *ctx, mpz_t *qq, mpz_t *rr, mpz_t *xx, mpz_t *yy) rhat = dividend_val % z; } else { - /* Two limbs available - use enhanced estimation */ + /* Two limbs available - standard Knuth estimation */ mp_dbl_limb dividend_val = ((mp_dbl_limb)x.p[j+yd] << DIG_SIZE) + x.p[j+yd-1]; qhat = dividend_val / z; rhat = dividend_val % z; - - /* Three-limb pre-adjustment when available */ - if (yd >= 2 && j+yd-2 < x.sz && y.p[yd-2] != 0) { - mp_dbl_limb y_second = y.p[yd-2]; - mp_dbl_limb x_third = x.p[j+yd-2]; - - if (qhat > 0) { - mp_dbl_limb left = qhat * y_second; - mp_dbl_limb right = (rhat << DIG_SIZE) + x_third; - - if (qhat >= ((mp_dbl_limb)1 << DIG_SIZE) || left > right) { - qhat--; - rhat += z; - } - } - } } - /* Enhanced qhat refinement step */ - if (yd > 2) { // Now considering at least 3 limbs of divisor - mp_dbl_limb y_second = y.p[yd-2]; - mp_dbl_limb y_third = y.p[yd-3]; // New: third limb of divisor - mp_dbl_limb x_third = (j+yd-2 < x.sz) ? x.p[j+yd-2] : 0; - mp_dbl_limb x_fourth = (j+yd-3 < x.sz) ? x.p[j+yd-3] : 0; // New: fourth limb of dividend - - // Initial check with 2 limbs - mp_dbl_limb left_side = qhat * y_second; - mp_dbl_limb right_side = (rhat << DIG_SIZE) + x_third; - - while (qhat >= ((mp_dbl_limb)1 << DIG_SIZE) || (left_side > right_side)) { - qhat--; - rhat += z; - if (rhat >= ((mp_dbl_limb)1 << DIG_SIZE)) break; - left_side -= y_second; - right_side = (rhat << DIG_SIZE) + x_third; - } - - // Additional check with 3 limbs (new refinement) - left_side = qhat * y_third; - right_side = (rhat << DIG_SIZE) + x_fourth; - - while (qhat >= ((mp_dbl_limb)1 << DIG_SIZE) || (left_side > right_side)) { - qhat--; - rhat += z; - if (rhat >= ((mp_dbl_limb)1 << DIG_SIZE)) break; - left_side -= y_third; - right_side = (rhat << DIG_SIZE) + x_fourth; - } - } - else if (yd == 2) { // Original 2-limb check + /* Standard Knuth Algorithm D qhat refinement (2-limb check) */ + if (yd >= 2) { mp_dbl_limb y_second = y.p[yd-2]; mp_dbl_limb x_third = (j+yd-2 < x.sz) ? x.p[j+yd-2] : 0; mp_dbl_limb left_side = qhat * y_second; @@ -2076,6 +2406,246 @@ static const mp_limb base_limit[34*2] = { #endif }; +/* + * Divide-and-conquer decimal string conversion. + * For numbers with > DC_TO_S_THRESHOLD digits, this is O(n log^2 n) + * instead of O(n^2) for the simple algorithm. + */ +#define DC_TO_S_THRESHOLD 1000 + +/* + * Recursive D&C conversion helper. + * Converts x to decimal string, writing exactly num_digits characters. + * The caller must ensure num_digits >= actual digits in x. + * Leading zeros are added if x has fewer digits than num_digits. + */ +/* Batch divisor: 10^9 for extracting 9 digits at once */ +#define BATCH_DIVISOR 1000000000UL +#define BATCH_DIGITS 9 + +/* Lookup table for fast 2-digit conversion (Lemire's small table technique) */ +static const char digit_pairs[] = + "00010203040506070809" + "10111213141516171819" + "20212223242526272829" + "30313233343536373839" + "40414243444546474849" + "50515253545556575859" + "60616263646566676869" + "70717273747576777879" + "80818283848586878889" + "90919293949596979899"; + +static void +mpz_to_s_dc_recur(mpz_ctx_t *ctx, char *s, mpz_t *x, size_t num_digits, + mpz_t *pow5, size_t num_powers) +{ + /* Base case: use simple conversion for small numbers */ + if (num_digits <= DC_TO_S_THRESHOLD || num_powers == 0) { + /* Convert to string in reverse order using batch extraction */ + size_t pos = num_digits; + mpz_t tmp; + mpz_init_set(ctx, &tmp, x); + + while (pos > 0 && !zero_p(&tmp)) { + mpz_t q; + mpz_init_heap(ctx, &q, tmp.sz); + mp_dbl_limb r = 0; + + /* Divide by 10^9 to extract 9 digits at once */ + for (size_t i = tmp.sz; i > 0; i--) { + r = (r << DIG_SIZE) | tmp.p[i-1]; + q.p[i-1] = (mp_limb)(r / BATCH_DIVISOR); + r %= BATCH_DIVISOR; + } + q.sz = tmp.sz; + q.sn = tmp.sn; + trim(&q); + + /* Convert remainder (0-999999999) to 9 digits using table lookup */ + mp_limb batch = (mp_limb)r; + /* Extract last digit (9th) separately since 9 is odd */ + if (pos > 0) { + s[--pos] = '0' + (char)(batch % 10); + batch /= 10; + } + /* Extract remaining 8 digits as 4 pairs using lookup table */ + for (int d = 0; d < 4 && pos >= 2; d++) { + mp_limb pair = batch % 100; + batch /= 100; + s[--pos] = digit_pairs[pair * 2 + 1]; + s[--pos] = digit_pairs[pair * 2]; + } + + mpz_set(ctx, &tmp, &q); + mpz_clear(ctx, &q); + } + + /* Fill remaining positions with zeros */ + while (pos > 0) { + s[--pos] = '0'; + } + + mpz_clear(ctx, &tmp); + return; + } + + /* Find appropriate power of 10 to split on */ + /* We want the largest power that gives roughly half the digits */ + size_t split_idx = 0; + size_t split_digits = 1; + for (size_t i = 0; i < num_powers; i++) { + size_t d = (size_t)1 << i; /* digits for this power */ + if (d * 2 <= num_digits) { + split_idx = i; + split_digits = d; + } + } + + /* + * Optimization: Use the factorization 10^k = 2^k * 5^k + * + * Instead of dividing by 10^k directly: + * 1. Divide by 5^k (smaller divisor = faster division) + * 2. Use bit operations to handle the 2^k part + * + * If x = hi * 10^k + lo, and we compute q5 = x / 5^k, r5 = x % 5^k: + * hi = q5 >> k (right shift by k bits) + * lo = (q5 & ((1<> split_digits (divide by 2^k using bit shift) */ + mpz_div_2exp(ctx, &hi, &q5, (mrb_int)split_digits); + + /* Step 3: lo = (q5 mod 2^k) * 5^k + r5 */ + mpz_t q5_low; + mpz_init(ctx, &q5_low); + mpz_mod_2exp(ctx, &q5_low, &q5, (mrb_int)split_digits); + + mpz_mul(ctx, &lo, &q5_low, &pow5[split_idx]); + mpz_add(ctx, &lo, &lo, &r5); + lo.sn = (lo.sn < 0) ? -lo.sn : lo.sn; + + mpz_clear(ctx, &q5); + mpz_clear(ctx, &r5); + mpz_clear(ctx, &q5_low); + + /* Recursively convert high part */ + size_t hi_digits = num_digits - split_digits; + mpz_to_s_dc_recur(ctx, s, &hi, hi_digits, pow5, split_idx); + + /* Recursively convert low part (exactly split_digits digits with padding) */ + mpz_to_s_dc_recur(ctx, s + hi_digits, &lo, split_digits, pow5, split_idx); + + mpz_clear(ctx, &hi); + mpz_clear(ctx, &lo); +} + +/* + * D&C decimal string conversion entry point. + * Returns pointer to start of string (after optional sign). + * + * Optimization: Uses 10^k = 2^k * 5^k factorization. + * Dividing by 5^k is ~30% faster than dividing by 10^k because + * 5^k has fewer bits (2.32k vs 3.32k). The 2^k part is handled + * with fast bit shifts. + */ +static char* +mpz_to_s_dc(mpz_ctx_t *ctx, char *s, mpz_t *x) +{ + mrb_state *mrb = MPZ_MRB(ctx); + + /* Handle sign */ + char *result = s; + if (x->sn < 0) { + *s++ = '-'; + } + + /* Calculate number of decimal digits needed */ + /* Use log10(2) ≈ 0.30103, so bits * 0.30103 + 1 gives upper bound */ + size_t bits = digits(x) * DIG_SIZE; + size_t num_digits = (size_t)(bits * 30103UL / 100000UL) + 2; + + /* Build table of powers: 5^(2^k) for k = 0, 1, 2, ... + (10^k = 2^k * 5^k, and we handle 2^k with bit shifts) + Use stack allocation with zero-init to allow safe cleanup on exception. + 64 levels covers the full range of size_t on 64-bit systems. */ +#define MAX_POWERS 64 + mpz_t pow5[MAX_POWERS]; + mpz_t tmp; + size_t num_powers = 0; + + /* Zero-initialize so mpz_clear is safe on uninitialized entries */ + memset(pow5, 0, sizeof(pow5)); + memset(&tmp, 0, sizeof(tmp)); + + /* Use exception handling to ensure cleanup on error */ + struct mrb_jmpbuf *prev_jmp = mrb->jmp; + struct mrb_jmpbuf c_jmp; + + MRB_TRY(&c_jmp) { + mrb->jmp = &c_jmp; + + /* 5^1 */ + mpz_init(ctx, &pow5[0]); + mpz_set_int(ctx, &pow5[0], 5); + num_powers = 1; + + /* Build powers by squaring: 5^(2^k) = (5^(2^(k-1)))^2 */ + while (num_powers < MAX_POWERS) { + size_t power_digits = (size_t)1 << num_powers; + if (power_digits > num_digits) break; + + mpz_init(ctx, &pow5[num_powers]); + mpz_sqr(ctx, &pow5[num_powers], &pow5[num_powers - 1]); + num_powers++; + } + + /* Make a copy of x for conversion (to preserve original) */ + mpz_init_set(ctx, &tmp, x); + tmp.sn = 1; /* Work with absolute value */ + + /* Do the recursive conversion */ + mpz_to_s_dc_recur(ctx, s, &tmp, num_digits, pow5, num_powers); + + mrb->jmp = prev_jmp; + } MRB_CATCH(&c_jmp) { + mrb->jmp = prev_jmp; + /* Clean up on exception and re-throw */ + for (size_t i = 0; i < MAX_POWERS; i++) { + mpz_clear(ctx, &pow5[i]); + } + mpz_clear(ctx, &tmp); + MRB_THROW(prev_jmp); + } MRB_END_EXC(&c_jmp); + + /* Clean up on success */ + for (size_t i = 0; i < num_powers; i++) { + mpz_clear(ctx, &pow5[i]); + } + mpz_clear(ctx, &tmp); + + /* Remove leading zeros (but keep at least one digit) */ + char *p = s; + while (*p == '0' && *(p+1) != '\0') p++; + if (p > s) { + memmove(s, p, strlen(p) + 1); + } + + return result; +} + static char* mpz_get_str(mpz_ctx_t *ctx, char *s, mrb_int sz, mrb_int base, mpz_t *x) { @@ -2129,6 +2699,12 @@ mpz_get_str(mpz_ctx_t *ctx, char *s, mrb_int sz, mrb_int base, mpz_t *x) mrb_raise(mrb, E_RUNTIME_ERROR, "bigint size too large for string conversion"); } + /* Use D&C algorithm for large base-10 numbers */ + size_t est_digits = (size_t)(xlen * DIG_SIZE * 30103UL / 100000UL) + 2; + if (base == 10 && est_digits > DC_TO_S_THRESHOLD) { + return mpz_to_s_dc(ctx, s, x); + } + mp_limb *t = (mp_limb*)mrb_malloc(mrb, xlen * sizeof(mp_limb)); mp_limb *tend = t + xlen; @@ -2236,8 +2812,10 @@ mpz_get_int(mpz_t *y, mrb_int *v) static void mpz_mul_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e) { - if (e==0) + if (e==0) { mpz_set(ctx, z, x); + trim(z); + } else { short sn = x->sn; size_t digs = e / DIG_SIZE; @@ -2259,6 +2837,7 @@ mpz_mul_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e) } else { mpz_move(ctx, z, &y); + trim(z); } if (uzero_p(z)) z->sn = 0; @@ -2276,6 +2855,7 @@ mpz_div_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e) mpz_clear(ctx, z); mpz_init_heap(ctx, z, x->sz); mpz_set(ctx, z, x); + trim(z); } /* else: z == x, nothing to do */ } @@ -2303,6 +2883,7 @@ mpz_div_2exp(mpz_ctx_t *ctx, mpz_t *z, mpz_t *x, mrb_int e) } else { mpz_move(ctx, z, &y); + trim(z); } if (uzero_p(z)) z->sn = 0; diff --git a/mrbgems/mruby-rational/src/rational.c b/mrbgems/mruby-rational/src/rational.c index 7f51355943..bf81dab36e 100644 --- a/mrbgems/mruby-rational/src/rational.c +++ b/mrbgems/mruby-rational/src/rational.c @@ -588,32 +588,44 @@ rational_eq_b(mrb_state *mrb, mrb_value x, mrb_value y) switch (mrb_type(y)) { case MRB_TT_INTEGER: - if (p1->denominator != 1) return mrb_false_value(); - result = p1->numerator == mrb_integer(y); - break; + { + /* For bigint-backed rationals, check if denominator is 1 */ + mrb_value den = mrb_obj_value(p1->b.den); + mrb_int den_cmp = mrb_bint_cmp(mrb, den, mrb_int_value(mrb, 1)); + if (den_cmp != 0) return mrb_false_value(); + mrb_value num = mrb_obj_value(p1->b.num); + result = mrb_bint_cmp(mrb, num, y) == 0; + break; + } +#ifdef MRB_USE_BIGINT + case MRB_TT_BIGINT: + { + /* For bigint-backed rationals comparing with bigint */ + mrb_value den = mrb_obj_value(p1->b.den); + mrb_int den_cmp = mrb_bint_cmp(mrb, den, mrb_int_value(mrb, 1)); + if (den_cmp != 0) return mrb_false_value(); + mrb_value num = mrb_obj_value(p1->b.num); + result = mrb_bint_cmp(mrb, num, y) == 0; + break; + } +#endif #ifndef MRB_NO_FLOAT case MRB_TT_FLOAT: - result = ((double)p1->numerator/p1->denominator) == mrb_float(y); - break; + { + /* For bigint-backed rationals, convert to float and compare */ + mrb_float num_f = mrb_bint_as_float(mrb, mrb_obj_value(p1->b.num)); + mrb_float den_f = mrb_bint_as_float(mrb, mrb_obj_value(p1->b.den)); + result = (num_f / den_f) == mrb_float(y); + break; + } #endif case MRB_TT_RATIONAL: { - struct mrb_rational *p2 = rat_ptr(mrb, y); - mrb_int a, b; - - if (p1->numerator == p2->numerator && p1->denominator == p2->denominator) { - return mrb_true_value(); - } - if (mrb_int_mul_overflow(p1->numerator, p2->denominator, &a) || - mrb_int_mul_overflow(p2->numerator, p1->denominator, &b)) { -#ifdef MRB_NO_FLOAT - rat_overflow(mrb); -#else - result = (double)p1->numerator*p2->denominator == (double)p2->numerator*p2->denominator; - break; -#endif - } - result = a == b; + /* Compare by converting to float - less precise but safe */ + mrb_float v1 = mrb_bint_as_float(mrb, mrb_obj_value(p1->b.num)) / + mrb_bint_as_float(mrb, mrb_obj_value(p1->b.den)); + mrb_float v2 = rat_float(mrb, y); + result = v1 == v2; break; } @@ -657,6 +669,13 @@ rational_eq(mrb_state *mrb, mrb_value x) if (p1->denominator != 1) return mrb_false_value(); result = p1->numerator == mrb_integer(y); break; +#ifdef MRB_USE_BIGINT + case MRB_TT_BIGINT: + /* Non-bigint rational comparing with bigint */ + if (p1->denominator != 1) return mrb_false_value(); + result = mrb_bint_cmp(mrb, y, mrb_int_value(mrb, p1->numerator)) == 0; + break; +#endif #ifndef MRB_NO_FLOAT case MRB_TT_FLOAT: result = ((double)p1->numerator/p1->denominator) == mrb_float(y); diff --git a/src/numeric.c b/src/numeric.c index 5bbaeee1af..db0264ff47 100644 --- a/src/numeric.c +++ b/src/numeric.c @@ -2093,8 +2093,21 @@ cmpnum(mrb_state *mrb, mrb_value v1, mrb_value v2) else if (x < y) return -1; return 0; } +#ifdef MRB_USE_BIGINT + if (mrb_bigint_p(v2)) { + return -mrb_bint_cmp(mrb, v2, v1); + } +#endif x = (mrb_float)mrb_integer(v1); } +#ifdef MRB_USE_BIGINT + else if (mrb_bigint_p(v1)) { + if (mrb_integer_p(v2) || mrb_bigint_p(v2)) { + return mrb_bint_cmp(mrb, v1, v2); + } + x = mrb_as_float(mrb, v1); + } +#endif else { x = mrb_as_float(mrb, v1); }