|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""BIP-0376 reference implementation and test vector runner. |
| 3 | +
|
| 4 | +Run: |
| 5 | + ./bip-0376/reference.py bip-0376/test-vectors.json |
| 6 | +""" |
| 7 | + |
| 8 | +import json |
| 9 | +import sys |
| 10 | +import hashlib |
| 11 | +from pathlib import Path |
| 12 | +from typing import Optional, Tuple |
| 13 | + |
| 14 | +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F |
| 15 | +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 |
| 16 | +G = ( |
| 17 | + 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, |
| 18 | + 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8, |
| 19 | +) |
| 20 | + |
| 21 | +Point = Tuple[int, int] |
| 22 | + |
| 23 | + |
| 24 | +def tagged_hash(tag: str, msg: bytes) -> bytes: |
| 25 | + tag_hash = hashlib.sha256(tag.encode("utf-8")).digest() |
| 26 | + return hashlib.sha256(tag_hash + tag_hash + msg).digest() |
| 27 | + |
| 28 | + |
| 29 | +def int_from_bytes(data: bytes) -> int: |
| 30 | + return int.from_bytes(data, byteorder="big") |
| 31 | + |
| 32 | + |
| 33 | +def bytes_from_int(x: int) -> bytes: |
| 34 | + return x.to_bytes(32, byteorder="big") |
| 35 | + |
| 36 | + |
| 37 | +def has_even_y(P: Point) -> bool: |
| 38 | + return (P[1] % 2) == 0 |
| 39 | + |
| 40 | + |
| 41 | +def bytes_from_point(P: Point) -> bytes: |
| 42 | + return bytes_from_int(P[0]) |
| 43 | + |
| 44 | + |
| 45 | +def xor_bytes(a: bytes, b: bytes) -> bytes: |
| 46 | + return bytes(x ^ y for (x, y) in zip(a, b)) |
| 47 | + |
| 48 | + |
| 49 | +def lift_x(x_coord: int) -> Optional[Point]: |
| 50 | + if x_coord >= p: |
| 51 | + return None |
| 52 | + y_sq = (pow(x_coord, 3, p) + 7) % p |
| 53 | + y_coord = pow(y_sq, (p + 1) // 4, p) |
| 54 | + if pow(y_coord, 2, p) != y_sq: |
| 55 | + return None |
| 56 | + return (x_coord, y_coord if (y_coord % 2) == 0 else p - y_coord) |
| 57 | + |
| 58 | + |
| 59 | +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: |
| 60 | + if P1 is None: |
| 61 | + return P2 |
| 62 | + if P2 is None: |
| 63 | + return P1 |
| 64 | + if (P1[0] == P2[0]) and (P1[1] != P2[1]): |
| 65 | + return None |
| 66 | + if P1 == P2: |
| 67 | + lam = (3 * P1[0] * P1[0] * pow(2 * P1[1], p - 2, p)) % p |
| 68 | + else: |
| 69 | + lam = ((P2[1] - P1[1]) * pow(P2[0] - P1[0], p - 2, p)) % p |
| 70 | + x3 = (lam * lam - P1[0] - P2[0]) % p |
| 71 | + y3 = (lam * (P1[0] - x3) - P1[1]) % p |
| 72 | + return (x3, y3) |
| 73 | + |
| 74 | + |
| 75 | +def point_mul(P: Optional[Point], scalar: int) -> Optional[Point]: |
| 76 | + R = None |
| 77 | + for i in range(256): |
| 78 | + if (scalar >> i) & 1: |
| 79 | + R = point_add(R, P) |
| 80 | + P = point_add(P, P) |
| 81 | + return R |
| 82 | + |
| 83 | + |
| 84 | +def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: |
| 85 | + if len(pubkey) != 32 or len(sig) != 64: |
| 86 | + return False |
| 87 | + P = lift_x(int_from_bytes(pubkey)) |
| 88 | + r = int_from_bytes(sig[0:32]) |
| 89 | + s = int_from_bytes(sig[32:64]) |
| 90 | + if P is None or r >= p or s >= n: |
| 91 | + return False |
| 92 | + e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n |
| 93 | + R = point_add(point_mul(G, s), point_mul(P, n - e)) |
| 94 | + if R is None: |
| 95 | + return False |
| 96 | + return has_even_y(R) and (R[0] == r) |
| 97 | + |
| 98 | + |
| 99 | +def schnorr_sign(msg: bytes, seckey: bytes, aux_rand: bytes) -> bytes: |
| 100 | + d0 = int_from_bytes(seckey) |
| 101 | + if not (1 <= d0 <= n - 1): |
| 102 | + raise ValueError("The secret key must be in the range 1..n-1.") |
| 103 | + if len(aux_rand) != 32: |
| 104 | + raise ValueError("aux_rand must be 32 bytes.") |
| 105 | + P = point_mul(G, d0) |
| 106 | + assert P is not None |
| 107 | + d = d0 if has_even_y(P) else n - d0 |
| 108 | + t = xor_bytes(bytes_from_int(d), tagged_hash("BIP0340/aux", aux_rand)) |
| 109 | + k0 = int_from_bytes(tagged_hash("BIP0340/nonce", t + bytes_from_point(P) + msg)) % n |
| 110 | + if k0 == 0: |
| 111 | + raise RuntimeError("Failure. This happens only with negligible probability.") |
| 112 | + R = point_mul(G, k0) |
| 113 | + assert R is not None |
| 114 | + k = k0 if has_even_y(R) else n - k0 |
| 115 | + e = int_from_bytes(tagged_hash("BIP0340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n |
| 116 | + sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) |
| 117 | + if not schnorr_verify(msg, bytes_from_point(P), sig): |
| 118 | + raise RuntimeError("The created signature does not pass verification.") |
| 119 | + return sig |
| 120 | + |
| 121 | + |
| 122 | +def parse_hex(data: str, expected_len: int, field_name: str) -> bytes: |
| 123 | + raw = bytes.fromhex(data) |
| 124 | + if len(raw) != expected_len: |
| 125 | + raise ValueError(f"{field_name} must be {expected_len} bytes.") |
| 126 | + return raw |
| 127 | + |
| 128 | + |
| 129 | +def derive_signing_key(spend_seckey: bytes, tweak: bytes, output_pubkey: bytes) -> Tuple[int, int, bool]: |
| 130 | + b_spend = int_from_bytes(spend_seckey) |
| 131 | + if not (1 <= b_spend <= n - 1): |
| 132 | + raise ValueError("spend key out of range") |
| 133 | + |
| 134 | + tweak_int = int_from_bytes(tweak) |
| 135 | + d_raw = (b_spend + tweak_int) % n |
| 136 | + if d_raw == 0: |
| 137 | + raise ValueError("tweaked private key is zero") |
| 138 | + |
| 139 | + Q = point_mul(G, d_raw) |
| 140 | + assert Q is not None |
| 141 | + negated = not has_even_y(Q) |
| 142 | + d = d_raw if not negated else n - d_raw |
| 143 | + |
| 144 | + Q_even = point_mul(G, d) |
| 145 | + assert Q_even is not None |
| 146 | + if bytes_from_point(Q_even) != output_pubkey: |
| 147 | + raise ValueError("tweaked key does not match output key") |
| 148 | + |
| 149 | + return d_raw, d, negated |
| 150 | + |
| 151 | + |
| 152 | +def run_test_vectors(path: Path) -> bool: |
| 153 | + vectors = json.loads(path.read_text(encoding="utf-8")) |
| 154 | + all_passed = True |
| 155 | + |
| 156 | + valid_vectors = vectors.get("valid", []) |
| 157 | + invalid_vectors = vectors.get("invalid", []) |
| 158 | + |
| 159 | + print(f"Running {len(valid_vectors)} valid vectors") |
| 160 | + for index, vector in enumerate(valid_vectors): |
| 161 | + description = vector["description"] |
| 162 | + given = vector["given"] |
| 163 | + expected = vector["expected"] |
| 164 | + print(f"- valid[{index}] {description}") |
| 165 | + try: |
| 166 | + spend_seckey = parse_hex(given["spend_seckey"], 32, "spend_seckey") |
| 167 | + tweak = parse_hex(given["tweak"], 32, "tweak") |
| 168 | + output_pubkey = parse_hex(given["output_pubkey"], 32, "output_pubkey") |
| 169 | + message = parse_hex(given["message"], 32, "message") |
| 170 | + aux_rand = parse_hex(given["aux_rand"], 32, "aux_rand") |
| 171 | + |
| 172 | + d_raw, d, negated = derive_signing_key(spend_seckey, tweak, output_pubkey) |
| 173 | + signature = schnorr_sign(message, bytes_from_int(d), aux_rand) |
| 174 | + |
| 175 | + assert bytes_from_int(d_raw).hex() == expected["raw_tweaked_seckey"] |
| 176 | + assert negated == expected["negated"] |
| 177 | + assert bytes_from_int(d).hex() == expected["final_seckey"] |
| 178 | + assert signature.hex() == expected["signature"] |
| 179 | + except Exception as exc: |
| 180 | + all_passed = False |
| 181 | + print(f" FAILED: {exc}") |
| 182 | + |
| 183 | + print(f"Running {len(invalid_vectors)} invalid vectors") |
| 184 | + for index, vector in enumerate(invalid_vectors): |
| 185 | + description = vector["description"] |
| 186 | + given = vector["given"] |
| 187 | + error_substr = vector["error_substr"] |
| 188 | + print(f"- invalid[{index}] {description}") |
| 189 | + try: |
| 190 | + spend_seckey = parse_hex(given["spend_seckey"], 32, "spend_seckey") |
| 191 | + tweak = parse_hex(given["tweak"], 32, "tweak") |
| 192 | + output_pubkey = parse_hex(given["output_pubkey"], 32, "output_pubkey") |
| 193 | + derive_signing_key(spend_seckey, tweak, output_pubkey) |
| 194 | + all_passed = False |
| 195 | + print(" FAILED: expected an exception") |
| 196 | + except Exception as exc: |
| 197 | + if error_substr not in str(exc): |
| 198 | + all_passed = False |
| 199 | + print(f" FAILED: wrong error, got: {exc}") |
| 200 | + |
| 201 | + print("All test vectors passed." if all_passed else "Some test vectors failed.") |
| 202 | + return all_passed |
| 203 | + |
| 204 | + |
| 205 | +def main() -> int: |
| 206 | + if len(sys.argv) > 2: |
| 207 | + print(f"Usage: {sys.argv[0]} [test-vectors.json]") |
| 208 | + return 1 |
| 209 | + |
| 210 | + if len(sys.argv) == 2: |
| 211 | + vector_path = Path(sys.argv[1]) |
| 212 | + else: |
| 213 | + vector_path = Path(__file__).with_name("test-vectors.json") |
| 214 | + |
| 215 | + if not vector_path.is_file(): |
| 216 | + print(f"Vector file not found: {vector_path}") |
| 217 | + return 1 |
| 218 | + |
| 219 | + return 0 if run_test_vectors(vector_path) else 1 |
| 220 | + |
| 221 | + |
| 222 | +if __name__ == "__main__": |
| 223 | + raise SystemExit(main()) |
0 commit comments