Shor's algorithm

View on QuantumAI Run in Google Colab View source on GitHub Download notebook

This tutorial presents a pedagogical demonstration of Shor's algorithm. It is a modified and expanded version of this Cirq example.

"""Install Cirq."""
try:
    import cirq
except ImportError:
    print("installing cirq...")
    !pip install --quiet cirq
    print("installed cirq.")
"""Imports for the notebook."""
import fractions
import math
import random

import numpy as np
import sympy
from typing import Callable, List, Optional, Sequence, Union

import cirq

Order finding

Factoring an integer $n$ can be reduced to finding the period of the modular exponential function (to be defined). Finding this period can be accomplished (with high probability) by finding the order of a randomly chosen element of the multiplicative group modulo $n$.

Let $n$ be a positive integer and

$$ \mathbb{Z}_n := \{x \in \mathbb{Z}_+ : x < n \text{ and } \text{gcd}(x, n) = 1\} $$

be the multiplicative group modulo $n$. Given $x \in \mathbb{Z}_n$, compute the smallest positive integer $r$ such that $x^r \text{ mod } n = 1$.

It can be shown from group/number theory that:

(1) Such an integer $r$ exists. (Note that $g^{|G|} = 1_G$ for any group $G$ with cardinality $|G|$ and element $g \in G$, but it's possible that $r < |G|$.)

(2) If $n = pq$ for primes $p$ and $q$, then $|\mathbb{Z}_n| = \phi(n) = (p - 1) (q - 1)$. (The function $\phi$ is called Euler's totient function.)

(3) The modular exponential function

$$ f_x(z) := x^z \mod n $$

is periodic with period $r$ (the order of the element $x \in \mathbb{Z}_n$). That is, $f_x(z + r) = f_x(z)$.

(4) If we know the period of the modular exponential function, we can (with high probability) figure out $p$ and $q$ -- that is, factor $n$.

As a refresher, we can visualize the elements of some multiplicative groups $\mathbb{Z}_n$ for integers $n$ via the following simple function.

"""Function to compute the elements of Z_n."""
def multiplicative_group(n: int) -> List[int]:
    """Returns the multiplicative group modulo n.

    Args:
        n: Modulus of the multiplicative group.
    """
    assert n > 2
    group = [1, 2]
    for x in range(3, n):
        if math.gcd(x, n) == 1:
            group.append(x)
    return group

For example, the multiplicative group modulo $n = 15$ is shown below.

"""Example of a multiplicative group."""
n = 15
print(f"The multiplicative group modulo n = {n} is:")
print(multiplicative_group(n))
The multiplicative group modulo n = 15 is:
[1, 2, 4, 7, 8, 11, 13, 14]

One can check that this set of elements indeed forms a group (under ordinary multiplication).

Classical order finding

A function for classically computing the order $r$ of an element $x \in \mathbb{Z}_n$ is provided below. This function simply computes the sequence

$$ x^2 \text{ mod } n $$
$$ x^3 \text{ mod } n $$
$$ x^4 \text{ mod } n $$
$$ \vdots $$

until an integer $r$ is found such that $x^r = 1 \text{ mod } n$. Since $|\mathbb{Z}_n| = \phi(n)$, this algorithm for order finding has time complexity $O(\phi(n))$ which is inefficient. (Roughly $O(2^{L / 2})$ where $L$ is the number of bits in $n$.)

"""Function for classically computing the order of an element of Z_n."""
def classical_order_finder(x: int, n: int) -> Optional[int]:
    """Computes smallest positive r such that x**r mod n == 1.

    Args:
        x: Integer whose order is to be computed, must be greater than one
           and belong to the multiplicative group of integers modulo n (which
           consists of positive integers relatively prime to n),
        n: Modulus of the multiplicative group.

    Returns:
        Smallest positive integer r such that x**r == 1 mod n.
        Always succeeds (and hence never returns None).

    Raises:
        ValueError when x is 1 or not an element of the multiplicative
        group of integers modulo n.
    """
    # Make sure x is both valid and in Z_n.
    if x < 2 or x >= n or math.gcd(x, n) > 1:
        raise ValueError(f"Invalid x={x} for modulus n={n}.")

    # Determine the order.
    r, y = 1, x
    while y != 1:
        y = (x * y) % n
        r += 1
    return r

An example of computing $r$ for a given $x \in \mathbb{Z}_n$ and given $n$ is shown in the code block below.

"""Example of (classically) computing the order of an element."""
n = 15  # The multiplicative group is [1, 2, 4, 7, 8, 11, 13, 14].
x = 8
r = classical_order_finder(x, n)

# Check that the order is indeed correct.
print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")
x^r mod n = 8^4 mod 15 = 1

The quantum part of Shor's algorithm is order finding, but done via a quantum circuit, which we'll discuss below.

Quantum order finding

Quantum order finding is essentially quantum phase estimation with unitary $U$ that computes the modular exponential function $f_x(z)$ for some randomly chosen $x \in \mathbb{Z}_n$. The full details of how $U$ is computed in terms of elementary gates can be complex to unravel, especially on a first reading. In this tutorial, we'll use arithmetic operations in Cirq which can implement such a unitary $U$ without fully delving into the details of elementary gates.

Below we first show an example of a simple arithmetic operation in Cirq (addition) then discuss the operation we care about (modular exponentiation).

Quantum arithmetic operations in Cirq

Here we discuss an example of defining an arithmetic operation in Cirq, namely modular addition. This operation adds the value of the input register into the target register. More specifically, this operation acts on two qubit registers as

$$ |a\rangle_i |b\rangle_t \mapsto |a\rangle_i |a + b \text{ mod } N_t \rangle_t . $$

Here, the subscripts $i$ and $t$ denote input and target register, respectively, and $N_t$ is the dimension of the target register.

To define this operation, called Adder, we inherit from cirq.ArithmeticOperation and override the four methods shown below. The main method is the apply method which defines the arithmetic. Here, we simply state the expression as $a + b$ instead of the more accurate $a + b \text{ mod } N_t$ above -- the cirq.ArithmeticOperation class is able to deduce what we mean by simply $a + b$ since the operation must be reversible.

"""Example of defining an arithmetic (quantum) operation in Cirq."""
class Adder(cirq.ArithmeticOperation):
    """Quantum addition."""
    def __init__(self, target_register, input_register):
        self.input_register = input_register
        self.target_register = target_register

    def registers(self):
        return self.target_register, self.input_register

    def with_registers(self, *new_registers):
        return Adder(*new_registers)

    def apply(self, target_value, input_value):
        return target_value + input_value

Now that we have the operation defined, we can use it in a circuit. The cell below creates two qubit registers, then sets the first register to be $|10\rangle$ (in binary) and the second register to be $|01\rangle$ (in binary) via $X$ gates. Then, we use the Adder operation, then measure all the qubits.

Since $10 + 01 = 11$ (in binary), we expect to measure $|11\rangle$ in the target register every time. Additionally, since we do not alter the input register, we expect to measure $|10\rangle$ in the input register every time. In short, the only bitstring we expect to measure is $1011$.

"""Example of using an Adder in a circuit."""
# Two qubit registers.
qreg1 = cirq.LineQubit.range(2)
qreg2 = cirq.LineQubit.range(2, 4)

# Define the circuit.
circ = cirq.Circuit(
    cirq.ops.X.on(qreg1[0]),
    cirq.ops.X.on(qreg2[1]),
    Adder(input_register=qreg1, target_register=qreg2),
    cirq.measure_each(*qreg1),
    cirq.measure_each(*qreg2)
)

# Display it.
print("Circuit:\n")
print(circ)

# Print the measurement outcomes.
print("\n\nMeasurement outcomes:\n")
print(cirq.sample(circ, repetitions=5).data)
Circuit:

0: ───X───#3──────────────────────────────────────────M───
          │
1: ───────#4──────────────────────────────────────────M───
          │
2: ───────<__main__.Adder object at 0x7f915460bf98>───M───
          │
3: ───X───#2──────────────────────────────────────────M───


Measurement outcomes:

   0  1  2  3
0  1  0  1  1
1  1  0  1  1
2  1  0  1  1
3  1  0  1  1
4  1  0  1  1

In the output of this code block, we first see the circuit which shows the initial $X$ gates, the Adder operation, then the final measurements. Next, we see the measurement outcomes which are all the bitstring $1011$ as expected.

It is also possible to see the unitary of the adder operation, which we do below. Here, we set the target register to be two qubits in the zero state, i.e. $|00\rangle$. We specify the input register as the integer one which corresponds to the qubit register $|01\rangle$.

"""Example of the unitary of an Adder operation."""
cirq.unitary(
    Adder(target_register=cirq.LineQubit.range(2),
          input_register=1)
).real
array([[0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.]])

We can understand this unitary as follows. The $i$th column of the unitary is the state $|i + 1 \text{ mod } 4\rangle$. For example, if we look at the $0$th column of the unitary, we see the state $|i + 1 \text{ mod } 4\rangle = |0 + 1 \text{ mod } 4\rangle = |1\rangle$. If we look at the $1$st column of the unitary, we see the state $|i + 1 \text{ mod } 4\rangle = |1 + 1 \text{ mod } 4\rangle = |2\rangle$. Similarly for the last two columns.

Modular exponential arithmetic operation

We can define the modular exponential arithmetic operation in a similar way to the simple addition arithmetic operation, shown below. For the purposes of understanding Shor's algorithm, the most important part of the following code block is the apply method which defines the arithmetic operation.

"""Defines the modular exponential operation used in Shor's algorithm."""
class ModularExp(cirq.ArithmeticOperation):
    """Quantum modular exponentiation.

    This class represents the unitary which multiplies base raised to exponent
    into the target modulo the given modulus. More precisely, it represents the
    unitary V which computes modular exponentiation x**e mod n:

        V|y⟩|e⟩ = |y * x**e mod n⟩ |e⟩     0 <= y < n
        V|y⟩|e⟩ = |y⟩ |e⟩                  n <= y

    where y is the target register, e is the exponent register, x is the base
    and n is the modulus. Consequently,

        V|y⟩|e⟩ = (U**e|y)|e⟩

    where U is the unitary defined as

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y
    """
    def __init__(
        self, 
        target: Sequence[cirq.Qid],
        exponent: Union[int, Sequence[cirq.Qid]], 
        base: int,
        modulus: int
    ) -> None:
        if len(target) < modulus.bit_length():
            raise ValueError(f'Register with {len(target)} qubits is too small '
                             f'for modulus {modulus}')
        self.target = target
        self.exponent = exponent
        self.base = base
        self.modulus = modulus

    def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:
        return self.target, self.exponent, self.base, self.modulus

    def with_registers(
            self,
            *new_registers: Union[int, Sequence['cirq.Qid']],
    ) -> cirq.ArithmeticOperation:
        if len(new_registers) != 4:
            raise ValueError(f'Expected 4 registers (target, exponent, base, '
                             f'modulus), but got {len(new_registers)}')
        target, exponent, base, modulus = new_registers
        if not isinstance(target, Sequence):
            raise ValueError(
                f'Target must be a qubit register, got {type(target)}')
        if not isinstance(base, int):
            raise ValueError(
                f'Base must be a classical constant, got {type(base)}')
        if not isinstance(modulus, int):
            raise ValueError(
                f'Modulus must be a classical constant, got {type(modulus)}')
        return ModularExp(target, exponent, base, modulus)

    def apply(self, *register_values: int) -> int:
        assert len(register_values) == 4
        target, exponent, base, modulus = register_values
        if target >= modulus:
            return target
        return (target * base**exponent) % modulus

    def _circuit_diagram_info_(
            self,
            args: cirq.CircuitDiagramInfoArgs,
    ) -> cirq.CircuitDiagramInfo:
        assert args.known_qubits is not None
        wire_symbols: List[str] = []
        t, e = 0, 0
        for qubit in args.known_qubits:
            if qubit in self.target:
                if t == 0:
                    if isinstance(self.exponent, Sequence):
                        e_str = 'e'
                    else:
                        e_str = str(self.exponent)
                    wire_symbols.append(
                        f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')
                else:
                    wire_symbols.append('t' + str(t))
                t += 1
            if isinstance(self.exponent, Sequence) and qubit in self.exponent:
                wire_symbols.append('e' + str(e))
                e += 1
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))

In the apply method, we see that we evaluate (target * base**exponent) % modulus. The target and the exponent depend on the values of the respective qubit registers, and the base and modulus are constant -- namely, the modulus is $n$ and the base is some $x \in \mathbb{Z}_n$.

The total number of qubits we will use is $3 (L + 1)$ where $L$ is the number of bits needed to store the integer $n$ to factor. The size of the unitary which implements the modular exponential is thus $4^{3(L + 1)}$. For a modest $n = 15$, the unitary requires storing $2^{30}$ floating point numbers in memory which is out of reach of most current standard laptops.

"""Create the target and exponent registers for phase estimation,
and see the number of qubits needed for Shor's algorithm.
"""
n = 15
L = n.bit_length()

# The target register has L qubits.
target = cirq.LineQubit.range(L)

# The exponent register has 2L + 3 qubits.
exponent = cirq.LineQubit.range(L, 3 * L + 3)

# Display the total number of qubits to factor this n.
print(f"To factor n = {n} which has L = {L} bits, we need 3L + 3 = {3 * L + 3} qubits.")
To factor n = 15 which has L = 4 bits, we need 3L + 3 = 15 qubits.

As with the simple adder operation, this modular exponential operation has a unitary which we can display (memory permitting) as follows.

"""See (part of) the unitary for a modular exponential operation."""
# Pick some element of the multiplicative group modulo n.
x = 5

# Display (part of) the unitary. Uncomment if n is small enough.
# cirq.unitary(ModularExp(target, exponent, x, n))

Using the modular exponentional operation in a circuit

The quantum part of Shor's algorithm is just phase estimation with the unitary $U$ corresponding to the modular exponential operation. The following cell defines a function which creates the circuit for Shor's algorithm using the ModularExp operation we defined above.

"""Function to make the quantum circuit for order finding."""
def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
    """Returns quantum circuit which computes the order of x modulo n.

    The circuit uses Quantum Phase Estimation to compute an eigenvalue of
    the unitary

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y

    Args:
        x: positive integer whose order modulo n is to be found
        n: modulus relative to which the order of x is to be found

    Returns:
        Quantum circuit for finding the order of x modulo n
    """
    L = n.bit_length()
    target = cirq.LineQubit.range(L)
    exponent = cirq.LineQubit.range(L, 3 * L + 3)
    return cirq.Circuit(
        cirq.X(target[L - 1]),
        cirq.H.on_each(*exponent),
        ModularExp(target, exponent, x, n),
        cirq.qft(*exponent, inverse=True),
        cirq.measure(*exponent, key='exponent'),
    )

Using this function, we can visualize the circuit for a given $x$ and $n$ as follows.

"""Example of the quantum circuit for period finding."""
n = 15
x = 7
circuit = make_order_finding_circuit(x, n)
print(circuit)
0: ────────ModularExp(t*7**e % 15)────────────────────────────
           │
1: ────────t1─────────────────────────────────────────────────
           │
2: ────────t2─────────────────────────────────────────────────
           │
3: ────X───t3─────────────────────────────────────────────────
           │
4: ────H───e0────────────────────────qft^-1───M('exponent')───
           │                         │        │
5: ────H───e1────────────────────────#2───────M───────────────
           │                         │        │
6: ────H───e2────────────────────────#3───────M───────────────
           │                         │        │
7: ────H───e3────────────────────────#4───────M───────────────
           │                         │        │
8: ────H───e4────────────────────────#5───────M───────────────
           │                         │        │
9: ────H───e5────────────────────────#6───────M───────────────
           │                         │        │
10: ───H───e6────────────────────────#7───────M───────────────
           │                         │        │
11: ───H───e7────────────────────────#8───────M───────────────
           │                         │        │
12: ───H───e8────────────────────────#9───────M───────────────
           │                         │        │
13: ───H───e9────────────────────────#10──────M───────────────
           │                         │        │
14: ───H───e10───────────────────────#11──────M───────────────

As previously described, we put the exponent register into an equal superposition via Hadamard gates. The $X$ gate on the last qubit in the target register is used for phase kickback. The modular exponential operation performs the sequence of controlled unitaries in phase estimation, then we apply the inverse quantum Fourier transform to the exponent register and measure to read out the result.

To illustrate the measurement results, we can sample from a smaller circuit. (Note that in practice we would never run Shor's algorithm with $n = 6$ because it is even. This is just an example to illustrate the measurement outcomes.)

"""Measuring Shor's period finding circuit."""
circuit = make_order_finding_circuit(x=5, n=6)
res = cirq.sample(circuit, repetitions=8)

print("Raw measurements:")
print(res)

print("\nInteger in exponent register:")
print(res.data)
Raw measurements:
exponent=10010110, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000

Integer in exponent register:
   exponent
0       256
1         0
2         0
3       256
4         0
5       256
6       256
7         0

We interpret each measured bitstring as an integer, but what do these integers tell us? In the next section we look at how to classically post-process to interpret them.

Classical post-processing

The integer we measure is close to $s / r$ where $r$ is the order of $x \in \mathbb{Z}_n$ and $0 \le s < r$ is an integer. We use the continued fractions algorithm to determine $r$ from $s / r$ then return it if the order finding circuit succeeded, else we return None.

def process_measurement(result: cirq.TrialResult, x: int, n: int) -> Optional[int]:
    """Interprets the output of the order finding circuit.

    Specifically, it determines s/r such that exp(2πis/r) is an eigenvalue
    of the unitary

        U|y⟩ = |xy mod n⟩  0 <= y < n
        U|y⟩ = |y⟩         n <= y

    then computes r (by continued fractions) if possible, and returns it.

    Args:
        result: trial result obtained by sampling the output of the
            circuit built by make_order_finding_circuit

    Returns:
        r, the order of x modulo n or None.
    """
    # Read the output integer of the exponent register.
    exponent_as_integer = result.data["exponent"][0]
    exponent_num_bits = result.measurements["exponent"].shape[1]
    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)

    # Run the continued fractions algorithm to determine f = s / r.
    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)

    # If the numerator is zero, the order finder failed.
    if f.numerator == 0:
        return None

    # Else, return the denominator if it is valid.
    r = f.denominator
    if x**r % n != 1:
        return None
    return r

