⬆️ ⬇️

Multiplication of long numbers by the Karatsuba method

The other day it was necessary to deal with this algorithm, but a cursory search in google did not give anything worthwhile. On HabrΓ©, too, there was only one article that did not really help me. Having understood, I will try to share with the public in an accessible form:



Algorithm



The Karatsuba algorithm is a fast multiplication method with the complexity of calculating n log 2 3 . At that time, as a naive algorithm, multiplication in a column requires n 2 operations. It should be noted that when the length of the numbers is shorter than a few dozen characters (or rather determined experimentally), ordinary multiplication works faster.

Imagine that there are two numbers A and B of length n in some number system BASE:

A = a n-1 a n-2 ... a 0

B = b n-1 a n-2 ... a 0 , where a ? b ? - value acc. discharging numbers.

Each of them can be represented as the sum of two parts, halves of length m = n / 2 (if n is odd, then one part is shorter than the other by one digit:

A 0 = a m-1 a m-2 ... a 0

A 1 = a n-1 a n-2 ... a m

A = A 0 + A 1 * BASE m



B 0 = b m-1 b m-2 ... b 0

B 1 = b n-1 b n-2 ... b m

B = B 0 + B 1 * BASE m



Then: A * B = (A 0 + A 1 * BASE m ) * (B 0 + B 1 * BASE m ) = A 0 * B 0 + A 0 * B 1 * BASE m + A 1 * B 0 * BASE m + A 1 * B 1 * BASE 2 * m = A 0 * B 0 + ( A 0 * B 1 + A 1 * B 0 ) * BASE m + A 1 * B 1 * BASE 2 * m

Here you need 4 multiplication operations (parts of the formula * BASE ? * M are not multiplication, in fact, indicating the place where the result is written, the digit). But in other way:

(A 0 + A 1 ) * (B 0 + B 1 ) = A 0 * B 0 + A 0 * B 1 + A 1 * B 0 + A 1 * B 1

Looking at the highlighted parts in both formulas. After simple transformations, the number of multiplication operations can be reduced to the 3rd by replacing two multiplications by one and several addition and subtraction operations, the execution time of which is an order of magnitude less:

A 0 * B 1 + A 1 * B 0 = (A 0 + A 1 ) * (B 0 + B 1 ) - A 0 * B 0 - A 1 * B 1

')

The final look of the expression:

A * B = A 0 * B 0 + ((A 0 + A 1 ) * (B 0 + B 1 ) - A 0 * B 0 - A 1 * B 1 ) * BASE m + A 1 * B 1 * BASE 2 * m



Graphic representation:

multiplication circuit




Example



For example, multiply two eight-digit numbers in the decimal system 12345 and 98765:

image

The image clearly shows the recursive nature of the algorithm. For a number less than four digits in length, naive multiplication was applied.



C ++ implementation



Probably should start with how long numbers are stored. It is convenient to represent long numbers as arrays, where each element corresponds to a discharge, and the lower digits are stored in elements with smaller indices (that is, backwards), so it is more convenient to process them. For example:

int long_value[] = { 9, 8, 7, 6, 5, 4} // 456789

To increase performance, it is desirable to choose the maximum number within the base types for the base of the number system. But at the same time the following conditions are imposed on it:

  1. The square of the maximum number in the selected number system should be placed in the selected base type. It is necessary to store the product of one digit to another in intermediate calculations.
  2. The selected base type is desirable to take the sign. This will allow to get rid of several intermediate normalizations.
  3. It is better that the sum of several squares of the maximum number be placed in the discharge. This will get rid of several intermediate normalizations.




Below is the working function of multiplication with comments with all the necessary auxiliary declarations and functions. For better performance, you should change the base of the number system, the type for storing the digits, and uncomment a small code snippet at the place responsible for the naive multiplication:

  1. #include <cstring>
  2. #define BASE 10 // number system
  3. #define MIN_LENGTH_FOR_KARATSUBA 4 // numbers are shorter multiplied by a quadratic algorithm
  4. typedef int digit; // taken only for digits
  5. typedef unsigned long int size_length; // type for long numbers
  6. using namespace std;
  7. struct long_value { // type for long numbers
  8. digit * values; // array with numbers written in reverse order
  9. size_length length; // is long numbers
  10. };
  11. long_value sum (long_value a, long_value b) {
  12. / * function to add two long numbers. If numbers of different length are added together.
  13. * then the longer is passed as the first argument. Returns new
  14. * unnormalized number.
  15. * /
  16. long_value s;
  17. s.length = a.length + 1;
  18. s.values ​​= new digit [s.length];
  19. s.values ​​[a.length - 1] = a.values ​​[a.length - 1];
  20. s.values ​​[a.length] = 0;
  21. for (size_length i = 0; i <b.length; ++ i)
  22. s.values ​​[i] = a.values ​​[i] + b.values ​​[i];
  23. return s;
  24. }
  25. long_value & sub (long_value & a, long_value b) {
  26. / * function to subtract one long number from another. Changes the contents of the first
  27. * numbers. Returns a link to the first number. The result is not normalized.
  28. * /
  29. for (size_length i = 0; i <b.length; ++ i)
  30. a.values ​​[i] - = b.values ​​[i];
  31. return a;
  32. }
  33. void normalize (long_value l) {
  34. / * Normalization of the number - bringing each digit in accordance with the number system.
  35. *
  36. * /
  37. for (size_length i = 0; i <l.length - 1; ++ i) {
  38. if (l.values ​​[i]> = BASE) { // if the number is greater than the maximum, then a transfer is organized
  39. digit carryover = l.values ​​[i] / BASE;
  40. l.values ​​[i + 1] + = carryover;
  41. l.values ​​[i] - = carryover * BASE;
  42. } else if (l.values ​​[i] <0) { // if less - loan
  43. digit carryover = (l.values ​​[i] + 1) / BASE - 1;
  44. l.values ​​[i + 1] + = carryover;
  45. l.values ​​[i] - = carryover * BASE;
  46. }
  47. }
  48. }
  49. long_value karatsuba (long_value a, long_value b) {
  50. long_value product; // resulting product
  51. product.length = a.length + b.length;
  52. product.values ​​= new digit [product.length];
  53. if (a.length <MIN_LENGTH_FOR_KARATSUBA) { // if the number is shorter then apply a naive multiplication
  54. memset (product.values, 0, sizeof (digit) * product.length);
  55. for (size_length i = 0; i <a.length; ++ i)
  56. for (size_length j = 0; j <b.length; ++ j) {
  57. product.values ​​[i + j] + = a.values ​​[i] * b.values ​​[j];
  58. / * If you change MIN_LENGTH_FOR_KARATSUBA or BASE, uncomment the following
  59. * lines and pick up acc. values ​​for avoiding overflow discharges.
  60. * For example, for the decimal number system, the number 100 means that it is organized
  61. * transfer 1 through one digit, 200 - transfer 2 through one digit, 5000 - 5 through two.
  62. * if (product.values ​​[i + j]> = 100) {
  63. * product.values ​​[i + j] - = 100;
  64. * product.values ​​[i + j + 2] + = 1;
  65. *}
  66. * /
  67. }
  68. } else { // multiplication by the Karatsuba method
  69. long_value a_part1; // the younger part of a
  70. a_part1.values ​​= a.values;
  71. a_part1.length = (a.length + 1) / 2;
  72. long_value a_part2; // the upper half of a
  73. a_part2.values ​​= a.values ​​+ a_part1.length;
  74. a_part2.length = a.length / 2;
  75. long_value b_part1; // the younger part of the number b
  76. b_part1.values ​​= b.values;
  77. b_part1.length = (b.length + 1) / 2;
  78. long_value b_part2; // the highest part of the number b
  79. b_part2.values ​​= b.values ​​+ b_part1.length;
  80. b_part2.length = b.length / 2;
  81. long_value sum_of_a_parts = sum (a_part1, a_part2); // sum of the parts of a
  82. normalize (sum_of_a_parts);
  83. long_value sum_of_b_parts = sum (b_part1, b_part2); // sum of parts of number b
  84. normalize (sum_of_b_parts);
  85. long_value product_of_sums_of_parts = karatsuba (sum_of_a_parts, sum_of_b_parts);
  86. // product of parts sums
  87. long_value product_of_first_parts = karatsuba (a_part1, b_part1); // junior member
  88. long_value product_of_second_parts = karatsuba (a_part2, b_part2); // senior member
  89. long_value sum_of_middle_terms = sub (sub (product_of_sums_of_parts, product_of_first_parts), product_of_second_parts);
  90. // find the sum of average members
  91. / *
  92. * Summation of a polynomial
  93. * /
  94. memcpy (product.values, product_of_first_parts.values,
  95. product_of_first_parts.length * sizeof (digit));
  96. memcpy (product.values ​​+ product_of_first_parts.length,
  97. product_of_second_parts.values, product_of_second_parts.length
  98. * sizeof (digit));
  99. for (size_length i = 0; i <sum_of_middle_terms.length; ++ i)
  100. product.values ​​[a_part1.length + i] + = sum_of_middle_terms.values ​​[i];
  101. / *
  102. * Stripping
  103. * /
  104. delete [] sum_of_a_parts.values;
  105. delete [] sum_of_b_parts.values;
  106. delete [] product_of_sums_of_parts.values;
  107. delete [] product_of_first_parts.values;
  108. delete [] product_of_second_parts.values;
  109. }
  110. normalize (product); // final number normalization
  111. return product;
  112. }
* This source code was highlighted with Source Code Highlighter .

Source: https://habr.com/ru/post/124258/



All Articles