|
1 | | -use std::array; |
| 1 | +use std::array::TryFromSliceError; |
2 | 2 |
|
3 | 3 | use base64::prelude::*; |
4 | 4 | use log::{debug, warn}; |
5 | 5 |
|
6 | 6 | const STATE_HASH_SIZE: usize = 32; |
7 | 7 |
|
8 | | -pub fn verify_protocol_state_proof_integrity(proof: &[u8], public_input: &[u8]) -> bool { |
| 8 | +pub fn verify_proof_integrity(proof: &[u8], public_input: &[u8]) -> bool { |
9 | 9 | debug!("Checking Mina protocol state proof"); |
10 | | - if let Err(err) = check_protocol_state_proof(proof) { |
| 10 | + if let Err(err) = check_proof(proof) { |
11 | 11 | warn!("Protocol state proof check failed: {}", err); |
12 | 12 | return false; |
13 | 13 | } |
14 | 14 |
|
15 | 15 | debug!("Checking Mina protocol state public inputs"); |
16 | | - if let Err(err) = check_protocol_state_pub(public_input) { |
| 16 | + if let Err(err) = check_pub_inputs(public_input) { |
17 | 17 | warn!("Protocol state public inputs check failed: {}", err); |
18 | 18 | return false; |
19 | 19 | } |
20 | 20 |
|
21 | 21 | true |
22 | 22 | } |
23 | 23 |
|
24 | | -pub fn check_protocol_state_proof(protocol_state_proof_bytes: &[u8]) -> Result<(), String> { |
25 | | - // TODO(xqft): check binprot deserialization |
26 | | - let protocol_state_proof_base64 = |
27 | | - std::str::from_utf8(protocol_state_proof_bytes).map_err(|err| err.to_string())?; |
28 | | - BASE64_URL_SAFE |
29 | | - .decode(protocol_state_proof_base64) |
30 | | - .map_err(|err| err.to_string())?; |
| 24 | +pub fn check_hash(pub_inputs: &[u8], offset: &mut usize) -> Result<(), String> { |
| 25 | + pub_inputs |
| 26 | + .get(*offset..*offset + STATE_HASH_SIZE) |
| 27 | + .ok_or("Failed to slice candidate hash".to_string())?; |
| 28 | + |
| 29 | + *offset += STATE_HASH_SIZE; |
31 | 30 |
|
32 | 31 | Ok(()) |
33 | 32 | } |
34 | 33 |
|
35 | | -pub fn check_protocol_state_pub(protocol_state_pub: &[u8]) -> Result<(), String> { |
36 | | - // TODO(xqft): check hash and binprot deserialization |
37 | | - let candidate_protocol_state_len = |
38 | | - check_protocol_state_and_hash(protocol_state_pub, STATE_HASH_SIZE)?; |
| 34 | +pub fn check_state(pub_inputs: &[u8], offset: &mut usize) -> Result<(), String> { |
| 35 | + let state_len: usize = pub_inputs |
| 36 | + .get(*offset..*offset + 4) |
| 37 | + .ok_or("Failed to slice state len".to_string()) |
| 38 | + .and_then(|slice| { |
| 39 | + slice |
| 40 | + .try_into() |
| 41 | + .map_err(|err: TryFromSliceError| err.to_string()) |
| 42 | + }) |
| 43 | + .map(u32::from_be_bytes) |
| 44 | + .and_then(|len| usize::try_from(len).map_err(|err| err.to_string()))?; |
| 45 | + |
| 46 | + pub_inputs |
| 47 | + .get(*offset + 4..*offset + 4 + state_len) |
| 48 | + .ok_or("Failed to slice state".to_string()) |
| 49 | + .and_then(|bytes| std::str::from_utf8(bytes).map_err(|err| err.to_string())) |
| 50 | + .and_then(|base64| { |
| 51 | + BASE64_STANDARD |
| 52 | + .decode(base64) |
| 53 | + .map_err(|err| err.to_string()) |
| 54 | + })?; |
| 55 | + *offset += 4 + state_len; |
| 56 | + |
| 57 | + Ok(()) |
| 58 | +} |
39 | 59 |
|
40 | | - let _tip_protocol_state_len = check_protocol_state_and_hash( |
41 | | - protocol_state_pub, |
42 | | - STATE_HASH_SIZE + 4 + candidate_protocol_state_len + STATE_HASH_SIZE, |
43 | | - )?; |
| 60 | +pub fn check_pub_inputs(pub_inputs: &[u8]) -> Result<(), String> { |
| 61 | + let mut offset = 0; |
| 62 | + |
| 63 | + check_hash(pub_inputs, &mut offset)?; // candidate hash |
| 64 | + check_hash(pub_inputs, &mut offset)?; // tip hash |
| 65 | + |
| 66 | + check_state(pub_inputs, &mut offset)?; // candidate state |
| 67 | + check_state(pub_inputs, &mut offset)?; // tip state |
44 | 68 |
|
45 | 69 | Ok(()) |
46 | 70 | } |
47 | 71 |
|
48 | | -fn check_protocol_state_and_hash(protocol_state_pub: &[u8], start: usize) -> Result<usize, String> { |
49 | | - let protocol_state_len_vec: Vec<_> = protocol_state_pub.iter().skip(start).take(4).collect(); |
50 | | - let protocol_state_len_bytes: [u8; 4] = array::from_fn(|i| protocol_state_len_vec[i].clone()); |
51 | | - let protocol_state_len = u32::from_be_bytes(protocol_state_len_bytes) as usize; |
52 | | - |
53 | | - let protocol_state_bytes: Vec<_> = protocol_state_pub |
54 | | - .iter() |
55 | | - .skip(start + 4) |
56 | | - .take(protocol_state_len) |
57 | | - .map(|byte| byte.clone()) |
58 | | - .collect(); |
59 | | - let protocol_state_base64 = |
60 | | - std::str::from_utf8(protocol_state_bytes.as_slice()).map_err(|err| err.to_string())?; |
61 | | - BASE64_STANDARD |
62 | | - .decode(protocol_state_base64) |
63 | | - .map_err(|err| err.to_string())?; |
64 | | - |
65 | | - Ok(protocol_state_len) |
| 72 | +pub fn check_proof(proof_bytes: &[u8]) -> Result<(), String> { |
| 73 | + std::str::from_utf8(proof_bytes) |
| 74 | + .map_err(|err| err.to_string()) |
| 75 | + .and_then(|base64| { |
| 76 | + BASE64_URL_SAFE |
| 77 | + .decode(base64) |
| 78 | + .map_err(|err| err.to_string()) |
| 79 | + })?; |
| 80 | + Ok(()) |
66 | 81 | } |
67 | 82 |
|
68 | 83 | #[cfg(test)] |
69 | 84 | mod test { |
70 | | - use super::verify_protocol_state_proof_integrity; |
| 85 | + use super::verify_proof_integrity; |
71 | 86 |
|
72 | 87 | const PROTOCOL_STATE_PROOF_BYTES: &[u8] = |
73 | 88 | include_bytes!("../../../../batcher/aligned/test_files/mina/protocol_state.proof"); |
74 | 89 | const PROTOCOL_STATE_PUB_BYTES: &[u8] = |
75 | 90 | include_bytes!("../../../../batcher/aligned/test_files/mina/protocol_state.pub"); |
76 | 91 |
|
77 | 92 | #[test] |
78 | | - fn verify_protocol_state_proof_integrity_does_not_fail() { |
79 | | - assert!(verify_protocol_state_proof_integrity( |
| 93 | + fn verify_proof_integrity_does_not_fail() { |
| 94 | + assert!(verify_proof_integrity( |
80 | 95 | PROTOCOL_STATE_PROOF_BYTES, |
81 | 96 | PROTOCOL_STATE_PUB_BYTES, |
82 | 97 | )); |
|
0 commit comments