The next code block shows an example of creating an order finding circuit, executing it, then using the classical postprocessing function to determine the order. Recall that the quantum part of the algorithm succeeds with some probability. If the order is None, try re-running the cell a few times.

"""Example of the classical post-processing."""
# Set n and x here
n = 6
x = 5

print(f"Finding the order of x = {x} modulo n = {n}\n")
measurement = cirq.sample(circuit, repetitions=1)
print("Raw measurements:")
print(measurement)

print("\nInteger in exponent register:")
print(measurement.data)

r = process_measurement(measurement, x, n)
print("\nOrder r =", r)
if r is not None:
    print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")
Finding the order of x = 5 modulo n = 6

Raw measurements:
exponent=1, 0, 0, 0, 0, 0, 0, 0, 0

Integer in exponent register:
   exponent
0       256

Order r = 2
x^r mod n = 5^2 mod 6 = 1

You should see that the order of $x = 5$ in $\mathbb{Z}_6$ is $r = 2$. Indeed, $5^2 \text{ mod } 6 = 25 \text{ mod } 6 = 1$.

Quantum order finder

We can now define a streamlined function for the quantum version of order finding using the functions we have previously written. The quantum order finder below creates the circuit, executes it, and processes the measurement result.

def quantum_order_finder(x: int, n: int) -> Optional[int]:
    """Computes smallest positive r such that x**r mod n == 1.

    Args:
        x: integer whose order is to be computed, must be greater than one
           and belong to the multiplicative group of integers modulo n (which
           consists of positive integers relatively prime to n),
        n: modulus of the multiplicative group.
    """
    # Check that the integer x is a valid element of the multiplicative group
    # modulo n.
    if x < 2 or n <= x or math.gcd(x, n) > 1:
        raise ValueError(f'Invalid x={x} for modulus n={n}.')

    # Create the order finding circuit.
    circuit = make_order_finding_circuit(x, n)

    # Sample from the order finding circuit.
    measurement = cirq.sample(circuit)

    # Return the processed measurement result.
    return process_measurement(measurement, x, n)

This completes our quantum implementation of an order finder, and the quantum part of Shor's algorithm.

The complete factoring algorithm

We can use this quantum order finder (or the classical order finder) to complete Shor's algorithm. In the following code block, we add a few pre-processing steps which:

(1) Check if $n$ is even,

(2) Check if $n$ is prime,

(3) Check if $n$ is a prime power,

all of which can be done efficiently with a classical computer. Additionally, we add the last necessary post-processing step which uses the order $r$ to compute a non-trivial factor $p$ of $n$. This is achieved by computing $y = x^{r / 2} \text{ mod } n$ (assuming $r$ is even), then computing $p = \text{gcd}(y - 1, n)$.

"""Functions for factoring from start to finish."""
def find_factor_of_prime_power(n: int) -> Optional[int]:
    """Returns non-trivial factor of n if n is a prime power, else None."""
    for k in range(2, math.floor(math.log2(n)) + 1):
        c = math.pow(n, 1 / k)
        c1 = math.floor(c)
        if c1**k == n:
            return c1
        c2 = math.ceil(c)
        if c2**k == n:
            return c2
    return None


def find_factor(
    n: int,
    order_finder: Callable[[int, int], Optional[int]] = quantum_order_finder,
    max_attempts: int = 30
) -> Optional[int]:
    """Returns a non-trivial factor of composite integer n.

    Args:
        n: Integer to factor.
        order_finder: Function for finding the order of elements of the
            multiplicative group of integers modulo n.
        max_attempts: number of random x's to try, also an upper limit
            on the number of order_finder invocations.

    Returns:
        Non-trivial factor of n or None if no such factor was found.
        Factor k of n is trivial if it is 1 or n.
    """
    # If the number is prime, there are no non-trivial factors.
    if sympy.isprime(n):
        print("n is prime!")
        return None

    # If the number is even, two is a non-trivial factor.
    if n % 2 == 0:
        return 2

    # If n is a prime power, we can find a non-trivial factor efficiently.
    c = find_factor_of_prime_power(n)
    if c is not None:
        return c

    for _ in range(max_attempts):
        # Choose a random number between 2 and n - 1.
        x = random.randint(2, n - 1)

        # Most likely x and n will be relatively prime.
        c = math.gcd(x, n)

        # If x and n are not relatively prime, we got lucky and found
        # a non-trivial factor.
        if 1 < c < n:
            return c

        # Compute the order r of x modulo n using the order finder.
        r = order_finder(x, n)

        # If the order finder failed, try again.
        if r is None:
            continue

        # If the order r is even, try again.
        if r % 2 != 0:
            continue

        # Compute the non-trivial factor.
        y = x**(r // 2) % n
        assert 1 < y < n
        c = math.gcd(y - 1, n)
        if 1 < c < n:
            return c

    print(f"Failed to find a non-trivial factor in {max_attempts} attempts.")
    return None

The function find_factor uses the quantum_order_finder by default, in which case it is executing Shor's algorithm. As previously mentioned, due to the large memory requirements for classically simulating this circuit, we cannot run Shor's algorithm for $n \ge 15$. However, we can use the classical order finder as a substitute.

"""Example of factoring via Shor's algorithm (order finding)."""
# Number to factor
n = 184573

# Attempt to find a factor
p = find_factor(n, order_finder=classical_order_finder)
q = n // p

print("Factoring n = pq =", n)
print("p =", p)
print("q =", q)
Factoring n = pq = 184573
p = 487
q = 379

"""Check the answer is correct."""
p * q == n
True