From ac02ac57e0f0225a5aad8ffa47ec0d089c353010 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 15 Sep 2025 20:47:26 -0400 Subject: [PATCH 01/41] Add native sumcheck chip --- extensions/native/circuit/src/extension.rs | 21 +- extensions/native/circuit/src/lib.rs | 2 + .../native/circuit/src/poseidon2/trace.rs | 4 +- extensions/native/circuit/src/sumcheck/air.rs | 44 + .../native/circuit/src/sumcheck/chip.rs | 1296 +++++++++++++++++ extensions/native/circuit/src/sumcheck/mod.rs | 5 + .../native/circuit/src/sumcheck/trace.rs | 59 + .../native/compiler/src/asm/compiler.rs | 6 + .../native/compiler/src/asm/instruction.rs | 5 + .../native/compiler/src/conversion/mod.rs | 19 +- .../native/compiler/src/ir/instructions.rs | 9 + extensions/native/compiler/src/ir/mod.rs | 1 + extensions/native/compiler/src/ir/sumcheck.rs | 26 + extensions/native/compiler/src/lib.rs | 12 + 14 files changed, 1495 insertions(+), 14 deletions(-) create mode 100644 extensions/native/circuit/src/sumcheck/air.rs create mode 100644 extensions/native/circuit/src/sumcheck/chip.rs create mode 100644 extensions/native/circuit/src/sumcheck/mod.rs create mode 100644 extensions/native/circuit/src/sumcheck/trace.rs create mode 100644 extensions/native/compiler/src/ir/sumcheck.rs diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 1ee1af6885..eecf6370e0 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -15,9 +15,7 @@ use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ - CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, - NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, + CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::{ @@ -29,10 +27,7 @@ use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use crate::{ - adapters::{convert_adapter::ConvertAdapterChip, *}, - poseidon2::chip::NativePoseidon2Chip, - phantom::*, - *, + adapters::{convert_adapter::ConvertAdapterChip, *}, phantom::*, poseidon2::chip::NativePoseidon2Chip, sumcheck::chip::NativeSumcheckChip, * }; #[derive(Clone, Debug, Serialize, Deserialize, VmConfig, derive_new::new)] @@ -76,6 +71,7 @@ pub enum NativeExecutor { FieldExtension(FieldExtensionChip), FriReducedOpening(FriReducedOpeningChip), VerifyBatch(NativePoseidon2Chip), + SumcheckLayerEval(NativeSumcheckChip), } #[derive(From, ChipUsageGetter, Chip, AnyEnum)] @@ -207,6 +203,17 @@ impl VmExtension for Native { ], )?; + let sumcheck_chip = NativeSumcheckChip::new( + builder.system_port(), + offline_memory.clone(), + ); + inventory.add_executor( + sumcheck_chip, + [ + SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode(), + ] + )?; + builder.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index 46c6bc890f..89261f73d5 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -8,6 +8,7 @@ mod fri; mod jal; mod loadstore; mod poseidon2; +mod sumcheck; pub use branch_eq::*; pub use castf::*; @@ -17,6 +18,7 @@ pub use fri::*; pub use jal::*; pub use loadstore::*; pub use poseidon2::*; +pub use sumcheck::*; mod extension; pub use extension::*; diff --git a/extensions/native/circuit/src/poseidon2/trace.rs b/extensions/native/circuit/src/poseidon2/trace.rs index 27ffe858a9..b212df67a6 100644 --- a/extensions/native/circuit/src/poseidon2/trace.rs +++ b/extensions/native/circuit/src/poseidon2/trace.rs @@ -15,9 +15,9 @@ use openvm_stark_backend::{ }; use crate::{ - chip::TranscriptObservationRecord, poseidon2::{ + poseidon2::{ chip::{ - CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, NativePoseidon2Chip, SimplePoseidonRecord, VerifyBatchRecord, NUM_INITIAL_READS + TranscriptObservationRecord, CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, NativePoseidon2Chip, SimplePoseidonRecord, VerifyBatchRecord, NUM_INITIAL_READS }, columns::{ InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs new file mode 100644 index 0000000000..bd55e88091 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -0,0 +1,44 @@ +use openvm_circuit::{ + arch::{ExecutionBridge, ExecutionState}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, +}; +use openvm_stark_backend::{ + air_builders::sub::SubAirBuilder, + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +#[derive(Clone, Debug)] +pub struct NativeSumcheckAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + pub(crate) address_space: F, +} + +impl BaseAir for NativeSumcheckAir { + fn width(&self) -> usize { + // _debug + 0 + } +} + +impl BaseAirWithPublicValues + for NativeSumcheckAir +{ +} + +impl PartitionedBaseAir + for NativeSumcheckAir +{ +} + +impl Air + for NativeSumcheckAir +{ + fn eval(&self, builder: &mut AB) { + // _debug + } +} \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs new file mode 100644 index 0000000000..e5d9603f60 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -0,0 +1,1296 @@ +use std::sync::{Arc, Mutex}; +use openvm_circuit::{ + arch::{ + ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, + }, + system::memory::{MemoryController, OfflineMemory, RecordId}, +}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_stark_backend::{ + p3_field::{Field, PrimeField, PrimeField32}, + p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, +}; +use openvm_native_compiler::{ + conversion::AS, + SumcheckOpcode::SUMCHECK_LAYER_EVAL, +}; +use crate::sumcheck::air::NativeSumcheckAir; +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + utils::const_max, +}; + +pub struct NativeSumcheckChip { + pub height: usize, + pub(super) air: NativeSumcheckAir, + pub(super) offline_memory: Arc>>, + // pub record_set: NativeSumcheckRecordSet, + // pub(super) streams: Arc>>, +} + +impl NativeSumcheckChip { + pub fn new( + port: SystemPort, + offline_memory: Arc>>, + ) -> Self { + let air = NativeSumcheckAir { + execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), + memory_bridge: port.memory_bridge, + address_space: F::from_canonical_u32(AS::Native as u32), + }; + + Self { + height: 0, + air, + offline_memory, + } + } +} + +impl InstructionExecutor for NativeSumcheckChip { + fn execute( + &mut self, + memory: &mut MemoryController, + instruction: &Instruction, + from_state: ExecutionState, + ) -> Result, ExecutionError> { + let &Instruction { + opcode: op, + a: output_register, + b: input_register_1, + c: input_register_2, + d: data_address_space, + e: register_address_space, + f: input_register_3, + g: input_register_4, + } = instruction; + + if op == SUMCHECK_LAYER_EVAL.global_opcode() { + println!("=> SUMCHECK_LAYER_EVAL"); + + let (read_ctx_pointer, ctx_pointer) = + memory.read_cell(register_address_space, input_register_1); + let (read_cs_pointer, cs_pointer) = + memory.read_cell(register_address_space, input_register_2); + let (read_prod_pointer, prod_ptr) = + memory.read_cell(register_address_space, input_register_3); + let (read_logup_pointer, logup_ptr) = + memory.read_cell(register_address_space, input_register_4); + + let (alpha_read, alpha) = memory.read::(data_address_space, cs_pointer); + let (c1_read, c1) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 1)); + let (c2_read, c2) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 2)); + + println!("=> c1: {:?}", c1); + println!("=> c2: {:?}", c2); + + // _debug: Calculation formula + } else { + unreachable!() + } + + Ok(ExecutionState { + pc: from_state.pc + DEFAULT_PC_STEP, + timestamp: memory.timestamp(), + }) + } + + + fn get_opcode_name(&self, opcode: usize) -> String { + if opcode == SUMCHECK_LAYER_EVAL.global_opcode().as_usize() { + String::from("SUMCHECK_LAYER_EVAL") + } else { + unreachable!("unsupported opcode: {}", opcode) + } + } +} + +// impl InstructionExecutor +// for NativePoseidon2Chip +// { +// fn execute( +// &mut self, +// memory: &mut MemoryController, +// instruction: &Instruction, +// from_state: ExecutionState, +// ) -> Result, ExecutionError> { +// if instruction.opcode == PERM_POS2.global_opcode() +// || instruction.opcode == COMP_POS2.global_opcode() +// { +// let &Instruction { +// a: output_register, +// b: input_register_1, +// c: input_register_2, +// d: register_address_space, +// e: data_address_space, +// .. +// } = instruction; + +// let (read_output_pointer, output_pointer) = +// memory.read_cell(register_address_space, output_register); +// let (read_input_pointer_1, input_pointer_1) = +// memory.read_cell(register_address_space, input_register_1); +// let (read_input_pointer_2, input_pointer_2) = +// if instruction.opcode == PERM_POS2.global_opcode() { +// memory.increment_timestamp(); +// (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) +// } else { +// let (read_input_pointer_2, input_pointer_2) = +// memory.read_cell(register_address_space, input_register_2); +// (Some(read_input_pointer_2), input_pointer_2) +// }; +// let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); +// let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); +// let p2_input = std::array::from_fn(|i| { +// if i < CHUNK { +// data_1[i] +// } else { +// data_2[i - CHUNK] +// } +// }); +// let output = self.subchip.permute(p2_input); +// let (write_data_1, _) = memory.write::( +// data_address_space, +// output_pointer, +// std::array::from_fn(|i| output[i]), +// ); +// let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { +// Some( +// memory +// .write::( +// data_address_space, +// output_pointer + F::from_canonical_usize(CHUNK), +// std::array::from_fn(|i| output[CHUNK + i]), +// ) +// .0, +// ) +// } else { +// memory.increment_timestamp(); +// None +// }; + +// assert_eq!( +// memory.timestamp(), +// from_state.timestamp + NUM_SIMPLE_ACCESSES +// ); + +// self.record_set +// .simple_permute_records +// .push(SimplePoseidonRecord { +// from_state, +// instruction: instruction.clone(), +// read_input_pointer_1, +// read_input_pointer_2, +// read_output_pointer, +// read_data_1, +// read_data_2, +// write_data_1, +// write_data_2, +// input_pointer_1, +// input_pointer_2, +// output_pointer, +// p2_input, +// }); +// self.height += 1; +// } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { +// let mut observation_records: Vec> = vec![]; + +// let &Instruction { +// a: output_register, +// b: input_register_1, +// c: input_register_2, +// d: data_address_space, +// e: register_address_space, +// f: input_register_3, +// .. +// } = instruction; + +// let (read_sponge_ptr, sponge_ptr) = memory.read_cell(register_address_space, output_register); +// let (read_init_pos, pos) = memory.read_cell(register_address_space, input_register_1); +// let (read_arr_ptr, arr_ptr) = memory.read_cell(register_address_space, input_register_2); +// let init_pos = pos.clone(); + +// let mut pos = pos.as_canonical_u32() as usize; +// let (read_len, len) = memory.read_cell(register_address_space, input_register_3); +// let init_len = len.as_canonical_u32() as usize; +// let mut len = len.as_canonical_u32() as usize; + +// let mut header_record: TranscriptObservationRecord = TranscriptObservationRecord { +// from_state, +// instruction: instruction.clone(), +// curr_timestamp_increment: 0, +// is_first: true, +// state_ptr: sponge_ptr, +// input_ptr: arr_ptr, +// init_pos, +// len: init_len, +// input_register_1, +// input_register_2, +// input_register_3, +// output_register, +// ..Default::default() +// }; +// header_record.read_input_data[0] = read_sponge_ptr; +// header_record.read_input_data[1] = read_arr_ptr; +// header_record.read_input_data[2] = read_init_pos; +// header_record.read_input_data[3] = read_len; + +// observation_records.push(header_record); +// self.height += 1; + +// // Observe bytes +// let mut observation_chunks: Vec<(usize, usize, bool)> = vec![]; +// while len > 0 { +// if len >= (CHUNK - pos) { +// observation_chunks.push((pos.clone(), CHUNK.clone(), true)); +// len -= CHUNK - pos; +// pos = 0; +// } else { +// observation_chunks.push((pos.clone(), pos + len, false)); +// len = 0; +// pos = pos + len; +// } +// } + +// let mut curr_timestamp = 4usize; +// let mut input_idx: usize = 0; +// for chunk in observation_chunks { +// let mut record: TranscriptObservationRecord = TranscriptObservationRecord { +// from_state, +// instruction: instruction.clone(), + +// start_idx: chunk.0, +// end_idx: chunk.1, + +// curr_timestamp_increment: curr_timestamp, +// state_ptr: sponge_ptr, +// input_ptr: arr_ptr, +// init_pos, +// len: init_len, +// curr_len: input_idx, +// input_register_1, +// input_register_2, +// input_register_3, +// output_register, +// ..Default::default() +// }; + +// for j in chunk.0..chunk.1 { +// let (n_read, n_f) = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(input_idx)); +// record.read_input_data[j] = n_read; +// record.input_data[j] = n_f; +// input_idx += 1; +// curr_timestamp += 1; + +// let (n_write, _) = memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(j), n_f); +// record.write_input_data[j] = n_write; +// curr_timestamp += 1; +// } + +// if record.end_idx >= CHUNK { +// let (read_sponge_record, permutation_input) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr); +// let output = self.subchip.permute(permutation_input); +// let (write_sponge_record, _) = memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i])); + +// curr_timestamp += 2; + +// record.should_permute = true; +// record.read_sponge_state = read_sponge_record; +// record.write_sponge_state = write_sponge_record; +// record.permutation_input = permutation_input; +// record.permutation_output = output; +// } + +// observation_records.push(record); +// self.height += 1; +// } + +// let last_record = observation_records.last_mut().unwrap(); +// let final_idx = last_record.end_idx % CHUNK; +// let (write_final, _) = memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(final_idx)); +// last_record.is_last = true; +// last_record.write_final_idx = write_final; +// last_record.final_idx = final_idx; +// curr_timestamp += 1; + +// for record in &mut observation_records { +// record.final_timestamp_increment = curr_timestamp; +// } +// self.record_set.transcript_observation_records.extend(observation_records); +// } else if instruction.opcode == VERIFY_BATCH.global_opcode() { +// let &Instruction { +// a: dim_register, +// b: opened_register, +// c: opened_length_register, +// d: proof_id_ptr, +// e: index_register, +// f: commit_register, +// g: opened_element_size_inv, +// .. +// } = instruction; +// let address_space = self.air.address_space; +// // calc inverse fast assuming opened_element_size in {1, 4} +// let mut opened_element_size = F::ONE; +// while opened_element_size * opened_element_size_inv != F::ONE { +// opened_element_size += F::ONE; +// } + +// let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); +// let (dim_base_pointer_read, dim_base_pointer) = +// memory.read_cell(address_space, dim_register); +// let (opened_base_pointer_read, opened_base_pointer) = +// memory.read_cell(address_space, opened_register); +// let (opened_length_read, opened_length) = +// memory.read_cell(address_space, opened_length_register); +// let (index_base_pointer_read, index_base_pointer) = +// memory.read_cell(address_space, index_register); +// let (commit_pointer_read, commit_pointer) = +// memory.read_cell(address_space, commit_register); +// let (commit_read, commit) = memory.read(address_space, commit_pointer); + +// let opened_length = opened_length.as_canonical_u32() as usize; + +// let initial_log_height = memory +// .unsafe_read_cell(address_space, dim_base_pointer) +// .as_canonical_u32(); +// let mut log_height = initial_log_height as i32; +// let mut sibling_index = 0; +// let mut opened_index = 0; +// let mut top_level = vec![]; + +// let mut root = [F::ZERO; CHUNK]; +// let sibling_proof: Vec<[F; CHUNK]> = { +// let streams = self.streams.lock().unwrap(); +// let proof_idx = proof_id.as_canonical_u32() as usize; +// streams.hint_space[proof_idx] +// .par_chunks(CHUNK) +// .map(|c| c.try_into().unwrap()) +// .collect() +// }; + +// while log_height >= 0 { +// let incorporate_row = if opened_index < opened_length +// && memory.unsafe_read_cell( +// address_space, +// dim_base_pointer + F::from_canonical_usize(opened_index), +// ) == F::from_canonical_u32(log_height as u32) +// { +// let initial_opened_index = opened_index; +// for _ in 0..NUM_INITIAL_READS { +// memory.increment_timestamp(); +// } +// let mut chunks = vec![]; + +// let mut row_pointer = 0; +// let mut row_end = 0; + +// let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None; +// let mut rolling_hash = [F::ZERO; 2 * CHUNK]; + +// let mut is_first_in_segment = true; + +// loop { +// let mut cells = vec![]; +// for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { +// let read_row_pointer_and_length = if is_first_in_segment +// || row_pointer == row_end +// { +// if is_first_in_segment { +// is_first_in_segment = false; +// } else { +// opened_index += 1; +// if opened_index == opened_length +// || memory.unsafe_read_cell( +// address_space, +// dim_base_pointer +// + F::from_canonical_usize(opened_index), +// ) != F::from_canonical_u32(log_height as u32) +// { +// break; +// } +// } +// let (result, [new_row_pointer, row_len]) = memory.read( +// address_space, +// opened_base_pointer + F::from_canonical_usize(2 * opened_index), +// ); +// row_pointer = new_row_pointer.as_canonical_u32() as usize; +// row_end = row_pointer +// + (opened_element_size * row_len).as_canonical_u32() as usize; +// Some(result) +// } else { +// memory.increment_timestamp(); +// None +// }; +// let (read, value) = memory +// .read_cell(address_space, F::from_canonical_usize(row_pointer)); +// cells.push(CellRecord { +// read, +// opened_index, +// read_row_pointer_and_length, +// row_pointer, +// row_end, +// }); +// *chunk_elem = value; +// row_pointer += 1; +// } +// if cells.is_empty() { +// break; +// } +// let cells_len = cells.len(); +// chunks.push(InsideRowRecord { +// cells, +// p2_input: rolling_hash, +// }); +// self.height += 1; +// prev_rolling_hash = Some(rolling_hash); +// self.subchip.permute_mut(&mut rolling_hash); +// if cells_len < CHUNK { +// for _ in 0..CHUNK - cells_len { +// memory.increment_timestamp(); +// memory.increment_timestamp(); +// } +// break; +// } +// } +// let final_opened_index = opened_index - 1; +// let (initial_height_read, height_check) = memory.read_cell( +// address_space, +// dim_base_pointer + F::from_canonical_usize(initial_opened_index), +// ); +// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); +// let (final_height_read, height_check) = memory.read_cell( +// address_space, +// dim_base_pointer + F::from_canonical_usize(final_opened_index), +// ); +// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + +// let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + +// let (p2_input, new_root) = if log_height as u32 == initial_log_height { +// (prev_rolling_hash.unwrap(), hash) +// } else { +// self.compress(root, hash) +// }; +// root = new_root; + +// self.height += 1; +// Some(IncorporateRowRecord { +// chunks, +// initial_opened_index, +// final_opened_index, +// initial_height_read, +// final_height_read, +// p2_input, +// }) +// } else { +// None +// }; + +// let incorporate_sibling = if log_height == 0 { +// None +// } else { +// for _ in 0..NUM_INITIAL_READS { +// memory.increment_timestamp(); +// } + +// let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( +// address_space, +// index_base_pointer + F::from_canonical_usize(sibling_index), +// ); +// let sibling_is_on_right = sibling_is_on_right == F::ONE; +// let sibling = sibling_proof[sibling_index]; +// let (p2_input, new_root) = if sibling_is_on_right { +// self.compress(sibling, root) +// } else { +// self.compress(root, sibling) +// }; +// root = new_root; + +// self.height += 1; +// Some(IncorporateSiblingRecord { +// read_sibling_is_on_right, +// sibling_is_on_right, +// p2_input, +// }) +// }; + +// top_level.push(TopLevelRecord { +// incorporate_row, +// incorporate_sibling, +// }); + +// log_height -= 1; +// sibling_index += 1; +// } + +// assert_eq!(commit, root); +// self.record_set +// .verify_batch_records +// .push(VerifyBatchRecord { +// from_state, +// instruction: instruction.clone(), +// dim_base_pointer, +// opened_base_pointer, +// opened_length, +// index_base_pointer, +// commit_pointer, +// dim_base_pointer_read, +// opened_base_pointer_read, +// opened_length_read, +// index_base_pointer_read, +// commit_pointer_read, +// commit_read, +// initial_log_height: initial_log_height as usize, +// top_level, +// }); +// } else { +// unreachable!() +// } +// Ok(ExecutionState { +// pc: from_state.pc + DEFAULT_PC_STEP, +// timestamp: memory.timestamp(), +// }) +// } + +// fn get_opcode_name(&self, opcode: usize) -> String { +// if opcode == VERIFY_BATCH.global_opcode().as_usize() { +// String::from("VERIFY_BATCH") +// } else if opcode == PERM_POS2.global_opcode().as_usize() { +// String::from("PERM_POS2") +// } else if opcode == COMP_POS2.global_opcode().as_usize() { +// String::from("COMP_POS2") +// } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { +// String::from("MULTI_OBSERVE") +// }else { +// unreachable!("unsupported opcode: {}", opcode) +// } +// } +// } + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +// use std::sync::{Arc, Mutex}; + +// use openvm_circuit::{ +// arch::{ +// ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, +// }, +// system::memory::{MemoryController, OfflineMemory, RecordId}, +// }; +// use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +// use openvm_native_compiler::{ +// conversion::AS, +// Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, +// VerifyBatchOpcode::VERIFY_BATCH, +// }; +// use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; +// use openvm_stark_backend::{ +// p3_field::{Field, PrimeField32}, +// p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, +// }; +// use serde::{Deserialize, Serialize}; + + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(bound = "F: Field")] +// pub struct VerifyBatchRecord { +// pub from_state: ExecutionState, +// pub instruction: Instruction, + +// pub dim_base_pointer: F, +// pub opened_base_pointer: F, +// pub opened_length: usize, +// pub index_base_pointer: F, +// pub commit_pointer: F, + +// pub dim_base_pointer_read: RecordId, +// pub opened_base_pointer_read: RecordId, +// pub opened_length_read: RecordId, +// pub index_base_pointer_read: RecordId, +// pub commit_pointer_read: RecordId, + +// pub commit_read: RecordId, +// pub initial_log_height: usize, +// pub top_level: Vec>, +// } + +// impl VerifyBatchRecord { +// pub fn opened_element_size_inv(&self) -> F { +// self.instruction.g +// } +// } + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(bound = "F: Field")] +// pub struct TopLevelRecord { +// // must be present in first record +// pub incorporate_row: Option>, +// // must be present in all bust last record +// pub incorporate_sibling: Option>, +// } + +// #[repr(C)] +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(bound = "F: Field")] +// pub struct IncorporateSiblingRecord { +// pub read_sibling_is_on_right: RecordId, +// pub sibling_is_on_right: bool, +// pub p2_input: [F; 2 * CHUNK], +// } + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(bound = "F: Field")] +// pub struct IncorporateRowRecord { +// pub chunks: Vec>, +// pub initial_opened_index: usize, +// pub final_opened_index: usize, +// pub initial_height_read: RecordId, +// pub final_height_read: RecordId, +// pub p2_input: [F; 2 * CHUNK], +// } + +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(bound = "F: Field")] +// pub struct InsideRowRecord { +// pub cells: Vec, +// pub p2_input: [F; 2 * CHUNK], +// } + +// #[repr(C)] +// #[derive(Debug, Clone, Serialize, Deserialize)] +// pub struct CellRecord { +// pub read: RecordId, +// pub opened_index: usize, +// pub read_row_pointer_and_length: Option, +// pub row_pointer: usize, +// pub row_end: usize, +// } + +// #[repr(C)] +// #[derive(Debug, Clone, Serialize, Deserialize)] +// #[serde(bound = "F: Field")] +// pub struct SimplePoseidonRecord { +// pub from_state: ExecutionState, +// pub instruction: Instruction, + +// pub read_input_pointer_1: RecordId, +// pub read_input_pointer_2: Option, +// pub read_output_pointer: RecordId, +// pub read_data_1: RecordId, +// pub read_data_2: RecordId, +// pub write_data_1: RecordId, +// pub write_data_2: Option, + +// pub input_pointer_1: F, +// pub input_pointer_2: F, +// pub output_pointer: F, +// pub p2_input: [F; 2 * CHUNK], +// } + +// #[repr(C)] +// #[derive(Debug, Clone, Serialize, Deserialize, Default)] +// #[serde(bound = "F: Field")] +// pub struct TranscriptObservationRecord { +// pub from_state: ExecutionState, +// pub instruction: Instruction, +// pub start_idx: usize, +// pub end_idx: usize, +// pub is_first: bool, +// pub is_last: bool, +// pub curr_timestamp_increment: usize, +// pub final_timestamp_increment: usize, + +// pub state_ptr: F, +// pub input_ptr: F, +// pub init_pos: F, +// pub len: usize, +// pub curr_len: usize, +// pub should_permute: bool, + +// pub read_input_data: [RecordId; CHUNK], +// pub write_input_data: [RecordId; CHUNK], +// pub input_data: [F; CHUNK], + +// pub read_sponge_state: RecordId, +// pub write_sponge_state: RecordId, +// pub permutation_input: [F; 2 * CHUNK], +// pub permutation_output: [F; 2 * CHUNK], + +// pub write_final_idx: RecordId, +// pub final_idx: usize, + +// pub input_register_1: F, +// pub input_register_2: F, +// pub input_register_3: F, +// pub output_register: F, +// } + +// #[derive(Debug, Clone, Serialize, Deserialize, Default)] +// #[serde(bound = "F: Field")] +// pub struct NativePoseidon2RecordSet { +// pub verify_batch_records: Vec>, +// pub simple_permute_records: Vec>, +// pub transcript_observation_records: Vec>, +// } + +// pub struct NativePoseidon2Chip { +// pub(super) air: NativePoseidon2Air, +// pub record_set: NativePoseidon2RecordSet, +// pub height: usize, +// pub(super) offline_memory: Arc>>, +// pub(super) subchip: Poseidon2SubChip, +// pub(super) streams: Arc>>, +// } + +// impl NativePoseidon2Chip { +// pub fn new( +// port: SystemPort, +// offline_memory: Arc>>, +// poseidon2_config: Poseidon2Config, +// verify_batch_bus: VerifyBatchBus, +// streams: Arc>>, +// ) -> Self { +// let air = NativePoseidon2Air { +// execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), +// memory_bridge: port.memory_bridge, +// internal_bus: verify_batch_bus, +// subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), +// address_space: F::from_canonical_u32(AS::Native as u32), +// }; +// Self { +// record_set: Default::default(), +// air, +// height: 0, +// offline_memory, +// subchip: Poseidon2SubChip::new(poseidon2_config.constants), +// streams, +// } +// } + +// fn compress(&self, left: [F; CHUNK], right: [F; CHUNK]) -> ([F; 2 * CHUNK], [F; CHUNK]) { +// let concatenated = +// std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); +// let permuted = self.subchip.permute(concatenated); +// (concatenated, std::array::from_fn(|i| permuted[i])) +// } +// } + +// pub(super) const NUM_INITIAL_READS: usize = 6; +// pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7; + +// impl InstructionExecutor +// for NativePoseidon2Chip +// { +// fn execute( +// &mut self, +// memory: &mut MemoryController, +// instruction: &Instruction, +// from_state: ExecutionState, +// ) -> Result, ExecutionError> { +// if instruction.opcode == PERM_POS2.global_opcode() +// || instruction.opcode == COMP_POS2.global_opcode() +// { +// let &Instruction { +// a: output_register, +// b: input_register_1, +// c: input_register_2, +// d: register_address_space, +// e: data_address_space, +// .. +// } = instruction; + +// let (read_output_pointer, output_pointer) = +// memory.read_cell(register_address_space, output_register); +// let (read_input_pointer_1, input_pointer_1) = +// memory.read_cell(register_address_space, input_register_1); +// let (read_input_pointer_2, input_pointer_2) = +// if instruction.opcode == PERM_POS2.global_opcode() { +// memory.increment_timestamp(); +// (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) +// } else { +// let (read_input_pointer_2, input_pointer_2) = +// memory.read_cell(register_address_space, input_register_2); +// (Some(read_input_pointer_2), input_pointer_2) +// }; +// let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); +// let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); +// let p2_input = std::array::from_fn(|i| { +// if i < CHUNK { +// data_1[i] +// } else { +// data_2[i - CHUNK] +// } +// }); +// let output = self.subchip.permute(p2_input); +// let (write_data_1, _) = memory.write::( +// data_address_space, +// output_pointer, +// std::array::from_fn(|i| output[i]), +// ); +// let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { +// Some( +// memory +// .write::( +// data_address_space, +// output_pointer + F::from_canonical_usize(CHUNK), +// std::array::from_fn(|i| output[CHUNK + i]), +// ) +// .0, +// ) +// } else { +// memory.increment_timestamp(); +// None +// }; + +// assert_eq!( +// memory.timestamp(), +// from_state.timestamp + NUM_SIMPLE_ACCESSES +// ); + +// self.record_set +// .simple_permute_records +// .push(SimplePoseidonRecord { +// from_state, +// instruction: instruction.clone(), +// read_input_pointer_1, +// read_input_pointer_2, +// read_output_pointer, +// read_data_1, +// read_data_2, +// write_data_1, +// write_data_2, +// input_pointer_1, +// input_pointer_2, +// output_pointer, +// p2_input, +// }); +// self.height += 1; +// } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { +// let mut observation_records: Vec> = vec![]; + +// let &Instruction { +// a: output_register, +// b: input_register_1, +// c: input_register_2, +// d: data_address_space, +// e: register_address_space, +// f: input_register_3, +// .. +// } = instruction; + +// let (read_sponge_ptr, sponge_ptr) = memory.read_cell(register_address_space, output_register); +// let (read_init_pos, pos) = memory.read_cell(register_address_space, input_register_1); +// let (read_arr_ptr, arr_ptr) = memory.read_cell(register_address_space, input_register_2); +// let init_pos = pos.clone(); + +// let mut pos = pos.as_canonical_u32() as usize; +// let (read_len, len) = memory.read_cell(register_address_space, input_register_3); +// let init_len = len.as_canonical_u32() as usize; +// let mut len = len.as_canonical_u32() as usize; + +// let mut header_record: TranscriptObservationRecord = TranscriptObservationRecord { +// from_state, +// instruction: instruction.clone(), +// curr_timestamp_increment: 0, +// is_first: true, +// state_ptr: sponge_ptr, +// input_ptr: arr_ptr, +// init_pos, +// len: init_len, +// input_register_1, +// input_register_2, +// input_register_3, +// output_register, +// ..Default::default() +// }; +// header_record.read_input_data[0] = read_sponge_ptr; +// header_record.read_input_data[1] = read_arr_ptr; +// header_record.read_input_data[2] = read_init_pos; +// header_record.read_input_data[3] = read_len; + +// observation_records.push(header_record); +// self.height += 1; + +// // Observe bytes +// let mut observation_chunks: Vec<(usize, usize, bool)> = vec![]; +// while len > 0 { +// if len >= (CHUNK - pos) { +// observation_chunks.push((pos.clone(), CHUNK.clone(), true)); +// len -= CHUNK - pos; +// pos = 0; +// } else { +// observation_chunks.push((pos.clone(), pos + len, false)); +// len = 0; +// pos = pos + len; +// } +// } + +// let mut curr_timestamp = 4usize; +// let mut input_idx: usize = 0; +// for chunk in observation_chunks { +// let mut record: TranscriptObservationRecord = TranscriptObservationRecord { +// from_state, +// instruction: instruction.clone(), + +// start_idx: chunk.0, +// end_idx: chunk.1, + +// curr_timestamp_increment: curr_timestamp, +// state_ptr: sponge_ptr, +// input_ptr: arr_ptr, +// init_pos, +// len: init_len, +// curr_len: input_idx, +// input_register_1, +// input_register_2, +// input_register_3, +// output_register, +// ..Default::default() +// }; + +// for j in chunk.0..chunk.1 { +// let (n_read, n_f) = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(input_idx)); +// record.read_input_data[j] = n_read; +// record.input_data[j] = n_f; +// input_idx += 1; +// curr_timestamp += 1; + +// let (n_write, _) = memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(j), n_f); +// record.write_input_data[j] = n_write; +// curr_timestamp += 1; +// } + +// if record.end_idx >= CHUNK { +// let (read_sponge_record, permutation_input) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr); +// let output = self.subchip.permute(permutation_input); +// let (write_sponge_record, _) = memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i])); + +// curr_timestamp += 2; + +// record.should_permute = true; +// record.read_sponge_state = read_sponge_record; +// record.write_sponge_state = write_sponge_record; +// record.permutation_input = permutation_input; +// record.permutation_output = output; +// } + +// observation_records.push(record); +// self.height += 1; +// } + +// let last_record = observation_records.last_mut().unwrap(); +// let final_idx = last_record.end_idx % CHUNK; +// let (write_final, _) = memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(final_idx)); +// last_record.is_last = true; +// last_record.write_final_idx = write_final; +// last_record.final_idx = final_idx; +// curr_timestamp += 1; + +// for record in &mut observation_records { +// record.final_timestamp_increment = curr_timestamp; +// } +// self.record_set.transcript_observation_records.extend(observation_records); +// } else if instruction.opcode == VERIFY_BATCH.global_opcode() { +// let &Instruction { +// a: dim_register, +// b: opened_register, +// c: opened_length_register, +// d: proof_id_ptr, +// e: index_register, +// f: commit_register, +// g: opened_element_size_inv, +// .. +// } = instruction; +// let address_space = self.air.address_space; +// // calc inverse fast assuming opened_element_size in {1, 4} +// let mut opened_element_size = F::ONE; +// while opened_element_size * opened_element_size_inv != F::ONE { +// opened_element_size += F::ONE; +// } + +// let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); +// let (dim_base_pointer_read, dim_base_pointer) = +// memory.read_cell(address_space, dim_register); +// let (opened_base_pointer_read, opened_base_pointer) = +// memory.read_cell(address_space, opened_register); +// let (opened_length_read, opened_length) = +// memory.read_cell(address_space, opened_length_register); +// let (index_base_pointer_read, index_base_pointer) = +// memory.read_cell(address_space, index_register); +// let (commit_pointer_read, commit_pointer) = +// memory.read_cell(address_space, commit_register); +// let (commit_read, commit) = memory.read(address_space, commit_pointer); + +// let opened_length = opened_length.as_canonical_u32() as usize; + +// let initial_log_height = memory +// .unsafe_read_cell(address_space, dim_base_pointer) +// .as_canonical_u32(); +// let mut log_height = initial_log_height as i32; +// let mut sibling_index = 0; +// let mut opened_index = 0; +// let mut top_level = vec![]; + +// let mut root = [F::ZERO; CHUNK]; +// let sibling_proof: Vec<[F; CHUNK]> = { +// let streams = self.streams.lock().unwrap(); +// let proof_idx = proof_id.as_canonical_u32() as usize; +// streams.hint_space[proof_idx] +// .par_chunks(CHUNK) +// .map(|c| c.try_into().unwrap()) +// .collect() +// }; + +// while log_height >= 0 { +// let incorporate_row = if opened_index < opened_length +// && memory.unsafe_read_cell( +// address_space, +// dim_base_pointer + F::from_canonical_usize(opened_index), +// ) == F::from_canonical_u32(log_height as u32) +// { +// let initial_opened_index = opened_index; +// for _ in 0..NUM_INITIAL_READS { +// memory.increment_timestamp(); +// } +// let mut chunks = vec![]; + +// let mut row_pointer = 0; +// let mut row_end = 0; + +// let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None; +// let mut rolling_hash = [F::ZERO; 2 * CHUNK]; + +// let mut is_first_in_segment = true; + +// loop { +// let mut cells = vec![]; +// for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { +// let read_row_pointer_and_length = if is_first_in_segment +// || row_pointer == row_end +// { +// if is_first_in_segment { +// is_first_in_segment = false; +// } else { +// opened_index += 1; +// if opened_index == opened_length +// || memory.unsafe_read_cell( +// address_space, +// dim_base_pointer +// + F::from_canonical_usize(opened_index), +// ) != F::from_canonical_u32(log_height as u32) +// { +// break; +// } +// } +// let (result, [new_row_pointer, row_len]) = memory.read( +// address_space, +// opened_base_pointer + F::from_canonical_usize(2 * opened_index), +// ); +// row_pointer = new_row_pointer.as_canonical_u32() as usize; +// row_end = row_pointer +// + (opened_element_size * row_len).as_canonical_u32() as usize; +// Some(result) +// } else { +// memory.increment_timestamp(); +// None +// }; +// let (read, value) = memory +// .read_cell(address_space, F::from_canonical_usize(row_pointer)); +// cells.push(CellRecord { +// read, +// opened_index, +// read_row_pointer_and_length, +// row_pointer, +// row_end, +// }); +// *chunk_elem = value; +// row_pointer += 1; +// } +// if cells.is_empty() { +// break; +// } +// let cells_len = cells.len(); +// chunks.push(InsideRowRecord { +// cells, +// p2_input: rolling_hash, +// }); +// self.height += 1; +// prev_rolling_hash = Some(rolling_hash); +// self.subchip.permute_mut(&mut rolling_hash); +// if cells_len < CHUNK { +// for _ in 0..CHUNK - cells_len { +// memory.increment_timestamp(); +// memory.increment_timestamp(); +// } +// break; +// } +// } +// let final_opened_index = opened_index - 1; +// let (initial_height_read, height_check) = memory.read_cell( +// address_space, +// dim_base_pointer + F::from_canonical_usize(initial_opened_index), +// ); +// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); +// let (final_height_read, height_check) = memory.read_cell( +// address_space, +// dim_base_pointer + F::from_canonical_usize(final_opened_index), +// ); +// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + +// let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + +// let (p2_input, new_root) = if log_height as u32 == initial_log_height { +// (prev_rolling_hash.unwrap(), hash) +// } else { +// self.compress(root, hash) +// }; +// root = new_root; + +// self.height += 1; +// Some(IncorporateRowRecord { +// chunks, +// initial_opened_index, +// final_opened_index, +// initial_height_read, +// final_height_read, +// p2_input, +// }) +// } else { +// None +// }; + +// let incorporate_sibling = if log_height == 0 { +// None +// } else { +// for _ in 0..NUM_INITIAL_READS { +// memory.increment_timestamp(); +// } + +// let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( +// address_space, +// index_base_pointer + F::from_canonical_usize(sibling_index), +// ); +// let sibling_is_on_right = sibling_is_on_right == F::ONE; +// let sibling = sibling_proof[sibling_index]; +// let (p2_input, new_root) = if sibling_is_on_right { +// self.compress(sibling, root) +// } else { +// self.compress(root, sibling) +// }; +// root = new_root; + +// self.height += 1; +// Some(IncorporateSiblingRecord { +// read_sibling_is_on_right, +// sibling_is_on_right, +// p2_input, +// }) +// }; + +// top_level.push(TopLevelRecord { +// incorporate_row, +// incorporate_sibling, +// }); + +// log_height -= 1; +// sibling_index += 1; +// } + +// assert_eq!(commit, root); +// self.record_set +// .verify_batch_records +// .push(VerifyBatchRecord { +// from_state, +// instruction: instruction.clone(), +// dim_base_pointer, +// opened_base_pointer, +// opened_length, +// index_base_pointer, +// commit_pointer, +// dim_base_pointer_read, +// opened_base_pointer_read, +// opened_length_read, +// index_base_pointer_read, +// commit_pointer_read, +// commit_read, +// initial_log_height: initial_log_height as usize, +// top_level, +// }); +// } else { +// unreachable!() +// } +// Ok(ExecutionState { +// pc: from_state.pc + DEFAULT_PC_STEP, +// timestamp: memory.timestamp(), +// }) +// } + +// fn get_opcode_name(&self, opcode: usize) -> String { +// if opcode == VERIFY_BATCH.global_opcode().as_usize() { +// String::from("VERIFY_BATCH") +// } else if opcode == PERM_POS2.global_opcode().as_usize() { +// String::from("PERM_POS2") +// } else if opcode == COMP_POS2.global_opcode().as_usize() { +// String::from("COMP_POS2") +// } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { +// String::from("MULTI_OBSERVE") +// }else { +// unreachable!("unsupported opcode: {}", opcode) +// } +// } +// } diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs new file mode 100644 index 0000000000..34ab23860b --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod chip; +// mod columns; +// mod tests; +mod trace; \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs new file mode 100644 index 0000000000..d6c733d392 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -0,0 +1,59 @@ +use std::{borrow::BorrowMut, sync::Arc}; + +use openvm_circuit::system::memory::{MemoryAuxColsFactory, OfflineMemory}; +use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_native_compiler::Poseidon2Opcode::COMP_POS2; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_air::BaseAir, + p3_field::{Field, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::*, + prover::types::AirProofInput, + AirRef, Chip, ChipUsageGetter, +}; +use crate::sumcheck::chip::NativeSumcheckChip; + +impl NativeSumcheckChip { + fn generate_trace(self) -> RowMajorMatrix { + let width = self.trace_width(); + let height = next_power_of_two_or_zero(self.height); + let mut flat_trace = F::zero_vec(width * height); + let memory = self.offline_memory.lock().unwrap(); + let aux_cols_factory = memory.aux_cols_factory(); + let mut used_cells = 0; + + RowMajorMatrix::new(flat_trace, width) + } +} + +impl ChipUsageGetter + for NativeSumcheckChip +{ + fn air_name(&self) -> String { + "SumcheckLayerEval".to_string() + } + + fn current_trace_height(&self) -> usize { + self.height + } + + fn trace_width(&self) -> usize { + // _debug + 0 + } +} + +impl Chip + for NativeSumcheckChip> +where + Val: PrimeField32, +{ + fn air(&self) -> AirRef { + Arc::new(self.air.clone()) + } + fn generate_air_proof_input(self) -> AirProofInput { + AirProofInput::simple_no_pis(self.generate_trace()) + } +} \ No newline at end of file diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 0e0db9cff9..20f4a9dd90 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -632,6 +632,12 @@ impl + TwoAdicField> AsmCo ); } } + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr) => { + self.push( + AsmInstruction::SumcheckLayerEval(input_ctx.ptr().fp(), challenges.ptr().fp(), prod_ptr.fp(), logup_ptr.fp()), + debug_info, + ); + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index cd4990b08b..70e6878023 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -171,6 +171,8 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), + + SumcheckLayerEval(i32, i32, i32, i32), } impl> AsmInstruction { @@ -403,6 +405,9 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr) => { + write!(f, "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp", ctx, cs, p_ptr, l_ptr) + } } } } diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index f8da82c30b..dedefceb52 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -9,10 +9,7 @@ use openvm_stark_backend::p3_field::{ExtensionField, PrimeField32, PrimeField64} use serde::{Deserialize, Serialize}; use crate::{ - asm::{AsmInstruction, AssemblyCode}, - FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, - NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, + asm::{AsmInstruction, AssemblyCode}, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode }; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -535,7 +532,19 @@ fn convert_instruction>( // Here it just requires a 0 AS::Immediate, )] - } + }, + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr) => vec![ + Instruction { + opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), + a: F::ZERO, // _debug + b: i32_f(ctx), + c: i32_f(cs), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(p_ptr), + g: i32_f(l_ptr), + } + ], }; let debug_infos = vec![debug_info; instructions.len()]; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 13f5c4a653..09999f67dd 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -319,6 +319,15 @@ pub enum DslIr { CycleTrackerStart(String), /// End the cycle tracker used by a block of code annotated by the string input. CycleTrackerEnd(String), + + /// Sumcheck calculate layer eval + SumcheckLayerEval( + Array>, // Input ctx: round, num_prod_spec, num_logup_spec, num_variables + Array>, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + // Ptr // output + ) } impl Default for DslIr { diff --git a/extensions/native/compiler/src/ir/mod.rs b/extensions/native/compiler/src/ir/mod.rs index 47e901cd3a..29bb52f086 100644 --- a/extensions/native/compiler/src/ir/mod.rs +++ b/extensions/native/compiler/src/ir/mod.rs @@ -23,6 +23,7 @@ mod types; mod utils; mod var; mod verify_batch; +mod sumcheck; pub trait Config: Clone + Default { type N: PrimeField; diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs new file mode 100644 index 0000000000..15a46f180d --- /dev/null +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -0,0 +1,26 @@ +use openvm_native_compiler_derive::iter_zip; +use openvm_stark_backend::p3_field::FieldAlgebra; +use crate::ir::Variable; +use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; + +impl Builder { + /// Extends native VM ability to calculate the evaluation for a sumcheck layer + pub fn sumcheck_layer_eval ( + &mut self, + input_ctx: Array>, + challenges: Array>, + prod_specs_eval: &Array>, + logup_specs_eval: &Array>, + // r_evals: &Array>, + ) -> Usize { + self.operations.push(DslIr::SumcheckLayerEval( + input_ctx, + challenges, + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), + // r_evals.ptr(), + )); + + Usize::from(0) + } +} \ No newline at end of file diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index 66c786fbd9..7496afbd1f 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -212,3 +212,15 @@ pub enum VerifyBatchOpcode { /// per column polynomial, per opening point VERIFY_BATCH, } + +/// Opcodes for sumcheck. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x180] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum SumcheckOpcode { + /// Computer the evaluation for a sumcheck layer + SUMCHECK_LAYER_EVAL, +} From d18ef80a6b3ba19b00009425c80e972ea358f4b9 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 16 Sep 2025 20:01:52 -0400 Subject: [PATCH 02/41] Complete sumcheck layer execution logic --- extensions/native/circuit/src/fri/mod.rs | 2 +- .../native/circuit/src/sumcheck/chip.rs | 109 +++++++++++++++++- .../native/compiler/src/asm/compiler.rs | 4 +- .../native/compiler/src/asm/instruction.rs | 6 +- .../native/compiler/src/conversion/mod.rs | 4 +- .../native/compiler/src/ir/instructions.rs | 2 +- extensions/native/compiler/src/ir/sumcheck.rs | 4 +- 7 files changed, 114 insertions(+), 17 deletions(-) diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 7dbc3fd851..eabf22ef5c 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -538,7 +538,7 @@ fn assert_array_eq, I2: Into, const } } -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { +pub fn elem_to_ext(elem: F) -> [F; EXT_DEG] { let mut ret = [F::ZERO; EXT_DEG]; ret[0] = elem; ret diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index e5d9603f60..924216a9c9 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -10,6 +10,7 @@ use openvm_stark_backend::{ p3_field::{Field, PrimeField, PrimeField32}, p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, }; +use crate::fri::elem_to_ext; use openvm_native_compiler::{ conversion::AS, SumcheckOpcode::SUMCHECK_LAYER_EVAL, @@ -20,6 +21,16 @@ use crate::{ utils::const_max, }; +fn calculate_3d_ext_idx( + inner_inner_len: F, + inner_len: F, + outer_idx: F, + inner_idx: F, + inner_inner_idx: F, +) -> F { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) * F::from_canonical_usize(EXT_DEG) +} + pub struct NativeSumcheckChip { pub height: usize, pub(super) air: NativeSumcheckAir, @@ -66,8 +77,6 @@ impl InstructionExecutor for NativeSumcheckChip { } = instruction; if op == SUMCHECK_LAYER_EVAL.global_opcode() { - println!("=> SUMCHECK_LAYER_EVAL"); - let (read_ctx_pointer, ctx_pointer) = memory.read_cell(register_address_space, input_register_1); let (read_cs_pointer, cs_pointer) = @@ -76,15 +85,103 @@ impl InstructionExecutor for NativeSumcheckChip { memory.read_cell(register_address_space, input_register_3); let (read_logup_pointer, logup_ptr) = memory.read_cell(register_address_space, input_register_4); + let (read_result_pointer, r_ptr) = + memory.read_cell(register_address_space, output_register); + + let (ctx_read, ctx) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); + + let [ + round, + num_prod_spec, + num_logup_spec, + prod_specs_inner_len, + prod_specs_inner_inner_len, + logup_specs_inner_len, + logup_specs_inner_inner_len, + _, + ] = ctx; let (alpha_read, alpha) = memory.read::(data_address_space, cs_pointer); let (c1_read, c1) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 1)); let (c2_read, c2) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 2)); - println!("=> c1: {:?}", c1); - println!("=> c2: {:?}", c2); - - // _debug: Calculation formula + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + + let mut i = F::ZERO; + while i < num_prod_spec { + let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + F::from_canonical_usize(EXT_DEG * 2) + i); + + if round < (max_round - F::from_canonical_usize(1)) { + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + F::from_canonical_usize(0), + ); + let (read_p1, p1) = memory.read::(data_address_space, prod_ptr + start); + let (read_p2, p2) = memory.read::(data_address_space, prod_ptr + start + F::from_canonical_usize(EXT_DEG)); + let evals = FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ); + + let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + i) * F::from_canonical_usize(EXT_DEG), evals); + + if (round + F::from_canonical_usize(1)) < (max_round - F::from_canonical_usize(1)) { + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, evals)); + } + } + + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + + i = i + F::ONE; + } + + let mut i = F::ZERO; + while i < num_logup_spec { + let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + num_prod_spec + F::from_canonical_usize(EXT_DEG * 2) + i); + + if round < (max_round - F::from_canonical_usize(1)) { + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i, + round, + F::from_canonical_usize(0), + ); + + let (read_p1, p1) = memory.read::(data_address_space, logup_ptr + start); + let (read_p2, p2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG)); + let (read_q1, q1) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 2)); + let (read_q2, q2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 3)); + + let p_evals = FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ); + let q_evals = FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ); + + let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), p_evals); + let (write_slice_eval_2, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + num_logup_spec + i) * F::from_canonical_usize(EXT_DEG), q_evals); + + if (round + F::from_canonical_usize(1)) < (max_round - F::from_canonical_usize(1)) { + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, p_evals)); + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_denominator, q_evals)); + } + } + + alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); + + i = i + F::ONE; + } + + let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); } else { unreachable!() } diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 20f4a9dd90..e5e8655ebb 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -632,9 +632,9 @@ impl + TwoAdicField> AsmCo ); } } - DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr) => { + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { self.push( - AsmInstruction::SumcheckLayerEval(input_ctx.ptr().fp(), challenges.ptr().fp(), prod_ptr.fp(), logup_ptr.fp()), + AsmInstruction::SumcheckLayerEval(input_ctx.ptr().fp(), challenges.ptr().fp(), prod_ptr.fp(), logup_ptr.fp(), r_ptr.fp()), debug_info, ); } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index 70e6878023..8f3ef82733 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -172,7 +172,7 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), - SumcheckLayerEval(i32, i32, i32, i32), + SumcheckLayerEval(i32, i32, i32, i32, i32), } impl> AsmInstruction { @@ -405,8 +405,8 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } - AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr) => { - write!(f, "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp", ctx, cs, p_ptr, l_ptr) + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => { + write!(f, "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", ctx, cs, p_ptr, l_ptr, r_ptr) } } } diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index dedefceb52..f6fffc2db0 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -533,10 +533,10 @@ fn convert_instruction>( AS::Immediate, )] }, - AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr) => vec![ + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ Instruction { opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), - a: F::ZERO, // _debug + a: i32_f(r_ptr), b: i32_f(ctx), c: i32_f(cs), d: AS::Native.to_field(), diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 09999f67dd..a14515cd89 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -326,7 +326,7 @@ pub enum DslIr { Array>, // Challenges: alpha, coeffs Ptr, // prod_specs_eval Ptr, // logup_specs_eval - // Ptr // output + Ptr // output ) } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 15a46f180d..725edcbc86 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -11,14 +11,14 @@ impl Builder { challenges: Array>, prod_specs_eval: &Array>, logup_specs_eval: &Array>, - // r_evals: &Array>, + r_evals: &Array>, ) -> Usize { self.operations.push(DslIr::SumcheckLayerEval( input_ctx, challenges, prod_specs_eval.ptr(), logup_specs_eval.ptr(), - // r_evals.ptr(), + r_evals.ptr(), )); Usize::from(0) From 7846a3ca078045f43567d7a217e27892183bb54d Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 16 Sep 2025 20:06:59 -0400 Subject: [PATCH 03/41] Remove debug flags --- .../native/circuit/src/sumcheck/chip.rs | 1192 +---------------- 1 file changed, 1 insertion(+), 1191 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 924216a9c9..625a7b6803 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -200,1194 +200,4 @@ impl InstructionExecutor for NativeSumcheckChip { unreachable!("unsupported opcode: {}", opcode) } } -} - -// impl InstructionExecutor -// for NativePoseidon2Chip -// { -// fn execute( -// &mut self, -// memory: &mut MemoryController, -// instruction: &Instruction, -// from_state: ExecutionState, -// ) -> Result, ExecutionError> { -// if instruction.opcode == PERM_POS2.global_opcode() -// || instruction.opcode == COMP_POS2.global_opcode() -// { -// let &Instruction { -// a: output_register, -// b: input_register_1, -// c: input_register_2, -// d: register_address_space, -// e: data_address_space, -// .. -// } = instruction; - -// let (read_output_pointer, output_pointer) = -// memory.read_cell(register_address_space, output_register); -// let (read_input_pointer_1, input_pointer_1) = -// memory.read_cell(register_address_space, input_register_1); -// let (read_input_pointer_2, input_pointer_2) = -// if instruction.opcode == PERM_POS2.global_opcode() { -// memory.increment_timestamp(); -// (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) -// } else { -// let (read_input_pointer_2, input_pointer_2) = -// memory.read_cell(register_address_space, input_register_2); -// (Some(read_input_pointer_2), input_pointer_2) -// }; -// let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); -// let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); -// let p2_input = std::array::from_fn(|i| { -// if i < CHUNK { -// data_1[i] -// } else { -// data_2[i - CHUNK] -// } -// }); -// let output = self.subchip.permute(p2_input); -// let (write_data_1, _) = memory.write::( -// data_address_space, -// output_pointer, -// std::array::from_fn(|i| output[i]), -// ); -// let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { -// Some( -// memory -// .write::( -// data_address_space, -// output_pointer + F::from_canonical_usize(CHUNK), -// std::array::from_fn(|i| output[CHUNK + i]), -// ) -// .0, -// ) -// } else { -// memory.increment_timestamp(); -// None -// }; - -// assert_eq!( -// memory.timestamp(), -// from_state.timestamp + NUM_SIMPLE_ACCESSES -// ); - -// self.record_set -// .simple_permute_records -// .push(SimplePoseidonRecord { -// from_state, -// instruction: instruction.clone(), -// read_input_pointer_1, -// read_input_pointer_2, -// read_output_pointer, -// read_data_1, -// read_data_2, -// write_data_1, -// write_data_2, -// input_pointer_1, -// input_pointer_2, -// output_pointer, -// p2_input, -// }); -// self.height += 1; -// } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { -// let mut observation_records: Vec> = vec![]; - -// let &Instruction { -// a: output_register, -// b: input_register_1, -// c: input_register_2, -// d: data_address_space, -// e: register_address_space, -// f: input_register_3, -// .. -// } = instruction; - -// let (read_sponge_ptr, sponge_ptr) = memory.read_cell(register_address_space, output_register); -// let (read_init_pos, pos) = memory.read_cell(register_address_space, input_register_1); -// let (read_arr_ptr, arr_ptr) = memory.read_cell(register_address_space, input_register_2); -// let init_pos = pos.clone(); - -// let mut pos = pos.as_canonical_u32() as usize; -// let (read_len, len) = memory.read_cell(register_address_space, input_register_3); -// let init_len = len.as_canonical_u32() as usize; -// let mut len = len.as_canonical_u32() as usize; - -// let mut header_record: TranscriptObservationRecord = TranscriptObservationRecord { -// from_state, -// instruction: instruction.clone(), -// curr_timestamp_increment: 0, -// is_first: true, -// state_ptr: sponge_ptr, -// input_ptr: arr_ptr, -// init_pos, -// len: init_len, -// input_register_1, -// input_register_2, -// input_register_3, -// output_register, -// ..Default::default() -// }; -// header_record.read_input_data[0] = read_sponge_ptr; -// header_record.read_input_data[1] = read_arr_ptr; -// header_record.read_input_data[2] = read_init_pos; -// header_record.read_input_data[3] = read_len; - -// observation_records.push(header_record); -// self.height += 1; - -// // Observe bytes -// let mut observation_chunks: Vec<(usize, usize, bool)> = vec![]; -// while len > 0 { -// if len >= (CHUNK - pos) { -// observation_chunks.push((pos.clone(), CHUNK.clone(), true)); -// len -= CHUNK - pos; -// pos = 0; -// } else { -// observation_chunks.push((pos.clone(), pos + len, false)); -// len = 0; -// pos = pos + len; -// } -// } - -// let mut curr_timestamp = 4usize; -// let mut input_idx: usize = 0; -// for chunk in observation_chunks { -// let mut record: TranscriptObservationRecord = TranscriptObservationRecord { -// from_state, -// instruction: instruction.clone(), - -// start_idx: chunk.0, -// end_idx: chunk.1, - -// curr_timestamp_increment: curr_timestamp, -// state_ptr: sponge_ptr, -// input_ptr: arr_ptr, -// init_pos, -// len: init_len, -// curr_len: input_idx, -// input_register_1, -// input_register_2, -// input_register_3, -// output_register, -// ..Default::default() -// }; - -// for j in chunk.0..chunk.1 { -// let (n_read, n_f) = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(input_idx)); -// record.read_input_data[j] = n_read; -// record.input_data[j] = n_f; -// input_idx += 1; -// curr_timestamp += 1; - -// let (n_write, _) = memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(j), n_f); -// record.write_input_data[j] = n_write; -// curr_timestamp += 1; -// } - -// if record.end_idx >= CHUNK { -// let (read_sponge_record, permutation_input) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr); -// let output = self.subchip.permute(permutation_input); -// let (write_sponge_record, _) = memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i])); - -// curr_timestamp += 2; - -// record.should_permute = true; -// record.read_sponge_state = read_sponge_record; -// record.write_sponge_state = write_sponge_record; -// record.permutation_input = permutation_input; -// record.permutation_output = output; -// } - -// observation_records.push(record); -// self.height += 1; -// } - -// let last_record = observation_records.last_mut().unwrap(); -// let final_idx = last_record.end_idx % CHUNK; -// let (write_final, _) = memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(final_idx)); -// last_record.is_last = true; -// last_record.write_final_idx = write_final; -// last_record.final_idx = final_idx; -// curr_timestamp += 1; - -// for record in &mut observation_records { -// record.final_timestamp_increment = curr_timestamp; -// } -// self.record_set.transcript_observation_records.extend(observation_records); -// } else if instruction.opcode == VERIFY_BATCH.global_opcode() { -// let &Instruction { -// a: dim_register, -// b: opened_register, -// c: opened_length_register, -// d: proof_id_ptr, -// e: index_register, -// f: commit_register, -// g: opened_element_size_inv, -// .. -// } = instruction; -// let address_space = self.air.address_space; -// // calc inverse fast assuming opened_element_size in {1, 4} -// let mut opened_element_size = F::ONE; -// while opened_element_size * opened_element_size_inv != F::ONE { -// opened_element_size += F::ONE; -// } - -// let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); -// let (dim_base_pointer_read, dim_base_pointer) = -// memory.read_cell(address_space, dim_register); -// let (opened_base_pointer_read, opened_base_pointer) = -// memory.read_cell(address_space, opened_register); -// let (opened_length_read, opened_length) = -// memory.read_cell(address_space, opened_length_register); -// let (index_base_pointer_read, index_base_pointer) = -// memory.read_cell(address_space, index_register); -// let (commit_pointer_read, commit_pointer) = -// memory.read_cell(address_space, commit_register); -// let (commit_read, commit) = memory.read(address_space, commit_pointer); - -// let opened_length = opened_length.as_canonical_u32() as usize; - -// let initial_log_height = memory -// .unsafe_read_cell(address_space, dim_base_pointer) -// .as_canonical_u32(); -// let mut log_height = initial_log_height as i32; -// let mut sibling_index = 0; -// let mut opened_index = 0; -// let mut top_level = vec![]; - -// let mut root = [F::ZERO; CHUNK]; -// let sibling_proof: Vec<[F; CHUNK]> = { -// let streams = self.streams.lock().unwrap(); -// let proof_idx = proof_id.as_canonical_u32() as usize; -// streams.hint_space[proof_idx] -// .par_chunks(CHUNK) -// .map(|c| c.try_into().unwrap()) -// .collect() -// }; - -// while log_height >= 0 { -// let incorporate_row = if opened_index < opened_length -// && memory.unsafe_read_cell( -// address_space, -// dim_base_pointer + F::from_canonical_usize(opened_index), -// ) == F::from_canonical_u32(log_height as u32) -// { -// let initial_opened_index = opened_index; -// for _ in 0..NUM_INITIAL_READS { -// memory.increment_timestamp(); -// } -// let mut chunks = vec![]; - -// let mut row_pointer = 0; -// let mut row_end = 0; - -// let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None; -// let mut rolling_hash = [F::ZERO; 2 * CHUNK]; - -// let mut is_first_in_segment = true; - -// loop { -// let mut cells = vec![]; -// for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { -// let read_row_pointer_and_length = if is_first_in_segment -// || row_pointer == row_end -// { -// if is_first_in_segment { -// is_first_in_segment = false; -// } else { -// opened_index += 1; -// if opened_index == opened_length -// || memory.unsafe_read_cell( -// address_space, -// dim_base_pointer -// + F::from_canonical_usize(opened_index), -// ) != F::from_canonical_u32(log_height as u32) -// { -// break; -// } -// } -// let (result, [new_row_pointer, row_len]) = memory.read( -// address_space, -// opened_base_pointer + F::from_canonical_usize(2 * opened_index), -// ); -// row_pointer = new_row_pointer.as_canonical_u32() as usize; -// row_end = row_pointer -// + (opened_element_size * row_len).as_canonical_u32() as usize; -// Some(result) -// } else { -// memory.increment_timestamp(); -// None -// }; -// let (read, value) = memory -// .read_cell(address_space, F::from_canonical_usize(row_pointer)); -// cells.push(CellRecord { -// read, -// opened_index, -// read_row_pointer_and_length, -// row_pointer, -// row_end, -// }); -// *chunk_elem = value; -// row_pointer += 1; -// } -// if cells.is_empty() { -// break; -// } -// let cells_len = cells.len(); -// chunks.push(InsideRowRecord { -// cells, -// p2_input: rolling_hash, -// }); -// self.height += 1; -// prev_rolling_hash = Some(rolling_hash); -// self.subchip.permute_mut(&mut rolling_hash); -// if cells_len < CHUNK { -// for _ in 0..CHUNK - cells_len { -// memory.increment_timestamp(); -// memory.increment_timestamp(); -// } -// break; -// } -// } -// let final_opened_index = opened_index - 1; -// let (initial_height_read, height_check) = memory.read_cell( -// address_space, -// dim_base_pointer + F::from_canonical_usize(initial_opened_index), -// ); -// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); -// let (final_height_read, height_check) = memory.read_cell( -// address_space, -// dim_base_pointer + F::from_canonical_usize(final_opened_index), -// ); -// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); - -// let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); - -// let (p2_input, new_root) = if log_height as u32 == initial_log_height { -// (prev_rolling_hash.unwrap(), hash) -// } else { -// self.compress(root, hash) -// }; -// root = new_root; - -// self.height += 1; -// Some(IncorporateRowRecord { -// chunks, -// initial_opened_index, -// final_opened_index, -// initial_height_read, -// final_height_read, -// p2_input, -// }) -// } else { -// None -// }; - -// let incorporate_sibling = if log_height == 0 { -// None -// } else { -// for _ in 0..NUM_INITIAL_READS { -// memory.increment_timestamp(); -// } - -// let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( -// address_space, -// index_base_pointer + F::from_canonical_usize(sibling_index), -// ); -// let sibling_is_on_right = sibling_is_on_right == F::ONE; -// let sibling = sibling_proof[sibling_index]; -// let (p2_input, new_root) = if sibling_is_on_right { -// self.compress(sibling, root) -// } else { -// self.compress(root, sibling) -// }; -// root = new_root; - -// self.height += 1; -// Some(IncorporateSiblingRecord { -// read_sibling_is_on_right, -// sibling_is_on_right, -// p2_input, -// }) -// }; - -// top_level.push(TopLevelRecord { -// incorporate_row, -// incorporate_sibling, -// }); - -// log_height -= 1; -// sibling_index += 1; -// } - -// assert_eq!(commit, root); -// self.record_set -// .verify_batch_records -// .push(VerifyBatchRecord { -// from_state, -// instruction: instruction.clone(), -// dim_base_pointer, -// opened_base_pointer, -// opened_length, -// index_base_pointer, -// commit_pointer, -// dim_base_pointer_read, -// opened_base_pointer_read, -// opened_length_read, -// index_base_pointer_read, -// commit_pointer_read, -// commit_read, -// initial_log_height: initial_log_height as usize, -// top_level, -// }); -// } else { -// unreachable!() -// } -// Ok(ExecutionState { -// pc: from_state.pc + DEFAULT_PC_STEP, -// timestamp: memory.timestamp(), -// }) -// } - -// fn get_opcode_name(&self, opcode: usize) -> String { -// if opcode == VERIFY_BATCH.global_opcode().as_usize() { -// String::from("VERIFY_BATCH") -// } else if opcode == PERM_POS2.global_opcode().as_usize() { -// String::from("PERM_POS2") -// } else if opcode == COMP_POS2.global_opcode().as_usize() { -// String::from("COMP_POS2") -// } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { -// String::from("MULTI_OBSERVE") -// }else { -// unreachable!("unsupported opcode: {}", opcode) -// } -// } -// } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -// use std::sync::{Arc, Mutex}; - -// use openvm_circuit::{ -// arch::{ -// ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, -// }, -// system::memory::{MemoryController, OfflineMemory, RecordId}, -// }; -// use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; -// use openvm_native_compiler::{ -// conversion::AS, -// Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, -// VerifyBatchOpcode::VERIFY_BATCH, -// }; -// use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; -// use openvm_stark_backend::{ -// p3_field::{Field, PrimeField32}, -// p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, -// }; -// use serde::{Deserialize, Serialize}; - - -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(bound = "F: Field")] -// pub struct VerifyBatchRecord { -// pub from_state: ExecutionState, -// pub instruction: Instruction, - -// pub dim_base_pointer: F, -// pub opened_base_pointer: F, -// pub opened_length: usize, -// pub index_base_pointer: F, -// pub commit_pointer: F, - -// pub dim_base_pointer_read: RecordId, -// pub opened_base_pointer_read: RecordId, -// pub opened_length_read: RecordId, -// pub index_base_pointer_read: RecordId, -// pub commit_pointer_read: RecordId, - -// pub commit_read: RecordId, -// pub initial_log_height: usize, -// pub top_level: Vec>, -// } - -// impl VerifyBatchRecord { -// pub fn opened_element_size_inv(&self) -> F { -// self.instruction.g -// } -// } - -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(bound = "F: Field")] -// pub struct TopLevelRecord { -// // must be present in first record -// pub incorporate_row: Option>, -// // must be present in all bust last record -// pub incorporate_sibling: Option>, -// } - -// #[repr(C)] -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(bound = "F: Field")] -// pub struct IncorporateSiblingRecord { -// pub read_sibling_is_on_right: RecordId, -// pub sibling_is_on_right: bool, -// pub p2_input: [F; 2 * CHUNK], -// } - -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(bound = "F: Field")] -// pub struct IncorporateRowRecord { -// pub chunks: Vec>, -// pub initial_opened_index: usize, -// pub final_opened_index: usize, -// pub initial_height_read: RecordId, -// pub final_height_read: RecordId, -// pub p2_input: [F; 2 * CHUNK], -// } - -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(bound = "F: Field")] -// pub struct InsideRowRecord { -// pub cells: Vec, -// pub p2_input: [F; 2 * CHUNK], -// } - -// #[repr(C)] -// #[derive(Debug, Clone, Serialize, Deserialize)] -// pub struct CellRecord { -// pub read: RecordId, -// pub opened_index: usize, -// pub read_row_pointer_and_length: Option, -// pub row_pointer: usize, -// pub row_end: usize, -// } - -// #[repr(C)] -// #[derive(Debug, Clone, Serialize, Deserialize)] -// #[serde(bound = "F: Field")] -// pub struct SimplePoseidonRecord { -// pub from_state: ExecutionState, -// pub instruction: Instruction, - -// pub read_input_pointer_1: RecordId, -// pub read_input_pointer_2: Option, -// pub read_output_pointer: RecordId, -// pub read_data_1: RecordId, -// pub read_data_2: RecordId, -// pub write_data_1: RecordId, -// pub write_data_2: Option, - -// pub input_pointer_1: F, -// pub input_pointer_2: F, -// pub output_pointer: F, -// pub p2_input: [F; 2 * CHUNK], -// } - -// #[repr(C)] -// #[derive(Debug, Clone, Serialize, Deserialize, Default)] -// #[serde(bound = "F: Field")] -// pub struct TranscriptObservationRecord { -// pub from_state: ExecutionState, -// pub instruction: Instruction, -// pub start_idx: usize, -// pub end_idx: usize, -// pub is_first: bool, -// pub is_last: bool, -// pub curr_timestamp_increment: usize, -// pub final_timestamp_increment: usize, - -// pub state_ptr: F, -// pub input_ptr: F, -// pub init_pos: F, -// pub len: usize, -// pub curr_len: usize, -// pub should_permute: bool, - -// pub read_input_data: [RecordId; CHUNK], -// pub write_input_data: [RecordId; CHUNK], -// pub input_data: [F; CHUNK], - -// pub read_sponge_state: RecordId, -// pub write_sponge_state: RecordId, -// pub permutation_input: [F; 2 * CHUNK], -// pub permutation_output: [F; 2 * CHUNK], - -// pub write_final_idx: RecordId, -// pub final_idx: usize, - -// pub input_register_1: F, -// pub input_register_2: F, -// pub input_register_3: F, -// pub output_register: F, -// } - -// #[derive(Debug, Clone, Serialize, Deserialize, Default)] -// #[serde(bound = "F: Field")] -// pub struct NativePoseidon2RecordSet { -// pub verify_batch_records: Vec>, -// pub simple_permute_records: Vec>, -// pub transcript_observation_records: Vec>, -// } - -// pub struct NativePoseidon2Chip { -// pub(super) air: NativePoseidon2Air, -// pub record_set: NativePoseidon2RecordSet, -// pub height: usize, -// pub(super) offline_memory: Arc>>, -// pub(super) subchip: Poseidon2SubChip, -// pub(super) streams: Arc>>, -// } - -// impl NativePoseidon2Chip { -// pub fn new( -// port: SystemPort, -// offline_memory: Arc>>, -// poseidon2_config: Poseidon2Config, -// verify_batch_bus: VerifyBatchBus, -// streams: Arc>>, -// ) -> Self { -// let air = NativePoseidon2Air { -// execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), -// memory_bridge: port.memory_bridge, -// internal_bus: verify_batch_bus, -// subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), -// address_space: F::from_canonical_u32(AS::Native as u32), -// }; -// Self { -// record_set: Default::default(), -// air, -// height: 0, -// offline_memory, -// subchip: Poseidon2SubChip::new(poseidon2_config.constants), -// streams, -// } -// } - -// fn compress(&self, left: [F; CHUNK], right: [F; CHUNK]) -> ([F; 2 * CHUNK], [F; CHUNK]) { -// let concatenated = -// std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); -// let permuted = self.subchip.permute(concatenated); -// (concatenated, std::array::from_fn(|i| permuted[i])) -// } -// } - -// pub(super) const NUM_INITIAL_READS: usize = 6; -// pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7; - -// impl InstructionExecutor -// for NativePoseidon2Chip -// { -// fn execute( -// &mut self, -// memory: &mut MemoryController, -// instruction: &Instruction, -// from_state: ExecutionState, -// ) -> Result, ExecutionError> { -// if instruction.opcode == PERM_POS2.global_opcode() -// || instruction.opcode == COMP_POS2.global_opcode() -// { -// let &Instruction { -// a: output_register, -// b: input_register_1, -// c: input_register_2, -// d: register_address_space, -// e: data_address_space, -// .. -// } = instruction; - -// let (read_output_pointer, output_pointer) = -// memory.read_cell(register_address_space, output_register); -// let (read_input_pointer_1, input_pointer_1) = -// memory.read_cell(register_address_space, input_register_1); -// let (read_input_pointer_2, input_pointer_2) = -// if instruction.opcode == PERM_POS2.global_opcode() { -// memory.increment_timestamp(); -// (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) -// } else { -// let (read_input_pointer_2, input_pointer_2) = -// memory.read_cell(register_address_space, input_register_2); -// (Some(read_input_pointer_2), input_pointer_2) -// }; -// let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); -// let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); -// let p2_input = std::array::from_fn(|i| { -// if i < CHUNK { -// data_1[i] -// } else { -// data_2[i - CHUNK] -// } -// }); -// let output = self.subchip.permute(p2_input); -// let (write_data_1, _) = memory.write::( -// data_address_space, -// output_pointer, -// std::array::from_fn(|i| output[i]), -// ); -// let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { -// Some( -// memory -// .write::( -// data_address_space, -// output_pointer + F::from_canonical_usize(CHUNK), -// std::array::from_fn(|i| output[CHUNK + i]), -// ) -// .0, -// ) -// } else { -// memory.increment_timestamp(); -// None -// }; - -// assert_eq!( -// memory.timestamp(), -// from_state.timestamp + NUM_SIMPLE_ACCESSES -// ); - -// self.record_set -// .simple_permute_records -// .push(SimplePoseidonRecord { -// from_state, -// instruction: instruction.clone(), -// read_input_pointer_1, -// read_input_pointer_2, -// read_output_pointer, -// read_data_1, -// read_data_2, -// write_data_1, -// write_data_2, -// input_pointer_1, -// input_pointer_2, -// output_pointer, -// p2_input, -// }); -// self.height += 1; -// } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { -// let mut observation_records: Vec> = vec![]; - -// let &Instruction { -// a: output_register, -// b: input_register_1, -// c: input_register_2, -// d: data_address_space, -// e: register_address_space, -// f: input_register_3, -// .. -// } = instruction; - -// let (read_sponge_ptr, sponge_ptr) = memory.read_cell(register_address_space, output_register); -// let (read_init_pos, pos) = memory.read_cell(register_address_space, input_register_1); -// let (read_arr_ptr, arr_ptr) = memory.read_cell(register_address_space, input_register_2); -// let init_pos = pos.clone(); - -// let mut pos = pos.as_canonical_u32() as usize; -// let (read_len, len) = memory.read_cell(register_address_space, input_register_3); -// let init_len = len.as_canonical_u32() as usize; -// let mut len = len.as_canonical_u32() as usize; - -// let mut header_record: TranscriptObservationRecord = TranscriptObservationRecord { -// from_state, -// instruction: instruction.clone(), -// curr_timestamp_increment: 0, -// is_first: true, -// state_ptr: sponge_ptr, -// input_ptr: arr_ptr, -// init_pos, -// len: init_len, -// input_register_1, -// input_register_2, -// input_register_3, -// output_register, -// ..Default::default() -// }; -// header_record.read_input_data[0] = read_sponge_ptr; -// header_record.read_input_data[1] = read_arr_ptr; -// header_record.read_input_data[2] = read_init_pos; -// header_record.read_input_data[3] = read_len; - -// observation_records.push(header_record); -// self.height += 1; - -// // Observe bytes -// let mut observation_chunks: Vec<(usize, usize, bool)> = vec![]; -// while len > 0 { -// if len >= (CHUNK - pos) { -// observation_chunks.push((pos.clone(), CHUNK.clone(), true)); -// len -= CHUNK - pos; -// pos = 0; -// } else { -// observation_chunks.push((pos.clone(), pos + len, false)); -// len = 0; -// pos = pos + len; -// } -// } - -// let mut curr_timestamp = 4usize; -// let mut input_idx: usize = 0; -// for chunk in observation_chunks { -// let mut record: TranscriptObservationRecord = TranscriptObservationRecord { -// from_state, -// instruction: instruction.clone(), - -// start_idx: chunk.0, -// end_idx: chunk.1, - -// curr_timestamp_increment: curr_timestamp, -// state_ptr: sponge_ptr, -// input_ptr: arr_ptr, -// init_pos, -// len: init_len, -// curr_len: input_idx, -// input_register_1, -// input_register_2, -// input_register_3, -// output_register, -// ..Default::default() -// }; - -// for j in chunk.0..chunk.1 { -// let (n_read, n_f) = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(input_idx)); -// record.read_input_data[j] = n_read; -// record.input_data[j] = n_f; -// input_idx += 1; -// curr_timestamp += 1; - -// let (n_write, _) = memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(j), n_f); -// record.write_input_data[j] = n_write; -// curr_timestamp += 1; -// } - -// if record.end_idx >= CHUNK { -// let (read_sponge_record, permutation_input) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr); -// let output = self.subchip.permute(permutation_input); -// let (write_sponge_record, _) = memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i])); - -// curr_timestamp += 2; - -// record.should_permute = true; -// record.read_sponge_state = read_sponge_record; -// record.write_sponge_state = write_sponge_record; -// record.permutation_input = permutation_input; -// record.permutation_output = output; -// } - -// observation_records.push(record); -// self.height += 1; -// } - -// let last_record = observation_records.last_mut().unwrap(); -// let final_idx = last_record.end_idx % CHUNK; -// let (write_final, _) = memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(final_idx)); -// last_record.is_last = true; -// last_record.write_final_idx = write_final; -// last_record.final_idx = final_idx; -// curr_timestamp += 1; - -// for record in &mut observation_records { -// record.final_timestamp_increment = curr_timestamp; -// } -// self.record_set.transcript_observation_records.extend(observation_records); -// } else if instruction.opcode == VERIFY_BATCH.global_opcode() { -// let &Instruction { -// a: dim_register, -// b: opened_register, -// c: opened_length_register, -// d: proof_id_ptr, -// e: index_register, -// f: commit_register, -// g: opened_element_size_inv, -// .. -// } = instruction; -// let address_space = self.air.address_space; -// // calc inverse fast assuming opened_element_size in {1, 4} -// let mut opened_element_size = F::ONE; -// while opened_element_size * opened_element_size_inv != F::ONE { -// opened_element_size += F::ONE; -// } - -// let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); -// let (dim_base_pointer_read, dim_base_pointer) = -// memory.read_cell(address_space, dim_register); -// let (opened_base_pointer_read, opened_base_pointer) = -// memory.read_cell(address_space, opened_register); -// let (opened_length_read, opened_length) = -// memory.read_cell(address_space, opened_length_register); -// let (index_base_pointer_read, index_base_pointer) = -// memory.read_cell(address_space, index_register); -// let (commit_pointer_read, commit_pointer) = -// memory.read_cell(address_space, commit_register); -// let (commit_read, commit) = memory.read(address_space, commit_pointer); - -// let opened_length = opened_length.as_canonical_u32() as usize; - -// let initial_log_height = memory -// .unsafe_read_cell(address_space, dim_base_pointer) -// .as_canonical_u32(); -// let mut log_height = initial_log_height as i32; -// let mut sibling_index = 0; -// let mut opened_index = 0; -// let mut top_level = vec![]; - -// let mut root = [F::ZERO; CHUNK]; -// let sibling_proof: Vec<[F; CHUNK]> = { -// let streams = self.streams.lock().unwrap(); -// let proof_idx = proof_id.as_canonical_u32() as usize; -// streams.hint_space[proof_idx] -// .par_chunks(CHUNK) -// .map(|c| c.try_into().unwrap()) -// .collect() -// }; - -// while log_height >= 0 { -// let incorporate_row = if opened_index < opened_length -// && memory.unsafe_read_cell( -// address_space, -// dim_base_pointer + F::from_canonical_usize(opened_index), -// ) == F::from_canonical_u32(log_height as u32) -// { -// let initial_opened_index = opened_index; -// for _ in 0..NUM_INITIAL_READS { -// memory.increment_timestamp(); -// } -// let mut chunks = vec![]; - -// let mut row_pointer = 0; -// let mut row_end = 0; - -// let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None; -// let mut rolling_hash = [F::ZERO; 2 * CHUNK]; - -// let mut is_first_in_segment = true; - -// loop { -// let mut cells = vec![]; -// for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { -// let read_row_pointer_and_length = if is_first_in_segment -// || row_pointer == row_end -// { -// if is_first_in_segment { -// is_first_in_segment = false; -// } else { -// opened_index += 1; -// if opened_index == opened_length -// || memory.unsafe_read_cell( -// address_space, -// dim_base_pointer -// + F::from_canonical_usize(opened_index), -// ) != F::from_canonical_u32(log_height as u32) -// { -// break; -// } -// } -// let (result, [new_row_pointer, row_len]) = memory.read( -// address_space, -// opened_base_pointer + F::from_canonical_usize(2 * opened_index), -// ); -// row_pointer = new_row_pointer.as_canonical_u32() as usize; -// row_end = row_pointer -// + (opened_element_size * row_len).as_canonical_u32() as usize; -// Some(result) -// } else { -// memory.increment_timestamp(); -// None -// }; -// let (read, value) = memory -// .read_cell(address_space, F::from_canonical_usize(row_pointer)); -// cells.push(CellRecord { -// read, -// opened_index, -// read_row_pointer_and_length, -// row_pointer, -// row_end, -// }); -// *chunk_elem = value; -// row_pointer += 1; -// } -// if cells.is_empty() { -// break; -// } -// let cells_len = cells.len(); -// chunks.push(InsideRowRecord { -// cells, -// p2_input: rolling_hash, -// }); -// self.height += 1; -// prev_rolling_hash = Some(rolling_hash); -// self.subchip.permute_mut(&mut rolling_hash); -// if cells_len < CHUNK { -// for _ in 0..CHUNK - cells_len { -// memory.increment_timestamp(); -// memory.increment_timestamp(); -// } -// break; -// } -// } -// let final_opened_index = opened_index - 1; -// let (initial_height_read, height_check) = memory.read_cell( -// address_space, -// dim_base_pointer + F::from_canonical_usize(initial_opened_index), -// ); -// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); -// let (final_height_read, height_check) = memory.read_cell( -// address_space, -// dim_base_pointer + F::from_canonical_usize(final_opened_index), -// ); -// assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); - -// let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); - -// let (p2_input, new_root) = if log_height as u32 == initial_log_height { -// (prev_rolling_hash.unwrap(), hash) -// } else { -// self.compress(root, hash) -// }; -// root = new_root; - -// self.height += 1; -// Some(IncorporateRowRecord { -// chunks, -// initial_opened_index, -// final_opened_index, -// initial_height_read, -// final_height_read, -// p2_input, -// }) -// } else { -// None -// }; - -// let incorporate_sibling = if log_height == 0 { -// None -// } else { -// for _ in 0..NUM_INITIAL_READS { -// memory.increment_timestamp(); -// } - -// let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( -// address_space, -// index_base_pointer + F::from_canonical_usize(sibling_index), -// ); -// let sibling_is_on_right = sibling_is_on_right == F::ONE; -// let sibling = sibling_proof[sibling_index]; -// let (p2_input, new_root) = if sibling_is_on_right { -// self.compress(sibling, root) -// } else { -// self.compress(root, sibling) -// }; -// root = new_root; - -// self.height += 1; -// Some(IncorporateSiblingRecord { -// read_sibling_is_on_right, -// sibling_is_on_right, -// p2_input, -// }) -// }; - -// top_level.push(TopLevelRecord { -// incorporate_row, -// incorporate_sibling, -// }); - -// log_height -= 1; -// sibling_index += 1; -// } - -// assert_eq!(commit, root); -// self.record_set -// .verify_batch_records -// .push(VerifyBatchRecord { -// from_state, -// instruction: instruction.clone(), -// dim_base_pointer, -// opened_base_pointer, -// opened_length, -// index_base_pointer, -// commit_pointer, -// dim_base_pointer_read, -// opened_base_pointer_read, -// opened_length_read, -// index_base_pointer_read, -// commit_pointer_read, -// commit_read, -// initial_log_height: initial_log_height as usize, -// top_level, -// }); -// } else { -// unreachable!() -// } -// Ok(ExecutionState { -// pc: from_state.pc + DEFAULT_PC_STEP, -// timestamp: memory.timestamp(), -// }) -// } - -// fn get_opcode_name(&self, opcode: usize) -> String { -// if opcode == VERIFY_BATCH.global_opcode().as_usize() { -// String::from("VERIFY_BATCH") -// } else if opcode == PERM_POS2.global_opcode().as_usize() { -// String::from("PERM_POS2") -// } else if opcode == COMP_POS2.global_opcode().as_usize() { -// String::from("COMP_POS2") -// } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { -// String::from("MULTI_OBSERVE") -// }else { -// unreachable!("unsupported opcode: {}", opcode) -// } -// } -// } +} \ No newline at end of file From 99c9f4ced1da3ab89e68d87489cb4d7940de0f7d Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 17 Sep 2025 18:33:07 -0400 Subject: [PATCH 04/41] Add variation --- .../native/circuit/src/sumcheck/chip.rs | 49 +++++++++++++------ .../native/compiler/src/asm/compiler.rs | 2 +- .../native/compiler/src/ir/instructions.rs | 10 ++-- extensions/native/compiler/src/ir/sumcheck.rs | 8 +-- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 625a7b6803..68ad0a1209 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -98,7 +98,7 @@ impl InstructionExecutor for NativeSumcheckChip { prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, - _, + in_round, ] = ctx; let (alpha_read, alpha) = memory.read::(data_address_space, cs_pointer); @@ -122,14 +122,20 @@ impl InstructionExecutor for NativeSumcheckChip { ); let (read_p1, p1) = memory.read::(data_address_space, prod_ptr + start); let (read_p2, p2) = memory.read::(data_address_space, prod_ptr + start + F::from_canonical_usize(EXT_DEG)); - let evals = FieldExtension::add( - FieldExtension::multiply(p1, c1), - FieldExtension::multiply(p2, c2), - ); + let evals = if in_round > F::ZERO { + FieldExtension::multiply(p1, p2) + } else { + FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ) + }; + let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + i) * F::from_canonical_usize(EXT_DEG), evals); - if (round + F::from_canonical_usize(1)) < (max_round - F::from_canonical_usize(1)) { + let not_in_round = F::ONE - in_round; + if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, evals)); } } @@ -157,19 +163,32 @@ impl InstructionExecutor for NativeSumcheckChip { let (read_q1, q1) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 2)); let (read_q2, q2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 3)); - let p_evals = FieldExtension::add( - FieldExtension::multiply(p1, c1), - FieldExtension::multiply(p2, c2), - ); - let q_evals = FieldExtension::add( - FieldExtension::multiply(q1, c1), - FieldExtension::multiply(q2, c2), - ); + let p_evals = if in_round > F::ZERO { + FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ) + } else { + FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ) + }; + + let q_evals = if in_round > F::ZERO { + FieldExtension::multiply(q1, q2) + } else { + FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ) + }; let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), p_evals); let (write_slice_eval_2, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + num_logup_spec + i) * F::from_canonical_usize(EXT_DEG), q_evals); - if (round + F::from_canonical_usize(1)) < (max_round - F::from_canonical_usize(1)) { + let not_in_round = F::ONE - in_round; + if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, p_evals)); let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_denominator, q_evals)); diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index e5e8655ebb..b701844008 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -634,7 +634,7 @@ impl + TwoAdicField> AsmCo } DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { self.push( - AsmInstruction::SumcheckLayerEval(input_ctx.ptr().fp(), challenges.ptr().fp(), prod_ptr.fp(), logup_ptr.fp(), r_ptr.fp()), + AsmInstruction::SumcheckLayerEval(input_ctx.fp(), challenges.fp(), prod_ptr.fp(), logup_ptr.fp(), r_ptr.fp()), debug_info, ); } diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index a14515cd89..15dde8d7d4 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -322,11 +322,11 @@ pub enum DslIr { /// Sumcheck calculate layer eval SumcheckLayerEval( - Array>, // Input ctx: round, num_prod_spec, num_logup_spec, num_variables - Array>, // Challenges: alpha, coeffs - Ptr, // prod_specs_eval - Ptr, // logup_specs_eval - Ptr // output + Ptr, // Input ctx: round, num_prod_spec, num_logup_spec, num_variables + Ptr, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + Ptr // output ) } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 725edcbc86..5c1564afd9 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -7,15 +7,15 @@ impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer pub fn sumcheck_layer_eval ( &mut self, - input_ctx: Array>, - challenges: Array>, + input_ctx: &Array>, + challenges: &Array>, prod_specs_eval: &Array>, logup_specs_eval: &Array>, r_evals: &Array>, ) -> Usize { self.operations.push(DslIr::SumcheckLayerEval( - input_ctx, - challenges, + input_ctx.ptr(), + challenges.ptr(), prod_specs_eval.ptr(), logup_specs_eval.ptr(), r_evals.ptr(), From 1e06dc7d546a2eda8291099a4604fed9a6263f94 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 21 Sep 2025 04:43:13 -0400 Subject: [PATCH 05/41] Cherry pick column reduction commits --- .../native/circuit/src/sumcheck/chip.rs | 100 ++++++++++++++++-- .../native/circuit/src/sumcheck/columns.rs | 99 +++++++++++++++++ extensions/native/circuit/src/sumcheck/mod.rs | 2 +- 3 files changed, 190 insertions(+), 11 deletions(-) create mode 100644 extensions/native/circuit/src/sumcheck/columns.rs diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 68ad0a1209..89c9050d60 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -10,7 +10,7 @@ use openvm_stark_backend::{ p3_field::{Field, PrimeField, PrimeField32}, p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, }; -use crate::fri::elem_to_ext; +use crate::{fri::elem_to_ext, sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}; use openvm_native_compiler::{ conversion::AS, SumcheckOpcode::SUMCHECK_LAYER_EVAL, @@ -20,6 +20,60 @@ use crate::{ field_extension::{FieldExtension, EXT_DEG}, utils::const_max, }; +use serde::{Deserialize, Serialize}; + +#[repr(C)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(bound = "F: Field")] +pub struct SumcheckEvalRecord { + pub from_state: ExecutionState, + pub instruction: Instruction, + pub row_type: usize, // 0 - header; 1 - prod; 2 - logup + pub curr_timestamp_increment: usize, + pub final_timestamp_increment: usize, + + pub ctx: [F; EXT_DEG * 2], + pub challenges: [F; EXT_DEG * 4], + pub read_data_records: [RecordId; 7], + pub write_data_records: [RecordId; 2], + + pub register_ptrs: [F; 5], + +} +// pub struct TranscriptObservationRecord { +// pub from_state: ExecutionState, +// pub instruction: Instruction, +// pub start_idx: usize, +// pub end_idx: usize, +// pub is_first: bool, +// pub is_last: bool, +// pub curr_timestamp_increment: usize, +// pub final_timestamp_increment: usize, + +// pub state_ptr: F, +// pub input_ptr: F, +// pub init_pos: F, +// pub len: usize, +// pub curr_len: usize, +// pub should_permute: bool, + +// pub read_input_data: [RecordId; CHUNK], +// pub write_input_data: [RecordId; CHUNK], +// pub input_data: [F; CHUNK], + +// pub read_sponge_state: RecordId, +// pub write_sponge_state: RecordId, +// pub permutation_input: [F; 2 * CHUNK], +// pub permutation_output: [F; 2 * CHUNK], + +// pub write_final_idx: RecordId, +// pub final_idx: usize, + +// pub input_register_1: F, +// pub input_register_2: F, +// pub input_register_3: F, +// pub output_register: F, +// } fn calculate_3d_ext_idx( inner_inner_len: F, @@ -77,6 +131,11 @@ impl InstructionExecutor for NativeSumcheckChip { } = instruction; if op == SUMCHECK_LAYER_EVAL.global_opcode() { + println!("=> column width: {:?}", NativeSumcheckCols::::width()); + println!("=> header width: {:?}", HeaderSpecificCols::::width()); + println!("=> prod width: {:?}", ProdSpecificCols::::width()); + println!("=> logup width: {:?}", LogupSpecificCols::::width()); + let (read_ctx_pointer, ctx_pointer) = memory.read_cell(register_address_space, input_register_1); let (read_cs_pointer, cs_pointer) = @@ -101,12 +160,19 @@ impl InstructionExecutor for NativeSumcheckChip { in_round, ] = ctx; - let (alpha_read, alpha) = memory.read::(data_address_space, cs_pointer); - let (c1_read, c1) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 1)); - let (c2_read, c2) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 2)); + let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = memory.read::<{EXT_DEG * 4}>(data_address_space, cs_pointer); + + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); + + // let (alpha_read, alpha) = memory.read::(data_address_space, cs_pointer); + // let (c1_read, c1) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 1)); + // let (c2_read, c2) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 2)); let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + self.height += 1; let mut i = F::ZERO; while i < num_prod_spec { @@ -120,8 +186,12 @@ impl InstructionExecutor for NativeSumcheckChip { round, F::from_canonical_usize(0), ); - let (read_p1, p1) = memory.read::(data_address_space, prod_ptr + start); - let (read_p2, p2) = memory.read::(data_address_space, prod_ptr + start + F::from_canonical_usize(EXT_DEG)); + // let (read_p1, p1) = memory.read::(data_address_space, prod_ptr + start); + // let (read_p2, p2) = memory.read::(data_address_space, prod_ptr + start + F::from_canonical_usize(EXT_DEG)); + + let (read_p, ps) = memory.read::<{EXT_DEG * 2}>(data_address_space, prod_ptr + start); + let p1: [F; 4] = ps[0..EXT_DEG].try_into().expect(""); + let p2: [F; 4] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); let evals = if in_round > F::ZERO { FieldExtension::multiply(p1, p2) @@ -143,6 +213,7 @@ impl InstructionExecutor for NativeSumcheckChip { alpha_acc = FieldExtension::multiply(alpha_acc, alpha); i = i + F::ONE; + self.height += 1; } let mut i = F::ZERO; @@ -158,10 +229,16 @@ impl InstructionExecutor for NativeSumcheckChip { F::from_canonical_usize(0), ); - let (read_p1, p1) = memory.read::(data_address_space, logup_ptr + start); - let (read_p2, p2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG)); - let (read_q1, q1) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 2)); - let (read_q2, q2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 3)); + // let (read_p1, p1) = memory.read::(data_address_space, logup_ptr + start); + // let (read_p2, p2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG)); + // let (read_q1, q1) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 2)); + // let (read_q2, q2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 3)); + + let (read_pqs, pqs) = memory.read::<{EXT_DEG * 4}>(data_address_space, logup_ptr + start); + let p1: [F; 4] = pqs[0..EXT_DEG].try_into().expect(""); + let p2: [F; 4] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let q1: [F; 4] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); + let q2: [F; 4] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); let p_evals = if in_round > F::ZERO { FieldExtension::add( @@ -198,9 +275,12 @@ impl InstructionExecutor for NativeSumcheckChip { alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); i = i + F::ONE; + self.height += 1; } let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); + + println!("=> current_height: {:?}", self.height); } else { unreachable!() } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs new file mode 100644 index 0000000000..e0e466c0f9 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -0,0 +1,99 @@ +use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use crate::field_extension::EXT_DEG; +use crate::utils::const_max; + +const fn max3(a: usize, b: usize, c: usize) -> usize { + const_max(a, const_max(b, c)) +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct NativeSumcheckCols { + /// Indicates that this row is the header for a layer sum operation + pub header_row: T, + /// Indicates that this row is a step for prod_spec in the layer sum operation + pub prod_row: T, + /// Indicates that this row is a step for logup_spec in the layer sum operation + pub logup_row: T, + + // Register values + pub register_ptrs: [T; 5], + + // Context variables + // [ + // round, + // num_prod_spec, + // num_logup_spec, + // prod_spec_inner_len, + // prod_spec_inner_inner_len, + // logup_spec_inner_len, + // logup_spec_inner_inner_len, + // in_layer, + // ] + pub ctx: [T; EXT_DEG * 2], + + pub curr_prod_n: T, + pub curr_logup_n: T, + + // alpha1, c1, c2, alpha2 (for logup rows) + pub challenges: [T; EXT_DEG * 4], + + // Specific to each row + pub max_round: T, + + // The current final evaluation accumulator. Extension element. + pub eval_acc: [T; EXT_DEG], + + // /// 1. For header row, 5 registers, ctx, challenges + // /// 2. For the rest: max_variables, p1, p2, q1, q2 + // pub read_records: [MemoryReadAuxCols; 7], + // /// 1. For header row, write final result + // /// 2. For prod rows: write prod_evals + // /// 3. For logup rows: write q_evals, p_evals + // pub write_records: [MemoryWriteAuxCols; 2], + + pub specific: [T; max3( + HeaderSpecificCols::::width(), + ProdSpecificCols::::width(), + LogupSpecificCols::::width(), + )] +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct HeaderSpecificCols { + /// 5 register reads + ctx read + challenges read + pub read_records: [MemoryReadAuxCols; 7], + /// Write the final evaluation + pub write_records: MemoryWriteAuxCols +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ProdSpecificCols { + /// 2 extension elements + pub p: [T; EXT_DEG * 2], + /// read 2 p values + pub read_record: MemoryReadAuxCols, + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// write p_evals + pub write_record: MemoryWriteAuxCols, +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct LogupSpecificCols { + /// 4 extension elements + pub pq: [T; EXT_DEG * 4], + /// read 4 values: p1, p2, q1, q2 + pub read_record: MemoryReadAuxCols, + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// Calculated q evals + pub q_evals: [T; EXT_DEG], + + /// write both p_evals and q_evals + pub write_records: [MemoryWriteAuxCols; 2], +} \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs index 34ab23860b..8b6cab5165 100644 --- a/extensions/native/circuit/src/sumcheck/mod.rs +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -1,5 +1,5 @@ pub mod air; pub mod chip; -// mod columns; +mod columns; // mod tests; mod trace; \ No newline at end of file From 8c33bc6cb56606f1624b8bda5245739971f44453 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 21 Sep 2025 20:58:05 -0400 Subject: [PATCH 06/41] Add record assignments --- extensions/native/circuit/src/sumcheck/air.rs | 76 +++++++- .../native/circuit/src/sumcheck/chip.rs | 168 +++++++++++++----- .../native/circuit/src/sumcheck/columns.rs | 8 + .../native/circuit/src/sumcheck/trace.rs | 58 ++++-- 4 files changed, 242 insertions(+), 68 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index bd55e88091..7cdba76f1d 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -1,7 +1,9 @@ +use std::{array::from_fn, borrow::Borrow, sync::Arc}; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, }; +use openvm_circuit_primitives::utils::{assert_array_eq, not}; use openvm_stark_backend::{ air_builders::sub::SubAirBuilder, interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, @@ -11,6 +13,8 @@ use openvm_stark_backend::{ rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; +use crate::{sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, EXT_DEG}; + #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, @@ -20,8 +24,7 @@ pub struct NativeSumcheckAir { impl BaseAir for NativeSumcheckAir { fn width(&self) -> usize { - // _debug - 0 + NativeSumcheckCols::::width() } } @@ -39,6 +42,73 @@ impl Air for NativeSumcheckAir { fn eval(&self, builder: &mut AB) { - // _debug + let main = builder.main(); + let local = main.row_slice(0); + let local: &NativeSumcheckCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &NativeSumcheckCols = (*next).borrow(); + + let &NativeSumcheckCols { + header_row, + prod_row, + logup_row, + first_timestamp, + start_timestamp, + last_timestamp, + register_ptrs, + ctx, + curr_prod_n, + curr_logup_n, + alpha, + challenges, + max_round, + should_acc, + eval_acc, + specific, + } = local; + + builder.assert_bool(header_row); + builder.assert_bool(prod_row); + builder.assert_bool(logup_row); + let enabled = header_row + prod_row + logup_row; + builder.assert_bool(enabled.clone()); + + // Carry along columns + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); + assert_array_eq::<_, _, _, {EXT_DEG * 2}>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") + ); + + // Row transitions + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(ctx[1]); + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(logup_row) + .when(not(next.logup_row)) + .assert_eq(ctx[2], curr_logup_n); + + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let prod_row_specific: &ProdSpecificCols = + specific[..ProdSpecificCols::::width()].borrow(); + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + + } } \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 89c9050d60..e967bafc75 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -32,48 +32,27 @@ pub struct SumcheckEvalRecord { pub curr_timestamp_increment: usize, pub final_timestamp_increment: usize, + pub register_ptrs: [F; 5], pub ctx: [F; EXT_DEG * 2], pub challenges: [F; EXT_DEG * 4], pub read_data_records: [RecordId; 7], pub write_data_records: [RecordId; 2], - pub register_ptrs: [F; 5], - + pub max_round: F, + pub should_acc: bool, + pub prod_spec_n: usize, + pub logup_spec_n: usize, + pub alpha: [F; EXT_DEG], + pub alpha1: [F; EXT_DEG], + pub alpha2: [F; EXT_DEG], + pub p1: [F; EXT_DEG], + pub p2: [F; EXT_DEG], + pub q1: [F; EXT_DEG], + pub q2: [F; EXT_DEG], + pub p_evals: [F; EXT_DEG], + pub q_evals: [F; EXT_DEG], + pub eval_acc: [F; EXT_DEG], } -// pub struct TranscriptObservationRecord { -// pub from_state: ExecutionState, -// pub instruction: Instruction, -// pub start_idx: usize, -// pub end_idx: usize, -// pub is_first: bool, -// pub is_last: bool, -// pub curr_timestamp_increment: usize, -// pub final_timestamp_increment: usize, - -// pub state_ptr: F, -// pub input_ptr: F, -// pub init_pos: F, -// pub len: usize, -// pub curr_len: usize, -// pub should_permute: bool, - -// pub read_input_data: [RecordId; CHUNK], -// pub write_input_data: [RecordId; CHUNK], -// pub input_data: [F; CHUNK], - -// pub read_sponge_state: RecordId, -// pub write_sponge_state: RecordId, -// pub permutation_input: [F; 2 * CHUNK], -// pub permutation_output: [F; 2 * CHUNK], - -// pub write_final_idx: RecordId, -// pub final_idx: usize, - -// pub input_register_1: F, -// pub input_register_2: F, -// pub input_register_3: F, -// pub output_register: F, -// } fn calculate_3d_ext_idx( inner_inner_len: F, @@ -89,7 +68,8 @@ pub struct NativeSumcheckChip { pub height: usize, pub(super) air: NativeSumcheckAir, pub(super) offline_memory: Arc>>, - // pub record_set: NativeSumcheckRecordSet, + pub record_set: Vec>, + // _debug // pub(super) streams: Arc>>, } @@ -108,6 +88,7 @@ impl NativeSumcheckChip { height: 0, air, offline_memory, + record_set: Default::default(), } } } @@ -131,6 +112,10 @@ impl InstructionExecutor for NativeSumcheckChip { } = instruction; if op == SUMCHECK_LAYER_EVAL.global_opcode() { + let mut observation_records: Vec> = vec![]; + let mut curr_timestamp: usize = 0; + + // _debug println!("=> column width: {:?}", NativeSumcheckCols::::width()); println!("=> header width: {:?}", HeaderSpecificCols::::width()); println!("=> prod width: {:?}", ProdSpecificCols::::width()); @@ -147,7 +132,7 @@ impl InstructionExecutor for NativeSumcheckChip { let (read_result_pointer, r_ptr) = memory.read_cell(register_address_space, output_register); - let (ctx_read, ctx) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); + let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); let [ round, @@ -166,17 +151,55 @@ impl InstructionExecutor for NativeSumcheckChip { let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); - // let (alpha_read, alpha) = memory.read::(data_address_space, cs_pointer); - // let (c1_read, c1) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 1)); - // let (c2_read, c2) = memory.read::(data_address_space, cs_pointer + F::from_canonical_usize(EXT_DEG * 2)); - let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + + let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; + let mut header_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 0, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + read_data_records: [ + read_ctx_pointer, + read_cs_pointer, + read_prod_pointer, + read_logup_pointer, + read_result_pointer, + ctx_read, + challenges_read, + ], + alpha, + ..Default::default() + }; + observation_records.push(header_row); self.height += 1; + curr_timestamp += 7; let mut i = F::ZERO; + let mut i_usize = 0usize; while i < num_prod_spec { + let mut prod_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 1, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + alpha, + prod_spec_n: i_usize, + ..Default::default() + }; + prod_row.alpha1 = alpha_acc; + let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + F::from_canonical_usize(EXT_DEG * 2) + i); + prod_row.max_round = max_round; + prod_row.read_data_records[0] = read_max_round; + curr_timestamp += 1; if round < (max_round - F::from_canonical_usize(1)) { let start = calculate_3d_ext_idx( @@ -186,13 +209,15 @@ impl InstructionExecutor for NativeSumcheckChip { round, F::from_canonical_usize(0), ); - // let (read_p1, p1) = memory.read::(data_address_space, prod_ptr + start); - // let (read_p2, p2) = memory.read::(data_address_space, prod_ptr + start + F::from_canonical_usize(EXT_DEG)); let (read_p, ps) = memory.read::<{EXT_DEG * 2}>(data_address_space, prod_ptr + start); let p1: [F; 4] = ps[0..EXT_DEG].try_into().expect(""); let p2: [F; 4] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + prod_row.read_data_records[1] = read_p; + prod_row.p1 = p1; + prod_row.p2 = p2; + let evals = if in_round > F::ZERO { FieldExtension::multiply(p1, p2) } else { @@ -201,24 +226,46 @@ impl InstructionExecutor for NativeSumcheckChip { FieldExtension::multiply(p2, c2), ) }; + prod_row.p_evals = evals; let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + i) * F::from_canonical_usize(EXT_DEG), evals); + prod_row.write_data_records[0] = write_slice_eval_1; let not_in_round = F::ONE - in_round; if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, evals)); + prod_row.should_acc = true; + prod_row.eval_acc = eval_acc.clone(); } + curr_timestamp += 2; } alpha_acc = FieldExtension::multiply(alpha_acc, alpha); i = i + F::ONE; + i_usize += 1; + observation_records.push(prod_row); self.height += 1; } let mut i = F::ZERO; + let mut i_usize = 0usize; while i < num_logup_spec { + let mut logup_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 2, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + logup_spec_n: i_usize, + ..Default::default() + }; let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + num_prod_spec + F::from_canonical_usize(EXT_DEG * 2) + i); + logup_row.max_round = max_round; + logup_row.read_data_records[0] = read_max_round; + curr_timestamp += 1; if round < (max_round - F::from_canonical_usize(1)) { let start = calculate_3d_ext_idx( @@ -229,17 +276,18 @@ impl InstructionExecutor for NativeSumcheckChip { F::from_canonical_usize(0), ); - // let (read_p1, p1) = memory.read::(data_address_space, logup_ptr + start); - // let (read_p2, p2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG)); - // let (read_q1, q1) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 2)); - // let (read_q2, q2) = memory.read::(data_address_space, logup_ptr + start + F::from_canonical_usize(EXT_DEG * 3)); - let (read_pqs, pqs) = memory.read::<{EXT_DEG * 4}>(data_address_space, logup_ptr + start); let p1: [F; 4] = pqs[0..EXT_DEG].try_into().expect(""); let p2: [F; 4] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); let q1: [F; 4] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); let q2: [F; 4] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); + logup_row.read_data_records[1] = read_pqs; + logup_row.p1 = p1; + logup_row.p2 = p2; + logup_row.q1 = q1; + logup_row.q2 = q2; + let p_evals = if in_round > F::ZERO { FieldExtension::add( FieldExtension::multiply(p1, q2), @@ -261,25 +309,47 @@ impl InstructionExecutor for NativeSumcheckChip { ) }; + logup_row.p_evals = p_evals; + logup_row.q_evals = q_evals; + let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), p_evals); let (write_slice_eval_2, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + num_logup_spec + i) * F::from_canonical_usize(EXT_DEG), q_evals); + logup_row.write_data_records[0] = write_slice_eval_1; + logup_row.write_data_records[1] = write_slice_eval_2; + let not_in_round = F::ONE - in_round; if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, p_evals)); let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_denominator, q_evals)); + + logup_row.should_acc = true; + logup_row.alpha1 = alpha_acc; + logup_row.alpha2 = alpha_denominator; + logup_row.eval_acc = eval_acc.clone(); } + curr_timestamp += 3; } alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); i = i + F::ONE; + i_usize += 1; + observation_records.push(logup_row); self.height += 1; } let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); + curr_timestamp += 1; + observation_records[0].write_data_records[0] = write_r; + + for record in &mut observation_records { + record.final_timestamp_increment = curr_timestamp; + record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); + } + self.record_set.extend(observation_records); println!("=> current_height: {:?}", self.height); } else { unreachable!() diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index e0e466c0f9..82f47f1396 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -17,6 +17,11 @@ pub struct NativeSumcheckCols { /// Indicates that this row is a step for logup_spec in the layer sum operation pub logup_row: T, + /// Timestamps + pub first_timestamp: T, + pub start_timestamp: T, + pub last_timestamp: T, + // Register values pub register_ptrs: [T; 5], @@ -37,10 +42,13 @@ pub struct NativeSumcheckCols { pub curr_logup_n: T, // alpha1, c1, c2, alpha2 (for logup rows) + pub alpha: [T; EXT_DEG], pub challenges: [T; EXT_DEG * 4], // Specific to each row pub max_round: T, + // Should the evaluation be accumualted + pub should_acc: T, // The current final evaluation accumulator. Extension element. pub eval_acc: [T; EXT_DEG], diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index d6c733d392..71277d2025 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -13,20 +13,7 @@ use openvm_stark_backend::{ prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, }; -use crate::sumcheck::chip::NativeSumcheckChip; - -impl NativeSumcheckChip { - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); - let mut used_cells = 0; - - RowMajorMatrix::new(flat_trace, width) - } -} +use crate::sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}; impl ChipUsageGetter for NativeSumcheckChip @@ -40,8 +27,47 @@ impl ChipUsageGetter } fn trace_width(&self) -> usize { - // _debug - 0 + NativeSumcheckCols::::width() + } +} + +impl NativeSumcheckChip { + fn generate_trace(self) -> RowMajorMatrix { + let width = self.trace_width(); + let height = next_power_of_two_or_zero(self.height); + let mut flat_trace: Vec = F::zero_vec(width * height); + + let memory = self.offline_memory.lock().unwrap(); + let aux_cols_factory = memory.aux_cols_factory(); + + let mut used_cells = 0; + for record in self.record_set { + let slice = &mut flat_trace[used_cells..used_cells + width]; + let cols: &mut NativeSumcheckCols = slice.borrow_mut(); + cols.first_timestamp = F::from_canonical_u32(record.from_state.timestamp); + cols.start_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.curr_timestamp_increment); + cols.last_timestamp = F::from_canonical_usize(record.final_timestamp_increment); + + if record.row_type == 0 { + cols.header_row = F::ONE; + let header: &mut HeaderSpecificCols = + cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + } else if record.row_type == 1 { + cols.prod_row = F::ONE; + let prod: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + } else if record.row_type == 2 { + cols.logup_row = F::ONE; + let logup: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + } else { + unreachable!() + } + + used_cells += width; + } + + RowMajorMatrix::new(flat_trace, width) } } From 5f1903e77b404032d665e44a0c8e8903d7e1c128 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 15:33:56 -0400 Subject: [PATCH 07/41] Adjust unit test --- .gitignore | 3 + extensions/native/circuit/src/sumcheck/air.rs | 155 ++++++- .../native/circuit/src/sumcheck/chip.rs | 13 +- .../native/circuit/src/sumcheck/columns.rs | 10 +- extensions/native/recursion/tests/sumcheck.rs | 422 ++++++++++++++++++ 5 files changed, 593 insertions(+), 10 deletions(-) create mode 100644 extensions/native/recursion/tests/sumcheck.rs diff --git a/.gitignore b/.gitignore index d794a5dc57..aceaa65679 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ Cargo.lock **/.env .DS_Store +# Log outputs +*.log + .cache/ rustc-* diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 7cdba76f1d..2737bd1ffc 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -1,9 +1,11 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress}, }; use openvm_circuit_primitives::utils::{assert_array_eq, not}; +use openvm_instructions::LocalOpcode; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; use openvm_stark_backend::{ air_builders::sub::SubAirBuilder, interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, @@ -12,14 +14,13 @@ use openvm_stark_backend::{ p3_matrix::Matrix, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; - use crate::{sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, EXT_DEG}; #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, - pub(crate) address_space: F, + pub address_space: F, } impl BaseAir for NativeSumcheckAir { @@ -73,6 +74,68 @@ impl Air let enabled = header_row + prod_row + logup_row; builder.assert_bool(enabled.clone()); + // Header + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let registers = header_row_specific.registers; + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), + [ + registers[4].into(), + registers[0].into(), + registers[1].into(), + self.address_space.into(), + self.address_space.into(), + registers[2].into(), + registers[3].into(), + ], + ExecutionState::new(header_row_specific.pc, first_timestamp), + last_timestamp, + ) + .eval(builder, header_row); + + for i in 0..5usize { + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, registers[i]), + [register_ptrs[i]], + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], + ) + .eval(builder, header_row); + } + + /* _debug + + // Read context variables + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0]), + ctx, + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[5], + ) + .eval(builder, header_row); + + // Read challenges + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[1]), + challenges, + first_timestamp + AB::F::from_canonical_usize(7), + &header_row_specific.read_records[6], + ) + .eval(builder, header_row); + + + // Separate aggregate column clusters + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); + let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); + let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); + // Carry along columns assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); @@ -81,6 +144,7 @@ impl Air challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") ); + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); // Row transitions builder @@ -102,13 +166,94 @@ impl Air .when(not(next.logup_row)) .assert_eq(ctx[2], curr_logup_n); - let header_row_specific: &HeaderSpecificCols = - specific[..HeaderSpecificCols::::width()].borrow(); + + + + + // Prod spec evaluation let prod_row_specific: &ProdSpecificCols = specific[..ProdSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2 - 1) + curr_prod_n), + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[2] + (ctx[4] * ctx[3] * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row); + + let p1: [_; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); + let p2: [_; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row); + + // Logup spec evaluation let logup_row_specific: &LogupSpecificCols = specific[..LogupSpecificCols::::width()].borrow(); + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0] + ctx[1] + AB::F::from_canonical_usize(EXT_DEG * 2 - 1) + curr_logup_n), + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[2] + (ctx[4] * ctx[3] * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().expect(""); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{EXT_DEG * 3}].try_into().expect(""); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row); + + // Termination condition + */ } } \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index e967bafc75..34d3ad189d 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -33,6 +33,7 @@ pub struct SumcheckEvalRecord { pub final_timestamp_increment: usize, pub register_ptrs: [F; 5], + pub registers: [F; 5], pub ctx: [F; EXT_DEG * 2], pub challenges: [F; EXT_DEG * 4], pub read_data_records: [RecordId; 7], @@ -131,7 +132,7 @@ impl InstructionExecutor for NativeSumcheckChip { memory.read_cell(register_address_space, input_register_4); let (read_result_pointer, r_ptr) = memory.read_cell(register_address_space, output_register); - + let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); let [ @@ -161,6 +162,13 @@ impl InstructionExecutor for NativeSumcheckChip { row_type: 0, curr_timestamp_increment: curr_timestamp, register_ptrs, + registers: [ + input_register_1, + input_register_2, + input_register_3, + input_register_4, + output_register, + ], ctx, challenges, read_data_records: [ @@ -179,6 +187,8 @@ impl InstructionExecutor for NativeSumcheckChip { self.height += 1; curr_timestamp += 7; + /* + let mut i = F::ZERO; let mut i_usize = 0usize; while i < num_prod_spec { @@ -351,6 +361,7 @@ impl InstructionExecutor for NativeSumcheckChip { self.record_set.extend(observation_records); println!("=> current_height: {:?}", self.height); + */ } else { unreachable!() } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 82f47f1396..c4f59c2e8a 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -71,6 +71,8 @@ pub struct NativeSumcheckCols { #[repr(C)] #[derive(AlignedBorrow)] pub struct HeaderSpecificCols { + pub pc: T, + pub registers: [T; 5], /// 5 register reads + ctx read + challenges read pub read_records: [MemoryReadAuxCols; 7], /// Write the final evaluation @@ -82,8 +84,8 @@ pub struct HeaderSpecificCols { pub struct ProdSpecificCols { /// 2 extension elements pub p: [T; EXT_DEG * 2], - /// read 2 p values - pub read_record: MemoryReadAuxCols, + /// read max varibale and 2 p values + pub read_records: [MemoryReadAuxCols; 2], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// write p_evals @@ -95,8 +97,8 @@ pub struct ProdSpecificCols { pub struct LogupSpecificCols { /// 4 extension elements pub pq: [T; EXT_DEG * 4], - /// read 4 values: p1, p2, q1, q2 - pub read_record: MemoryReadAuxCols, + /// read max variable and 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 2], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// Calculated q evals diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs new file mode 100644 index 0000000000..7a9bde1ec1 --- /dev/null +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -0,0 +1,422 @@ +use itertools::Itertools; +use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmConfig, VmExecutor, verify_single, VirtualMachine,}; +use openvm_native_circuit::{Native, NativeConfig, EXT_DEG}; +use openvm_native_compiler::{ + prelude::*, + asm::{AsmBuilder, AsmCompiler}, ir::{Felt, Ext, Usize}, + conversion::{convert_program, CompilerOptions}, +}; +use openvm_native_recursion::{testing_utils::inner::run_recursive_test, challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}}; +use openvm_stark_backend::{ + config::{Domain, StarkGenericConfig}, + p3_commit::PolynomialSpace, + p3_field::{extension::BinomialExtensionField, FieldAlgebra}, +}; +use openvm_stark_sdk::{ + config::FriParameters, + p3_baby_bear::BabyBear, + utils::ProofInputForTest, + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Engine, + fri_params::standard_fri_params_with_100_bits_conjectured_security, + }, + engine::StarkFriEngine, + utils::create_seeded_rng, +}; +use rand::Rng; +pub type F = BabyBear; +pub type E = BinomialExtensionField; + +#[test] +fn test_sumcheck_layer_eval() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + // Fill in test program logic + builder.halt(); + + + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + // let program = Program::from_instructions(&instructions); + let program: Program<_> = convert_program(asm_code, compilation_options); + let sumcheck_max_constraint_degree = 3; + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + + let vm = VirtualMachine::new(engine, config); + + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, vec![]).unwrap(); + let proofs = vm.prove(&pk, result); + + /* + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + } + */ +} + +fn build_test_program( + builder: &mut Builder, +) { + /* + let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); + for (idx, n) in ctx_u32s.into_iter().enumerate() { + builder.set(&ctx, idx, Usize::from(n as usize)); + } + + let challenges_u32s = [ + 548478283u32, 456436544, 1716290291, 791326976, + 1829717553, 1422025771, 1917123958, 727015942, + 183548369, 591240150, 96141963, 1286249979, + ]; + let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); + for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { + let f1 = builder.constant(C::F::from_canonical_u32(n[0])); + let f2 = builder.constant(C::F::from_canonical_u32(n[1])); + let f3 = builder.constant(C::F::from_canonical_u32(n[2])); + let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + + let e = builder.felts2ext(&[f1, f2, f3, f4]); + builder.set(&challenges, idx, e); + } + + let prod_spec_eval_u32s = [ + 1538906710u32, 637535518, 1753132406, 1395236651, + 278806441, 1722910382, 1475548665, 1117874675, + 1578586709, 1826764884, 384068476, 1852240363, + 707958906, 1960944944, 183554399, 1259273357, + 227285124, 243066436, 1718037317, 369721963, + 1752968006, 1061013677, 775617499, 1464907431, + 544300429, 871461966, 135151545, 1343592602, + 1622220528, 643966158, 3932580, 434948358, + 540553922, 1446502052, 153298741, 1191216273, + 265936762, 1463035257, 1237633339, 1797346310, + 1355791584, 389527741, 1741650463, 1728913415, + 1825739540, 1790924136, 460776743, 29536554, + 6842036, 252495270, 1968285155, 299467416, + 49085744, 1499815729, 1098802236, 644489275, + 1827273105, 1888401527, 390077051, 565528894, + 1366177188, 67441791, 958486301, 402056716, + 590379691, 462035406, 633459131, 843304872, + 584100013, 1932496508, 250656031, 146983915, + 1835173157, 939973454, 1844873638, 1916054832, + 1601784696, 167251717, 409107688, 1062925788, + 1291319514, 1790529531, 495655592, 1093359708, + 790197205, 674458164, 195988318, 399764452, + 106865258, 967050329, 350035523, 1109292118, + 1815460301, 281986036, 900636603, 1121197008, + 1228976590, 1879998708, 1924332706, 434695844, + 1159360621, 471397106, 473371067, 1009065094, + 1320176846, 168020789, 1265321929, 1901808675, + 223657700, 1480150183, 1779968584, 144416591, + 304407746, 1864498679, 1482460119, 1554376965, + 1479261548, 1657723043, 1039345063, 1053923521, + 442080513, 1964082352, 691664908, 1941008321, + 1007729002, 860529393, 849697342, 754485488, + 584295923, 1072251466, 1105105254, 996079746, + 1305909868, 1348028973, 122275988, 464050036, + 692807777, 1098809324, 397235220, 596459886, + 1663209783, 720230826, 1422510715, 1760654694, + 544197700, 1417744567, 1938716517, 1571826328, + 1591430185, 1173137446, 175285007, 1541718596, + 1715958587, 1429966110, 583013357, 1667787861, + 109891172, 668253167, 161783842, 296183397, + 1681897325, 1054396117, 264741948, 464026995, + 1907686022, 1532786783, 394869458, 1766734740, + 136047179, 536856195, 376188855, 700633625, + 515518419, 531043483, 60673499, 556496527, + 1743028981, 873954569, 1371062291, 632169731, + 1353239206, 526507035, 1894490088, 589441599, + 1610487168, 1074160583, 366366374, 247602990, + 1535354896, 894493713, 1555870413, 1389854934, + 1897251683, 1525812801, 675621735, 697919636, + 1690274072, 1466810921, 1221110784, 1741995587, + 1877169764, 390876982, 1794129810, 297662156, + 144295349, 417037264, 1290835727, 1654971513, + 1674131303, 1625667423, 1471248832, 1676797844, + 1172916558, 1707775403, 423725211, 1643279661, + 1695774264, 378140395, 1517569394, 1666625392, + 1803981250, 439036260, 247966130, 709534816, + 361144100, 1546096548, 1240886454, 1898161518, + 843262057, 1709259464, 1301015977, 1997626928, + 677153173, 1606710353, 1216038070, 435565562, + 98686333, 1773787396, 267051994, 99395396, + 545509105, 782289675, 1289865975, 1707775075, + 1158993015, 1506576588, 993215179, 1523099397, + 923914455, 1895162386, 284489994, 1444139016, + 1943825680, 466202724, 1632522710, 1384015062, + 723147188, 1284031324, 1430481515, 341213007, + 171192499, 1061688239, 808927167, 83182639, + 759209907, 1728321272, 976049976, 1652071995, + 1002877840, 69880246, 1095135165, 677588420, + 1384715290, 829619452, 170122781, 1958173727, + 13389238, 789379698, 1883383039, 1279195174, + 1618672336, 1192839317, 1348311124, 758896285, + 1939775389, 684108413, 1838340479, 1332232130, + 1070486028, 549228790, 868851698, 1678207843, + 1754321489, 637000403, 647901906, 45343322, + 1768524074, 1167955205, 1816497210, 1609414096, + 1985231742, 1540534482, 232730819, 232221968, + 1509637836, 1480860627, 884647789, 1096458024, + 163721583, 1248032262, 436419506, 1737102298, + 651105860, 452298073, 1064372507, 1792838683, + 619243471, 860127631, 721724708, 950768433, + 279913448, 339693210, 47730422, 1952683911, + 1316500770, 675944216, 386902809, 619333956, + 1194800389, 43989936, 1944372656, 666045666, + 1155873844, 522696968, 58874730, 1497238023, + 421619994, 1980672127, 1657191856, 1913792631, + 1784663131, 1118400672, 1828104993, 1637808383, + 414755472, 775410449, 747132157, 136820101, + 1082674285, 93190395, 357955402, 335652723, + 1192102705, 480365232, 1354935730, 1391829361, + 966662991, 1601510445, 569528575, 545490940, + 1753711688, 807025222, 580374183, 587718008, + 977546290, 1055719519, 1157107032, 562799608, + 859466927, 840450024, 815325134, 936576801, + 1010587056, 246624382, 1808049797, 1098183398, + 1005077390, 772432546, 1976629565, 1003772218, + 1655315418, 1767931114, 982008720, 785023351, + ]; + + let prod_spec_evals: Array> = builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let f1 = builder.constant(C::F::from_canonical_u32(n[0])); + let f2 = builder.constant(C::F::from_canonical_u32(n[1])); + let f3 = builder.constant(C::F::from_canonical_u32(n[2])); + let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + + let e = builder.felts2ext(&[f1, f2, f3, f4]); + builder.set(&prod_spec_evals, idx, e); + } + + let logup_spec_eval_u32s = [ + 1522353967u32, 457603397, 421847521, 1352563318, + 1746817766, 737872688, 1087008622, 1850835028, + 456475558, 892966330, 638163666, 148568548, + 678863061, 1334386850, 1896333039, 154585769, + 433618446, 1186936470, 970218722, 1213827097, + 1798557019, 861757965, 119285527, 395360622, + 226164366, 1330279872, 66561048, 785421608, + 1950755756, 1559889596, 348449876, 1090789452, + 257578851, 273164442, 1644906, 295600924, + 1187949602, 1168249609, 469763604, 60929061, + 291163036, 403842501, 1421902433, 1700188477, + 1046093370, 921059131, 1638991894, 464012042, + 96905857, 1370999592, 271896041, 13595534, + 1489760970, 1650552701, 133367846, 25680377, + 377631580, 652729291, 645763356, 426747355, + 482475486, 1877299223, 103226636, 1333832358, + 1399609097, 458536972, 976248802, 1109365280, + 515164588, 1579426417, 1601829549, 607169702, + 852817956, 1980537127, 134138338, 913344050, + 737880920, 476360275, 61624034, 1610624252, + 264461991, 546933535, 937769429, 293346965, + 1522058041, 1012551797, 994330314, 23333322, + 1969510890, 974351570, 2012030621, 120742000, + 450250620, 180547360, 642746933, 1815029950, + 629489142, 1176992624, 723354779, 572648755, + 1218615348, 648847054, 351903235, 723149764, + 248065753, 243829448, 1283393001, 1912627886, + 581641342, 702465306, 205969758, 1061911274, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1703043252, 1467887451, 1714319214, 907866644, + 1542426838, 742609036, 1814393459, 448706641, + 1960340767, 46490834, 186512520, 363973095, + 846448854, 463742343, 2012517527, 40473617, + 9472552, 263483342, 105738598, 586389136, + 254290990, 625150844, 960233097, 1488303724, + 1700231692, 1471714612, 1540211186, 1590246915, + 945341972, 1343225515, 179976237, 34857822, + 276912528, 984309272, 1277293398, 1520924162, + 1823117694, 604836357, 1460812009, 600052559, + 970469338, 1771022707, 181855831, 1445947220, + 467514809, 1514677498, 947030389, 170390653, + 415409007, 1601463730, 204153427, 904614278, + 1855419512, 2009471607, 1352607379, 576586082, + 1343812879, 1176377580, 1166188815, 1592289048, + 761793881, 1529621462, 193034837, 344011596, + 1669461833, 1356800025, 314186361, 586497329, + 1832810846, 1288092861, 1619454491, 732529408, + 737934269, 909504928, 769680420, 1437893101, + 1727002258, 1618231110, 535125583, 153412473, + 1917760929, 588586507, 564531165, 1790797737, + 1666283994, 1366948884, 117673690, 476470378, + 2012274032, 1951406668, 1739767532, 1273142151, + 1591812317, 1900205312, 1912608761, 1734766024, + 1265002082, 1450462894, 749810837, 1329222552, + 745081805, 1231519431, 1420957967, 883846107, + 1995463911, 407795592, 161655852, 125886157, + 995318920, 484905024, 284135318, 551493419, + 406742309, 1089024446, 637339867, 1858138403, + 1230680117, 187078889, 1929517480, 1125646261, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1610035932, 462442436, 831412555, 44798862, + 1748147276, 1911945531, 1329343740, 971894393, + 362147969, 1583335926, 1528700112, 426908674, + 847905883, 447889090, 1050883911, 1883537469, + 1487501632, 964178870, 1818828551, 1980840799, + 340372118, 1697179193, 215113037, 1893217470, + 1138628493, 1788052486, 443362955, 1349213730, + 589553425, 562526667, 1006040406, 1194546769, + 1831034644, 612004157, 730213913, 1068905440, + 371983982, 502900790, 802785198, 822377635, + 1477528437, 501356237, 684668525, 1306043781, + 621032592, 1971342708, 1411586583, 733418745, + 186045462, 1559301855, 323758310, 453170140, + 498381240, 976247416, 631213663, 898017829, + 501459603, 609703046, 1379288251, 177682695, + 912381595, 121915494, 1137416430, 504054388, + 1138277238, 1603388253, 1838013301, 1700271853, + 20488607, 58775264, 217974275, 979141729, + 53136584, 1331566240, 1460303356, 525812787, + 718385521, 1477919263, 1663622276, 1089788203, + 1204483837, 54225863, 290660186, 1441441958, + 134168813, 349638823, 1867912015, 1579183319, + 55528656, 1602973359, 194297109, 949763297, + 101931919, 242300116, 1610052257, 1351823848, + 174522860, 776955925, 1706962365, 808187490, + 1487253852, 431806906, 213982593, 1170647308, + 1776840400, 295916317, 378708073, 381270341, + 457494568, 705823997, 1407301442, 1693003013, + 700310785, 1349874247, 1284363817, 1566253815, + 1014298154, 215294365, 1070968678, 871641358, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1302679751, 1121894357, 368587356, 1564724097, + 733815591, 2012670011, 1146780092, 1439780227, + 1801628424, 838692317, 932318853, 213634365, + 155292454, 1644317110, 1599846194, 978829059, + 1282095862, 1780431647, 527412087, 1024583705, + 804423802, 951808322, 689345230, 180304167, + 1784562773, 1514653374, 2009396440, 1143778943, + 235299446, 1553017484, 475425117, 758292254, + 716575432, 517083432, 1728864125, 418010549, + 43202592, 507659742, 433077118, 1268144019, + 1462778342, 1928073362, 1330130180, 1749624351, + 827401013, 1236194147, 1875519726, 1437946791, + 607293265, 309229599, 1009445595, 1725229718, + 1436309341, 1952606463, 943149111, 291680468, + 1989684076, 1944713370, 1285294139, 399758737, + 1572979232, 213817406, 214840530, 184898060, + 1483844295, 1536616777, 494816009, 217625163, + 529448032, 786640964, 1766471731, 1424140424, + 1721961711, 740275169, 169908711, 913969302, + 1359358267, 1328322971, 593228769, 771095186, + 801680440, 450930656, 1796349530, 1824428677, + 1111258504, 1741666629, 1098430204, 1792001884, + 1679003061, 590088446, 647614538, 1324461639, + 818996796, 229187928, 74288115, 1158900266, + 1512606270, 1381672753, 785927403, 493453164, + 425259497, 1367873539, 931023744, 221202218, + 669580668, 424996238, 1840425275, 1873362670, + 967642716, 263556335, 578560519, 1558449223, + 607579284, 1724012378, 333582342, 1195784167, + 1419727276, 199294290, 138807165, 1061030752, + 1, 0, 0, 0, + 1, 0, 0, 0, + 776332180, 1333076185, 1855163818, 1897408938, + 799274251, 950452503, 691904988, 1205387466, + 659107883, 434394982, 129587940, 639018629, + 659238594, 1957584892, 864291238, 589178070, + 1267157231, 48925338, 200093884, 1953762869, + 1227617341, 1471420621, 193077633, 1007876111, + 228491220, 1377349503, 1889411060, 1807513892, + 1593042934, 1240864695, 1472870721, 583021932, + 598239104, 1862008818, 1811242869, 780768026, + 520870395, 292016292, 322246659, 868240490, + 1715620331, 1183509209, 2010262726, 1003957251, + 264895455, 307755941, 201990485, 1662471178, + 1643997923, 1573129362, 277821143, 388834470, + 943361405, 1449402196, 614413575, 1504113993, + 1860552739, 1755127315, 1734129760, 1232115188, + 803035456, 360488092, 271342171, 1269544258, + 290642673, 660703582, 986842267, 870891877, + 454573044, 1999346236, 701614601, 820253867, + 883282765, 137247873, 1727164949, 1320585493, + 1738664600, 1900116905, 472215154, 1114994489, + 104218174, 1694603079, 771486383, 935361143, + 92277671, 881040480, 925124484, 1464396527, + 100625197, 65290355, 1001454341, 134627585, + 58629702, 1541542242, 568583607, 1706262052, + 530687550, 1303187245, 1010302462, 264001857, + 789816678, 561378226, 827432508, 801307507, + 1613508315, 1650822853, 1603502703, 439320335, + 15283580, 1244486577, 254345266, 1745653280, + 1648250354, 1528271018, 528366563, 1078707735, + 1430767759, 1890467731, 2001894083, 799949326, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1341839494, 1092219735, 755644898, 966729319, + 1914277278, 1545367697, 1765189119, 1693413008, + ]; + + let logup_spec_evals: Array> = builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let f1 = builder.constant(C::F::from_canonical_u32(n[0])); + let f2 = builder.constant(C::F::from_canonical_u32(n[1])); + let f3 = builder.constant(C::F::from_canonical_u32(n[2])); + let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + + let e = builder.felts2ext(&[f1, f2, f3, f4]); + builder.set(&logup_spec_evals, idx, e); + } + + let r_evals_u32s = [ + 941378355u32, 1078920879, 696738840, 496039492, + 1555445457, 184545404, 905938226, 1847966044, + 1024875886, 1782716223, 1625644635, 266865456, + 465953066, 1663531470, 757423849, 1957075986, + 1919693393, 839104130, 127480221, 1527842912, + 918650796, 921462354, 575456073, 696646705, + 1585912361, 258186488, 353168830, 1111094691, + 1401166558, 1905942163, 1923083163, 393037255, + 1042127700, 1126793296, 895794165, 1124924482, + 1324266058, 722406365, 1963838171, 968504459, + 1934378800, 714588691, 6465911, 1168379648, + 903786009, 1326035939, 518289228, 418998914, + 1513133474, 1578096058, 617547414, 1658315126, + 68556894, 1697802593, 1346510664, 1709381671, + 345062962, 1254089535, 1002281845, 1882822096, + 700581748, 1431345304, 489112954, 98435728, + 1799886007, 479788390, 223111065, 631662309, + ]; + + let next_layer_evals: Array> = builder.dyn_array(r_evals_u32s.len() / EXT_DEG); + for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { + let f1 = builder.constant(C::F::from_canonical_u32(n[0])); + let f2 = builder.constant(C::F::from_canonical_u32(n[1])); + let f3 = builder.constant(C::F::from_canonical_u32(n[2])); + let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + + let e = builder.felts2ext(&[f1, f2, f3, f4]); + builder.set(&next_layer_evals, idx, e); + } + */ + + // builder.sumcheck_layer_eval(&ctx, &challenges, &prod_spec_evals, &logup_spec_evals, &next_layer_evals); +} From 8f780da7dbd88bd56af735e3fac75fa39320a9c9 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 17:05:33 -0400 Subject: [PATCH 08/41] Correct header row register reads --- extensions/native/circuit/src/sumcheck/air.rs | 19 ++++--- .../native/circuit/src/sumcheck/chip.rs | 38 +++++++++++-- .../native/circuit/src/sumcheck/trace.rs | 21 +++++++- extensions/native/recursion/tests/sumcheck.rs | 54 ++++++++++--------- 4 files changed, 94 insertions(+), 38 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 2737bd1ffc..ad7fad2495 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -92,11 +92,14 @@ impl Air registers[3].into(), ], ExecutionState::new(header_row_specific.pc, first_timestamp), - last_timestamp, + last_timestamp - first_timestamp, ) .eval(builder, header_row); - for i in 0..5usize { + // Read registers + // _debug + // for i in 0..5usize { + for i in 0..1usize { self.memory_bridge .read( MemoryAddress::new(self.address_space, registers[i]), @@ -107,14 +110,13 @@ impl Air .eval(builder, header_row); } - /* _debug - - // Read context variables + /* + // React ctx self.memory_bridge .read( MemoryAddress::new(self.address_space, register_ptrs[0]), ctx, - first_timestamp + AB::F::from_canonical_usize(6), + first_timestamp + AB::F::from_canonical_usize(5), &header_row_specific.read_records[5], ) .eval(builder, header_row); @@ -124,12 +126,13 @@ impl Air .read( MemoryAddress::new(self.address_space, register_ptrs[1]), challenges, - first_timestamp + AB::F::from_canonical_usize(7), + first_timestamp + AB::F::from_canonical_usize(6), &header_row_specific.read_records[6], ) .eval(builder, header_row); + */ - + /* _debug // Separate aggregate column clusters let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 34d3ad189d..a3560432d2 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -124,6 +124,9 @@ impl InstructionExecutor for NativeSumcheckChip { let (read_ctx_pointer, ctx_pointer) = memory.read_cell(register_address_space, input_register_1); + + // _debug + /* let (read_cs_pointer, cs_pointer) = memory.read_cell(register_address_space, input_register_2); let (read_prod_pointer, prod_ptr) = @@ -156,6 +159,8 @@ impl InstructionExecutor for NativeSumcheckChip { let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; + + let mut header_row: SumcheckEvalRecord = SumcheckEvalRecord { from_state, instruction: instruction.clone(), @@ -183,9 +188,35 @@ impl InstructionExecutor for NativeSumcheckChip { alpha, ..Default::default() }; + */ + + + // _debug + let mut header_row = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 0, + curr_timestamp_increment: curr_timestamp, + registers: [ + input_register_1, + input_register_2, + input_register_3, + input_register_4, + output_register, + ], + ..Default::default() + }; + println!("=> ctx_pointer: {:?}", ctx_pointer); + header_row.register_ptrs[0] = ctx_pointer; + println!("=> read_ctx_pointer: {:?}", read_ctx_pointer); + header_row.read_data_records[0] = read_ctx_pointer; + observation_records.push(header_row); self.height += 1; - curr_timestamp += 7; + // _debug + // curr_timestamp += 7; + curr_timestamp += 1; + /* @@ -354,14 +385,15 @@ impl InstructionExecutor for NativeSumcheckChip { curr_timestamp += 1; observation_records[0].write_data_records[0] = write_r; + */ for record in &mut observation_records { record.final_timestamp_increment = curr_timestamp; - record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); + // _debug + // record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); } self.record_set.extend(observation_records); println!("=> current_height: {:?}", self.height); - */ } else { unreachable!() } diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 71277d2025..24aff2984e 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -46,12 +46,31 @@ impl NativeSumcheckChip { let cols: &mut NativeSumcheckCols = slice.borrow_mut(); cols.first_timestamp = F::from_canonical_u32(record.from_state.timestamp); cols.start_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.curr_timestamp_increment); - cols.last_timestamp = F::from_canonical_usize(record.final_timestamp_increment); + cols.last_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.final_timestamp_increment); + cols.register_ptrs = record.register_ptrs; + cols.ctx = record.ctx; + cols.challenges = record.challenges; if record.row_type == 0 { cols.header_row = F::ONE; let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + header.pc = F::from_canonical_u32(record.from_state.pc); + header.registers = record.registers; + + // registers, ctx, challenges + // _debug + for i in 0..1usize { + let mem_record = memory.record_by_id(record.read_data_records[i]); + aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); + } + + + + + + } else if record.row_type == 1 { cols.prod_row = F::ONE; let prod: &mut ProdSpecificCols = diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 7a9bde1ec1..48de36b5ce 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -10,7 +10,7 @@ use openvm_native_recursion::{testing_utils::inner::run_recursive_test, challeng use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, p3_commit::PolynomialSpace, - p3_field::{extension::BinomialExtensionField, FieldAlgebra}, + p3_field::{extension::BinomialExtensionField, FieldAlgebra, PackedValue, FieldExtensionAlgebra}, }; use openvm_stark_sdk::{ config::FriParameters, @@ -69,23 +69,22 @@ fn test_sumcheck_layer_eval() { let result = vm.execute_and_generate(program, vec![]).unwrap(); let proofs = vm.prove(&pk, result); - /* for proof in proofs { verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); } - */ } fn build_test_program( builder: &mut Builder, ) { - /* + let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; let ctx: Array> = builder.dyn_array(ctx_u32s.len()); for (idx, n) in ctx_u32s.into_iter().enumerate() { builder.set(&ctx, idx, Usize::from(n as usize)); } + let challenges_u32s = [ 548478283u32, 456436544, 1716290291, 791326976, 1829717553, 1422025771, 1917123958, 727015942, @@ -93,12 +92,13 @@ fn build_test_program( ]; let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { - let f1 = builder.constant(C::F::from_canonical_u32(n[0])); - let f2 = builder.constant(C::F::from_canonical_u32(n[1])); - let f3 = builder.constant(C::F::from_canonical_u32(n[2])); - let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); - let e = builder.felts2ext(&[f1, f2, f3, f4]); builder.set(&challenges, idx, e); } @@ -203,12 +203,13 @@ fn build_test_program( let prod_spec_evals: Array> = builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { - let f1 = builder.constant(C::F::from_canonical_u32(n[0])); - let f2 = builder.constant(C::F::from_canonical_u32(n[1])); - let f3 = builder.constant(C::F::from_canonical_u32(n[2])); - let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); - let e = builder.felts2ext(&[f1, f2, f3, f4]); builder.set(&prod_spec_evals, idx, e); } @@ -377,12 +378,13 @@ fn build_test_program( let logup_spec_evals: Array> = builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { - let f1 = builder.constant(C::F::from_canonical_u32(n[0])); - let f2 = builder.constant(C::F::from_canonical_u32(n[1])); - let f3 = builder.constant(C::F::from_canonical_u32(n[2])); - let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); - let e = builder.felts2ext(&[f1, f2, f3, f4]); builder.set(&logup_spec_evals, idx, e); } @@ -408,15 +410,15 @@ fn build_test_program( let next_layer_evals: Array> = builder.dyn_array(r_evals_u32s.len() / EXT_DEG); for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { - let f1 = builder.constant(C::F::from_canonical_u32(n[0])); - let f2 = builder.constant(C::F::from_canonical_u32(n[1])); - let f3 = builder.constant(C::F::from_canonical_u32(n[2])); - let f4 = builder.constant(C::F::from_canonical_u32(n[3])); + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); - let e = builder.felts2ext(&[f1, f2, f3, f4]); builder.set(&next_layer_evals, idx, e); } - */ - // builder.sumcheck_layer_eval(&ctx, &challenges, &prod_spec_evals, &logup_spec_evals, &next_layer_evals); + builder.sumcheck_layer_eval(&ctx, &challenges, &prod_spec_evals, &logup_spec_evals, &next_layer_evals); } From f9023a5a8c286c6181edc5140f6af7b6e94f6181 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 17:22:22 -0400 Subject: [PATCH 09/41] Correct header row register reads --- extensions/native/circuit/src/sumcheck/air.rs | 6 +- .../native/circuit/src/sumcheck/chip.rs | 55 ++++--------------- .../native/circuit/src/sumcheck/trace.rs | 3 +- 3 files changed, 14 insertions(+), 50 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index ad7fad2495..8d2d76fe6a 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -97,9 +97,7 @@ impl Air .eval(builder, header_row); // Read registers - // _debug - // for i in 0..5usize { - for i in 0..1usize { + for i in 0..5usize { self.memory_bridge .read( MemoryAddress::new(self.address_space, registers[i]), @@ -110,7 +108,6 @@ impl Air .eval(builder, header_row); } - /* // React ctx self.memory_bridge .read( @@ -130,7 +127,6 @@ impl Air &header_row_specific.read_records[6], ) .eval(builder, header_row); - */ /* _debug // Separate aggregate column clusters diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index a3560432d2..775a815928 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -124,9 +124,6 @@ impl InstructionExecutor for NativeSumcheckChip { let (read_ctx_pointer, ctx_pointer) = memory.read_cell(register_address_space, input_register_1); - - // _debug - /* let (read_cs_pointer, cs_pointer) = memory.read_cell(register_address_space, input_register_2); let (read_prod_pointer, prod_ptr) = @@ -135,6 +132,8 @@ impl InstructionExecutor for NativeSumcheckChip { memory.read_cell(register_address_space, input_register_4); let (read_result_pointer, r_ptr) = memory.read_cell(register_address_space, output_register); + let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; + let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); @@ -151,22 +150,12 @@ impl InstructionExecutor for NativeSumcheckChip { let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = memory.read::<{EXT_DEG * 4}>(data_address_space, cs_pointer); - let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); - let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); - let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); - - let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); - let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - - let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; - - - let mut header_row: SumcheckEvalRecord = SumcheckEvalRecord { + let mut header_row = SumcheckEvalRecord { from_state, instruction: instruction.clone(), row_type: 0, - curr_timestamp_increment: curr_timestamp, - register_ptrs, + curr_timestamp_increment: curr_timestamp, + register_ptrs, registers: [ input_register_1, input_register_2, @@ -175,7 +164,7 @@ impl InstructionExecutor for NativeSumcheckChip { output_register, ], ctx, - challenges, + challenges, read_data_records: [ read_ctx_pointer, read_cs_pointer, @@ -185,38 +174,18 @@ impl InstructionExecutor for NativeSumcheckChip { ctx_read, challenges_read, ], - alpha, ..Default::default() }; - */ - - - // _debug - let mut header_row = SumcheckEvalRecord { - from_state, - instruction: instruction.clone(), - row_type: 0, - curr_timestamp_increment: curr_timestamp, - registers: [ - input_register_1, - input_register_2, - input_register_3, - input_register_4, - output_register, - ], - ..Default::default() - }; - println!("=> ctx_pointer: {:?}", ctx_pointer); - header_row.register_ptrs[0] = ctx_pointer; - println!("=> read_ctx_pointer: {:?}", read_ctx_pointer); - header_row.read_data_records[0] = read_ctx_pointer; observation_records.push(header_row); self.height += 1; - // _debug - // curr_timestamp += 7; - curr_timestamp += 1; + curr_timestamp += 7; + // let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + // let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); + // let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + // let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); /* diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 24aff2984e..84c5796509 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -60,8 +60,7 @@ impl NativeSumcheckChip { header.registers = record.registers; // registers, ctx, challenges - // _debug - for i in 0..1usize { + for i in 0..7usize { let mem_record = memory.record_by_id(record.read_data_records[i]); aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); } From 1ae5182bc40e7494557fe574e051525e3229c65e Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 17:26:29 -0400 Subject: [PATCH 10/41] Correct header row register read --- extensions/native/circuit/src/sumcheck/chip.rs | 14 +++++++------- extensions/native/circuit/src/sumcheck/trace.rs | 7 ++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 775a815928..298081969a 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -134,9 +134,7 @@ impl InstructionExecutor for NativeSumcheckChip { memory.read_cell(register_address_space, output_register); let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; - let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); - let [ round, num_prod_spec, @@ -149,6 +147,7 @@ impl InstructionExecutor for NativeSumcheckChip { ] = ctx; let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = memory.read::<{EXT_DEG * 4}>(data_address_space, cs_pointer); + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); let mut header_row = SumcheckEvalRecord { from_state, @@ -156,6 +155,7 @@ impl InstructionExecutor for NativeSumcheckChip { row_type: 0, curr_timestamp_increment: curr_timestamp, register_ptrs, + alpha, registers: [ input_register_1, input_register_2, @@ -181,11 +181,11 @@ impl InstructionExecutor for NativeSumcheckChip { self.height += 1; curr_timestamp += 7; - // let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); - // let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); - // let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); - // let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); /* diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 84c5796509..9d2a72bc9c 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -50,6 +50,7 @@ impl NativeSumcheckChip { cols.register_ptrs = record.register_ptrs; cols.ctx = record.ctx; cols.challenges = record.challenges; + cols.alpha = record.alpha; if record.row_type == 0 { cols.header_row = F::ONE; @@ -65,11 +66,7 @@ impl NativeSumcheckChip { aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); } - - - - - + } else if record.row_type == 1 { cols.prod_row = F::ONE; let prod: &mut ProdSpecificCols = From 49f337f6f06b5566f8e91e172ccbcc738b120c59 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 18:26:17 -0400 Subject: [PATCH 11/41] Correct rw records for prod rows --- extensions/native/circuit/src/sumcheck/air.rs | 86 +++++++++++-------- .../native/circuit/src/sumcheck/chip.rs | 11 ++- .../native/circuit/src/sumcheck/columns.rs | 6 ++ .../native/circuit/src/sumcheck/trace.rs | 50 ++++++++++- 4 files changed, 108 insertions(+), 45 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 8d2d76fe6a..03110eec25 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -63,6 +63,7 @@ impl Air alpha, challenges, max_round, + within_round_limit, should_acc, eval_acc, specific, @@ -128,47 +129,12 @@ impl Air ) .eval(builder, header_row); - /* _debug // Separate aggregate column clusters let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); - // Carry along columns - assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); - assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); - assert_array_eq::<_, _, _, {EXT_DEG * 2}>( - &mut builder.when(next.prod_row + next.logup_row), - challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), - next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") - ); - assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); - - // Row transitions - builder - .when(header_row) - .when(next.logup_row) - .assert_zero(ctx[1]); - builder - .when(next.prod_row) - .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); - builder - .when(next.logup_row) - .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); - builder - .when(prod_row) - .when(next.logup_row) - .assert_eq(ctx[1], curr_prod_n); - builder - .when(logup_row) - .when(not(next.logup_row)) - .assert_eq(ctx[2], curr_logup_n); - - - - - // Prod spec evaluation let prod_row_specific: &ProdSpecificCols = specific[..ProdSpecificCols::::width()].borrow(); @@ -182,18 +148,22 @@ impl Air ) .eval(builder, prod_row); + // _debug + // let p_start_ptr = register_ptrs[2] + (ctx[4] * ctx[3] * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG); + self.memory_bridge .read( MemoryAddress::new( self.address_space, - register_ptrs[2] + (ctx[4] * ctx[3] * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[2] + prod_row_specific.data_ptr, ), prod_row_specific.p, start_timestamp + AB::F::ONE, &prod_row_specific.read_records[1], ) - .eval(builder, prod_row); + .eval(builder, prod_row * within_round_limit); + let p1: [_; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); let p2: [_; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); @@ -207,7 +177,47 @@ impl Air start_timestamp + AB::F::TWO, &prod_row_specific.write_record, ) - .eval(builder, prod_row); + .eval(builder, prod_row * within_round_limit); + + + + + /* _debug + // Carry along columns + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); + assert_array_eq::<_, _, _, {EXT_DEG * 2}>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") + ); + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); + + // Row transitions + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(ctx[1]); + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(logup_row) + .when(not(next.logup_row)) + .assert_eq(ctx[2], curr_logup_n); + + + + + + // Logup spec evaluation let logup_row_specific: &LogupSpecificCols = diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 298081969a..681aa9bd05 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -40,12 +40,14 @@ pub struct SumcheckEvalRecord { pub write_data_records: [RecordId; 2], pub max_round: F, + pub within_round_limit: bool, pub should_acc: bool, pub prod_spec_n: usize, pub logup_spec_n: usize, pub alpha: [F; EXT_DEG], pub alpha1: [F; EXT_DEG], pub alpha2: [F; EXT_DEG], + pub data_ptr: F, pub p1: [F; EXT_DEG], pub p2: [F; EXT_DEG], pub q1: [F; EXT_DEG], @@ -183,12 +185,9 @@ impl InstructionExecutor for NativeSumcheckChip { let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); - /* - let mut i = F::ZERO; let mut i_usize = 0usize; while i < num_prod_spec { @@ -212,6 +211,7 @@ impl InstructionExecutor for NativeSumcheckChip { curr_timestamp += 1; if round < (max_round - F::from_canonical_usize(1)) { + prod_row.within_round_limit = true; let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, prod_specs_inner_len, @@ -219,6 +219,7 @@ impl InstructionExecutor for NativeSumcheckChip { round, F::from_canonical_usize(0), ); + prod_row.data_ptr = start; let (read_p, ps) = memory.read::<{EXT_DEG * 2}>(data_address_space, prod_ptr + start); let p1: [F; 4] = ps[0..EXT_DEG].try_into().expect(""); @@ -247,6 +248,7 @@ impl InstructionExecutor for NativeSumcheckChip { prod_row.should_acc = true; prod_row.eval_acc = eval_acc.clone(); } + curr_timestamp += 2; } @@ -258,6 +260,8 @@ impl InstructionExecutor for NativeSumcheckChip { self.height += 1; } + + /* let mut i = F::ZERO; let mut i_usize = 0usize; while i < num_logup_spec { @@ -278,6 +282,7 @@ impl InstructionExecutor for NativeSumcheckChip { curr_timestamp += 1; if round < (max_round - F::from_canonical_usize(1)) { + logup_row.within_round_limit = true; let start = calculate_3d_ext_idx( logup_specs_inner_inner_len, logup_specs_inner_len, diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index c4f59c2e8a..f34aaf51e9 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -47,6 +47,8 @@ pub struct NativeSumcheckCols { // Specific to each row pub max_round: T, + // Is this round within max_round + pub within_round_limit: T, // Should the evaluation be accumualted pub should_acc: T, @@ -82,6 +84,8 @@ pub struct HeaderSpecificCols { #[repr(C)] #[derive(AlignedBorrow)] pub struct ProdSpecificCols { + /// Pointer + pub data_ptr: T, /// 2 extension elements pub p: [T; EXT_DEG * 2], /// read max varibale and 2 p values @@ -95,6 +99,8 @@ pub struct ProdSpecificCols { #[repr(C)] #[derive(AlignedBorrow)] pub struct LogupSpecificCols { + /// Pointer + pub data_ptr: T, /// 4 extension elements pub pq: [T; EXT_DEG * 4], /// read max variable and 4 values: p1, p2, q1, q2 diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 9d2a72bc9c..fa961d48a0 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -13,7 +13,7 @@ use openvm_stark_backend::{ prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, }; -use crate::sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}; +use crate::{sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}, EXT_DEG}; impl ChipUsageGetter for NativeSumcheckChip @@ -51,6 +51,10 @@ impl NativeSumcheckChip { cols.ctx = record.ctx; cols.challenges = record.challenges; cols.alpha = record.alpha; + cols.max_round = record.max_round; + cols.within_round_limit = if record.within_round_limit { F::ONE } else { F::ZERO }; + cols.should_acc = if record.should_acc { F::ONE } else { F::ZERO }; + cols.eval_acc = record.eval_acc; if record.row_type == 0 { cols.header_row = F::ONE; @@ -60,21 +64,59 @@ impl NativeSumcheckChip { header.pc = F::from_canonical_u32(record.from_state.pc); header.registers = record.registers; - // registers, ctx, challenges for i in 0..7usize { let mem_record = memory.record_by_id(record.read_data_records[i]); aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); } - - } else if record.row_type == 1 { cols.prod_row = F::ONE; let prod: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + cols.curr_prod_n = F::from_canonical_usize(record.prod_spec_n + 1); + cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); + prod.p[0..EXT_DEG].copy_from_slice(&record.p1); + prod.p[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); + prod.data_ptr = record.data_ptr; + + // Read max_round + let mem_record = memory.record_by_id(record.read_data_records[0]); + aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[0]); + + if record.within_round_limit { + // Read p1, p2 + let mem_record = memory.record_by_id(record.read_data_records[1]); + aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[1]); + + // Write p eval + prod.p_evals = record.p_evals; + let mem_record = memory.record_by_id(record.write_data_records[0]); + aux_cols_factory.generate_write_aux(mem_record, &mut prod.write_record); + } } else if record.row_type == 2 { cols.logup_row = F::ONE; let logup: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + cols.curr_logup_n = F::from_canonical_usize(record.logup_spec_n + 1); + cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); + cols.challenges[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.alpha2); + + +// pub p1: [F; EXT_DEG], +// pub p2: [F; EXT_DEG], +// pub q1: [F; EXT_DEG], +// pub q2: [F; EXT_DEG], +// pub p_evals: [F; EXT_DEG], +// pub q_evals: [F; EXT_DEG], +// } +// logup.data_ptr = record.data_ptr; + + + + + + } else { unreachable!() } From 5d6a236ddce9588164c957b48b010a98dba7028d Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 20:53:23 -0400 Subject: [PATCH 12/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 120 +++++++++++------- .../native/circuit/src/sumcheck/chip.rs | 24 ++-- .../native/circuit/src/sumcheck/columns.rs | 8 ++ .../native/circuit/src/sumcheck/trace.rs | 41 ++++-- 4 files changed, 124 insertions(+), 69 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 03110eec25..6dd331b6ea 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -53,11 +53,15 @@ impl Air header_row, prod_row, logup_row, + prod_row_within_max_round, + logup_row_within_max_round, first_timestamp, start_timestamp, last_timestamp, register_ptrs, ctx, + prod_nested_len, + logup_nested_len, curr_prod_n, curr_logup_n, alpha, @@ -148,8 +152,9 @@ impl Air ) .eval(builder, prod_row); - // _debug - // let p_start_ptr = register_ptrs[2] + (ctx[4] * ctx[3] * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG); + builder + .when(prod_row_within_max_round) + .assert_eq(prod_row_specific.data_ptr, (prod_nested_len * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); self.memory_bridge .read( @@ -161,9 +166,8 @@ impl Air start_timestamp + AB::F::ONE, &prod_row_specific.read_records[1], ) - .eval(builder, prod_row * within_round_limit); + .eval(builder, prod_row_within_max_round); - let p1: [_; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); let p2: [_; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); @@ -177,7 +181,71 @@ impl Air start_timestamp + AB::F::TWO, &prod_row_specific.write_record, ) - .eval(builder, prod_row * within_round_limit); + .eval(builder, prod_row_within_max_round); + + // Logup spec evaluation + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2 - 1) + ctx[1] + curr_logup_n), + [max_round], + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row); + + // _debug + builder + .when(logup_row_within_max_round) + .assert_eq(logup_row_specific.data_ptr, (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[3] + logup_row_specific.data_ptr, + ), + logup_row_specific.pq, + start_timestamp + AB::F::ONE, + &logup_row_specific.read_records[1], + ) + .eval(builder, logup_row_within_max_round); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().expect(""); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{EXT_DEG * 3}].try_into().expect(""); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); + + /* _debug + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + (ctx[1] + curr_prod_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &logup_row_specific.write_records[0], + ) + .eval(builder, logup_row_within_max_round); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + (ctx[1] + ctx[2] + curr_prod_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.q_evals, + start_timestamp + AB::F::from_canonical_usize(3), + &logup_row_specific.write_records[1], + ) + .eval(builder, logup_row_within_max_round); + */ + + + @@ -219,47 +287,7 @@ impl Air - // Logup spec evaluation - let logup_row_specific: &LogupSpecificCols = - specific[..LogupSpecificCols::::width()].borrow(); - - self.memory_bridge - .read( - MemoryAddress::new(self.address_space, register_ptrs[0] + ctx[1] + AB::F::from_canonical_usize(EXT_DEG * 2 - 1) + curr_logup_n), - [max_round], - start_timestamp, - &prod_row_specific.read_records[0], - ) - .eval(builder, prod_row); - - self.memory_bridge - .read( - MemoryAddress::new( - self.address_space, - register_ptrs[2] + (ctx[4] * ctx[3] * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG), - ), - prod_row_specific.p, - start_timestamp + AB::F::ONE, - &prod_row_specific.read_records[1], - ) - .eval(builder, prod_row); - - let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().expect(""); - let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); - let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{EXT_DEG * 3}].try_into().expect(""); - let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); - - self.memory_bridge - .write( - MemoryAddress::new( - self.address_space, - register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), - ), - prod_row_specific.p_evals, - start_timestamp + AB::F::TWO, - &prod_row_specific.write_record, - ) - .eval(builder, prod_row); + // Termination condition diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 681aa9bd05..5a4ebdacf3 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -260,8 +260,6 @@ impl InstructionExecutor for NativeSumcheckChip { self.height += 1; } - - /* let mut i = F::ZERO; let mut i_usize = 0usize; while i < num_logup_spec { @@ -273,10 +271,13 @@ impl InstructionExecutor for NativeSumcheckChip { register_ptrs, ctx, challenges, + alpha, logup_spec_n: i_usize, ..Default::default() }; - let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + num_prod_spec + F::from_canonical_usize(EXT_DEG * 2) + i); + logup_row.alpha1 = alpha_acc; + + let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + F::from_canonical_usize(EXT_DEG * 2) + num_prod_spec + i); logup_row.max_round = max_round; logup_row.read_data_records[0] = read_max_round; curr_timestamp += 1; @@ -290,6 +291,7 @@ impl InstructionExecutor for NativeSumcheckChip { round, F::from_canonical_usize(0), ); + logup_row.data_ptr = start; let (read_pqs, pqs) = memory.read::<{EXT_DEG * 4}>(data_address_space, logup_ptr + start); let p1: [F; 4] = pqs[0..EXT_DEG].try_into().expect(""); @@ -327,11 +329,13 @@ impl InstructionExecutor for NativeSumcheckChip { logup_row.p_evals = p_evals; logup_row.q_evals = q_evals; + /* _debug let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), p_evals); let (write_slice_eval_2, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + num_logup_spec + i) * F::from_canonical_usize(EXT_DEG), q_evals); - + logup_row.write_data_records[0] = write_slice_eval_1; logup_row.write_data_records[1] = write_slice_eval_2; + */ let not_in_round = F::ONE - in_round; if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { @@ -340,11 +344,13 @@ impl InstructionExecutor for NativeSumcheckChip { eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_denominator, q_evals)); logup_row.should_acc = true; - logup_row.alpha1 = alpha_acc; logup_row.alpha2 = alpha_denominator; logup_row.eval_acc = eval_acc.clone(); } - curr_timestamp += 3; + + // _debug + // curr_timestamp += 3; + curr_timestamp += 1; } alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); @@ -355,15 +361,15 @@ impl InstructionExecutor for NativeSumcheckChip { self.height += 1; } + /* _debug let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); curr_timestamp += 1; observation_records[0].write_data_records[0] = write_r; - */ + for record in &mut observation_records { record.final_timestamp_increment = curr_timestamp; - // _debug - // record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); + record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); } self.record_set.extend(observation_records); diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index f34aaf51e9..d3b9e40ca3 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -17,6 +17,11 @@ pub struct NativeSumcheckCols { /// Indicates that this row is a step for logup_spec in the layer sum operation pub logup_row: T, + /// Indicates that the prod row is within maximum round + pub prod_row_within_max_round: T, + /// Indicates that the logup row is within maximum round + pub logup_row_within_max_round: T, + /// Timestamps pub first_timestamp: T, pub start_timestamp: T, @@ -38,6 +43,9 @@ pub struct NativeSumcheckCols { // ] pub ctx: [T; EXT_DEG * 2], + pub prod_nested_len: T, + pub logup_nested_len: T, + pub curr_prod_n: T, pub curr_logup_n: T, diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index fa961d48a0..2a95e8a018 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -49,6 +49,8 @@ impl NativeSumcheckChip { cols.last_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.final_timestamp_increment); cols.register_ptrs = record.register_ptrs; cols.ctx = record.ctx; + cols.prod_nested_len = record.ctx[4] * record.ctx[3]; + cols.logup_nested_len = record.ctx[6] * record.ctx[5]; cols.challenges = record.challenges; cols.alpha = record.alpha; cols.max_round = record.max_round; @@ -70,6 +72,7 @@ impl NativeSumcheckChip { } } else if record.row_type == 1 { cols.prod_row = F::ONE; + cols.prod_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; let prod: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); @@ -95,28 +98,38 @@ impl NativeSumcheckChip { } } else if record.row_type == 2 { cols.logup_row = F::ONE; + cols.logup_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; let logup: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); cols.curr_logup_n = F::from_canonical_usize(record.logup_spec_n + 1); cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); cols.challenges[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.alpha2); + logup.pq[0..EXT_DEG].copy_from_slice(&record.p1); + logup.pq[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); + logup.pq[(EXT_DEG * 2)..(EXT_DEG * 3)].copy_from_slice(&record.q1); + logup.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.q2); + logup.data_ptr = record.data_ptr; + // Read max_round + let mem_record = memory.record_by_id(record.read_data_records[0]); + aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[0]); -// pub p1: [F; EXT_DEG], -// pub p2: [F; EXT_DEG], -// pub q1: [F; EXT_DEG], -// pub q2: [F; EXT_DEG], -// pub p_evals: [F; EXT_DEG], -// pub q_evals: [F; EXT_DEG], -// } -// logup.data_ptr = record.data_ptr; - - - - - - + if record.within_round_limit { + // Read p1, p2, q1, q2 + let mem_record = memory.record_by_id(record.read_data_records[1]); + aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[1]); + + // Write p and q eval + /* _debug + logup.p_evals = record.p_evals; + logup.q_evals = record.q_evals; + for i in 0..2usize { + let mem_record = memory.record_by_id(record.write_data_records[i]); + aux_cols_factory.generate_write_aux(mem_record, &mut logup.write_records[i]); + } + */ + } } else { unreachable!() } From 368cb99a535fd1df8ace21a80f94ad68ce1fd107 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 20:56:45 -0400 Subject: [PATCH 13/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 12 +++--------- extensions/native/circuit/src/sumcheck/chip.rs | 6 +----- extensions/native/circuit/src/sumcheck/trace.rs | 2 -- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 6dd331b6ea..b467cc2ac4 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -217,13 +217,12 @@ impl Air let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{EXT_DEG * 3}].try_into().expect(""); let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); - - /* _debug + self.memory_bridge .write( MemoryAddress::new( self.address_space, - register_ptrs[4] + (ctx[1] + curr_prod_n) * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[4] + (ctx[1] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, start_timestamp + AB::F::TWO, @@ -235,18 +234,13 @@ impl Air .write( MemoryAddress::new( self.address_space, - register_ptrs[4] + (ctx[1] + ctx[2] + curr_prod_n) * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[4] + (ctx[1] + ctx[2] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, start_timestamp + AB::F::from_canonical_usize(3), &logup_row_specific.write_records[1], ) .eval(builder, logup_row_within_max_round); - */ - - - - diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 5a4ebdacf3..3a80767b77 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -329,13 +329,11 @@ impl InstructionExecutor for NativeSumcheckChip { logup_row.p_evals = p_evals; logup_row.q_evals = q_evals; - /* _debug let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), p_evals); let (write_slice_eval_2, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + num_logup_spec + i) * F::from_canonical_usize(EXT_DEG), q_evals); logup_row.write_data_records[0] = write_slice_eval_1; logup_row.write_data_records[1] = write_slice_eval_2; - */ let not_in_round = F::ONE - in_round; if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { @@ -348,9 +346,7 @@ impl InstructionExecutor for NativeSumcheckChip { logup_row.eval_acc = eval_acc.clone(); } - // _debug - // curr_timestamp += 3; - curr_timestamp += 1; + curr_timestamp += 3; } alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 2a95e8a018..857c118b64 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -121,14 +121,12 @@ impl NativeSumcheckChip { aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[1]); // Write p and q eval - /* _debug logup.p_evals = record.p_evals; logup.q_evals = record.q_evals; for i in 0..2usize { let mem_record = memory.record_by_id(record.write_data_records[i]); aux_cols_factory.generate_write_aux(mem_record, &mut logup.write_records[i]); } - */ } } else { unreachable!() From 0785a3e91b5164ef5797733326dac1c1a84edf56 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 21:01:39 -0400 Subject: [PATCH 14/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 13 +++++++++++++ extensions/native/circuit/src/sumcheck/chip.rs | 2 -- extensions/native/circuit/src/sumcheck/trace.rs | 4 ++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index b467cc2ac4..667afcb921 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -133,6 +133,19 @@ impl Air ) .eval(builder, header_row); + // Write final result + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4], + ), + eval_acc, + last_timestamp - AB::F::ONE, + &header_row_specific.write_records, + ) + .eval(builder, header_row); + // Separate aggregate column clusters let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 3a80767b77..325c6b3035 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -357,11 +357,9 @@ impl InstructionExecutor for NativeSumcheckChip { self.height += 1; } - /* _debug let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); curr_timestamp += 1; observation_records[0].write_data_records[0] = write_r; - */ for record in &mut observation_records { record.final_timestamp_increment = curr_timestamp; diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 857c118b64..2bcb692a0e 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -70,6 +70,10 @@ impl NativeSumcheckChip { let mem_record = memory.record_by_id(record.read_data_records[i]); aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); } + + // write the final result + let mem_record = memory.record_by_id(record.write_data_records[0]); + aux_cols_factory.generate_write_aux(mem_record, &mut header.write_records); } else if record.row_type == 1 { cols.prod_row = F::ONE; cols.prod_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; From de7917473e86bea149d53c63fe5419cb5b287436 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 21:08:56 -0400 Subject: [PATCH 15/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 667afcb921..46968d3390 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -79,6 +79,37 @@ impl Air let enabled = header_row + prod_row + logup_row; builder.assert_bool(enabled.clone()); + // Carry along columns + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); + assert_array_eq::<_, _, _, {EXT_DEG * 2}>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") + ); + builder.when(next.prod_row + next.logup_row).assert_eq(prod_nested_len, next.prod_nested_len); + builder.when(next.prod_row + next.logup_row).assert_eq(logup_nested_len, next.logup_nested_len); + + // Row transition + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(ctx[1]); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(logup_row) + .when(not(next.logup_row)) + .assert_eq(ctx[2], curr_logup_n); + // Header let header_row_specific: &HeaderSpecificCols = specific[..HeaderSpecificCols::::width()].borrow(); @@ -209,7 +240,6 @@ impl Air ) .eval(builder, logup_row); - // _debug builder .when(logup_row_within_max_round) .assert_eq(logup_row_specific.data_ptr, (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); @@ -258,35 +288,10 @@ impl Air /* _debug - // Carry along columns - assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); - assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); - assert_array_eq::<_, _, _, {EXT_DEG * 2}>( - &mut builder.when(next.prod_row + next.logup_row), - challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), - next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") - ); - assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); + // Row transitions - builder - .when(header_row) - .when(next.logup_row) - .assert_zero(ctx[1]); - builder - .when(next.prod_row) - .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); - builder - .when(next.logup_row) - .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); - builder - .when(prod_row) - .when(next.logup_row) - .assert_eq(ctx[1], curr_prod_n); - builder - .when(logup_row) - .when(not(next.logup_row)) - .assert_eq(ctx[2], curr_logup_n); + @@ -297,6 +302,7 @@ impl Air // Termination condition + // Timestamp transition */ } From e1194c2d1d5f02957bdc5312ff22c8fd8f05f1aa Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 21:21:37 -0400 Subject: [PATCH 16/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 46968d3390..92248d9955 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -110,6 +110,20 @@ impl Air .when(not(next.logup_row)) .assert_eq(ctx[2], curr_logup_n); + // Timestamp transition + builder + .when(header_row) + .when(next.prod_row + next.logup_row) + .assert_eq(next.start_timestamp, start_timestamp + AB::F::from_canonical_usize(7)); + builder + .when(prod_row) + .when(next.prod_row + next.logup_row) + .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO); + builder + .when(logup_row) + .when(next.prod_row + next.logup_row) + .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); + // Header let header_row_specific: &HeaderSpecificCols = specific[..HeaderSpecificCols::::width()].borrow(); @@ -290,18 +304,7 @@ impl Air /* _debug - // Row transitions - - - - - - - - - - // Termination condition // Timestamp transition */ From 6aaea67ea8e46f2b1bd67a3c239f2c574e456a8a Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 21:22:12 -0400 Subject: [PATCH 17/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 92248d9955..1873e4401d 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -298,15 +298,5 @@ impl Air &logup_row_specific.write_records[1], ) .eval(builder, logup_row_within_max_round); - - - - /* _debug - - - - // Timestamp transition - - */ } } \ No newline at end of file From 75df4f33e4f89c754c41baba44145238b96f90d4 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 23 Sep 2025 21:48:08 -0400 Subject: [PATCH 18/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 26 ++++++++++++++----- .../native/circuit/src/sumcheck/chip.rs | 8 +++++- .../native/circuit/src/sumcheck/columns.rs | 7 +++++ .../native/circuit/src/sumcheck/trace.rs | 3 +++ 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 1873e4401d..2446847634 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -14,7 +14,7 @@ use openvm_stark_backend::{ p3_matrix::Matrix, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use crate::{sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, EXT_DEG}; +use crate::{sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, FieldExtension, EXT_DEG}; #[derive(Clone, Debug)] pub struct NativeSumcheckAir { @@ -53,6 +53,9 @@ impl Air header_row, prod_row, logup_row, + header_continuation, + prod_continuation, + logup_continuation, prod_row_within_max_round, logup_row_within_max_round, first_timestamp, @@ -124,6 +127,21 @@ impl Air .when(next.prod_row + next.logup_row) .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); + + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); + let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); + let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); + let next_alpha2: [_; EXT_DEG] = next.challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); + + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + // Header let header_row_specific: &HeaderSpecificCols = specific[..HeaderSpecificCols::::width()].borrow(); @@ -191,12 +209,6 @@ impl Air ) .eval(builder, header_row); - // Separate aggregate column clusters - let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); - let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); - let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); - let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); - // Prod spec evaluation let prod_row_specific: &ProdSpecificCols = specific[..ProdSpecificCols::::width()].borrow(); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 325c6b3035..78275c0c88 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -31,6 +31,7 @@ pub struct SumcheckEvalRecord { pub row_type: usize, // 0 - header; 1 - prod; 2 - logup pub curr_timestamp_increment: usize, pub final_timestamp_increment: usize, + pub continuation: bool, pub register_ptrs: [F; 5], pub registers: [F; 5], @@ -151,10 +152,11 @@ impl InstructionExecutor for NativeSumcheckChip { let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = memory.read::<{EXT_DEG * 4}>(data_address_space, cs_pointer); let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); - let mut header_row = SumcheckEvalRecord { + let mut header_row = SumcheckEvalRecord { from_state, instruction: instruction.clone(), row_type: 0, + continuation: true, curr_timestamp_increment: curr_timestamp, register_ptrs, alpha, @@ -195,6 +197,7 @@ impl InstructionExecutor for NativeSumcheckChip { from_state, instruction: instruction.clone(), row_type: 1, + continuation: true, curr_timestamp_increment: curr_timestamp, register_ptrs, ctx, @@ -267,6 +270,7 @@ impl InstructionExecutor for NativeSumcheckChip { from_state, instruction: instruction.clone(), row_type: 2, + continuation: true, curr_timestamp_increment: curr_timestamp, register_ptrs, ctx, @@ -365,6 +369,8 @@ impl InstructionExecutor for NativeSumcheckChip { record.final_timestamp_increment = curr_timestamp; record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); } + let last_idx = observation_records.len() - 1; + observation_records[last_idx].continuation = false; self.record_set.extend(observation_records); println!("=> current_height: {:?}", self.height); diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index d3b9e40ca3..d8b175c43e 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -16,6 +16,13 @@ pub struct NativeSumcheckCols { pub prod_row: T, /// Indicates that this row is a step for logup_spec in the layer sum operation pub logup_row: T, + + /// Indicates that there are valid operations following this header row + pub header_continuation: T, + /// Indicates that there are valid operations following this product evaluation row + pub prod_continuation: T, + /// Indicates that there are valid operations following this logup row + pub logup_continuation: T, /// Indicates that the prod row is within maximum round pub prod_row_within_max_round: T, diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 2bcb692a0e..c95ff696f8 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -60,6 +60,7 @@ impl NativeSumcheckChip { if record.row_type == 0 { cols.header_row = F::ONE; + cols.header_continuation = if record.continuation { F::ONE } else { F::ZERO }; let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); @@ -76,6 +77,7 @@ impl NativeSumcheckChip { aux_cols_factory.generate_write_aux(mem_record, &mut header.write_records); } else if record.row_type == 1 { cols.prod_row = F::ONE; + cols.prod_continuation = if record.continuation { F::ONE } else { F::ZERO }; cols.prod_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; let prod: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); @@ -102,6 +104,7 @@ impl NativeSumcheckChip { } } else if record.row_type == 2 { cols.logup_row = F::ONE; + cols.logup_continuation = if record.continuation { F::ONE } else { F::ZERO }; cols.logup_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; let logup: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); From 2cbbb78e09d6eabf3020170ed0ae66d0d3d90c13 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 24 Sep 2025 17:53:14 -0400 Subject: [PATCH 19/41] Degree reduction --- .../circuit/src/field_extension/core.rs | 4 +- extensions/native/circuit/src/sumcheck/air.rs | 62 ++++++++++++++++--- .../native/circuit/src/sumcheck/columns.rs | 6 ++ .../native/circuit/src/sumcheck/trace.rs | 4 ++ 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index d8c83fabdd..ea06ecb4c9 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -245,10 +245,10 @@ impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where - V: Copy, + V: Clone, V: Add, { - array::from_fn(|i| x[i] + y[i]) + array::from_fn(|i| x[i].clone() + y[i].clone()) } pub(crate) fn subtract(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 2446847634..1bcdd88407 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -1,6 +1,6 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionState}, + arch::{ContinuationVmProof, ExecutionBridge, ExecutionState}, system::memory::{offline_checker::MemoryBridge, MemoryAddress}, }; use openvm_circuit_primitives::utils::{assert_array_eq, not}; @@ -50,6 +50,7 @@ impl Air let next: &NativeSumcheckCols = (*next).borrow(); let &NativeSumcheckCols { + // Row indicators header_row, prod_row, logup_row, @@ -58,17 +59,30 @@ impl Air logup_continuation, prod_row_within_max_round, logup_row_within_max_round, + + prod_in_round_evaluation, + prod_next_round_evaluation, + logup_in_round_evaluation, + logup_next_round_evaluation, + + // Timestamps first_timestamp, start_timestamp, last_timestamp, + + // Results from reading registers register_ptrs, ctx, prod_nested_len, logup_nested_len, - curr_prod_n, - curr_logup_n, + + // Challenges alpha, challenges, + + curr_prod_n, + curr_logup_n, + max_round, within_round_limit, should_acc, @@ -79,8 +93,16 @@ impl Air builder.assert_bool(header_row); builder.assert_bool(prod_row); builder.assert_bool(logup_row); + builder.assert_bool(header_continuation); + builder.assert_bool(prod_continuation); + builder.assert_bool(logup_continuation); + builder.assert_bool(prod_row_within_max_round); + builder.assert_bool(logup_row_within_max_round); + builder.assert_bool(prod_in_round_evaluation); + builder.assert_bool(logup_in_round_evaluation); let enabled = header_row + prod_row + logup_row; builder.assert_bool(enabled.clone()); + let in_round = ctx[7]; // Carry along columns assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); @@ -108,11 +130,20 @@ impl Air .when(prod_row) .when(next.logup_row) .assert_eq(ctx[1], curr_prod_n); + builder + .when(prod_row) + .when(not(prod_continuation)) + .assert_eq(ctx[1], curr_prod_n); builder .when(logup_row) - .when(not(next.logup_row)) + .when(not(logup_continuation)) .assert_eq(ctx[2], curr_logup_n); + // Termination condition + let continuation = header_continuation + prod_continuation + logup_continuation; + builder.assert_bool(continuation.clone()); + assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); + // Timestamp transition builder .when(header_row) @@ -127,14 +158,12 @@ impl Air .when(next.prod_row + next.logup_row) .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); - // Randomness transition let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); - let next_alpha2: [_; EXT_DEG] = next.challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); @@ -225,7 +254,11 @@ impl Air builder .when(prod_row_within_max_round) .assert_eq(prod_row_specific.data_ptr, (prod_nested_len * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); - + builder + .assert_eq(prod_row * prod_row_within_max_round * in_round, prod_in_round_evaluation); + builder + .assert_eq(prod_row * prod_row_within_max_round * not(in_round), prod_next_round_evaluation); + self.memory_bridge .read( MemoryAddress::new( @@ -238,8 +271,8 @@ impl Air ) .eval(builder, prod_row_within_max_round); - let p1: [_; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); - let p2: [_; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); + let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); self.memory_bridge .write( @@ -253,6 +286,15 @@ impl Air ) .eval(builder, prod_row_within_max_round); + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::multiply::(p1, p2); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_in_round_evaluation), in_round_p_evals, prod_row_specific.p_evals); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_next_round_evaluation), next_round_p_evals, prod_row_specific.p_evals); + // Logup spec evaluation let logup_row_specific: &LogupSpecificCols = specific[..LogupSpecificCols::::width()].borrow(); @@ -269,6 +311,8 @@ impl Air builder .when(logup_row_within_max_round) .assert_eq(logup_row_specific.data_ptr, (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); + builder + .assert_eq(logup_row * logup_row_within_max_round * in_round, logup_in_round_evaluation); self.memory_bridge .read( diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index d8b175c43e..0911166b72 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -29,6 +29,12 @@ pub struct NativeSumcheckCols { /// Indicates that the logup row is within maximum round pub logup_row_within_max_round: T, + /// Indicates what type of evaluation constraints should be applied + pub prod_in_round_evaluation: T, + pub prod_next_round_evaluation: T, + pub logup_in_round_evaluation: T, + pub logup_next_round_evaluation: T, + /// Timestamps pub first_timestamp: T, pub start_timestamp: T, diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index c95ff696f8..bcf52c0c68 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -79,6 +79,8 @@ impl NativeSumcheckChip { cols.prod_row = F::ONE; cols.prod_continuation = if record.continuation { F::ONE } else { F::ZERO }; cols.prod_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; + cols.prod_in_round_evaluation = if record.within_round_limit { record.ctx[7] } else { F::ZERO }; + cols.prod_next_round_evaluation = if record.within_round_limit { F::ONE - record.ctx[7] } else { F::ZERO }; let prod: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); @@ -106,6 +108,8 @@ impl NativeSumcheckChip { cols.logup_row = F::ONE; cols.logup_continuation = if record.continuation { F::ONE } else { F::ZERO }; cols.logup_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; + cols.logup_in_round_evaluation = if record.within_round_limit { record.ctx[7] } else { F::ZERO }; + cols.logup_next_round_evaluation = if record.within_round_limit { F::ONE - record.ctx[7] } else { F::ZERO }; let logup: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); From 7e569280712ca797295f622211380e050e03f439 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 24 Sep 2025 18:12:03 -0400 Subject: [PATCH 20/41] Reduce degree --- extensions/native/circuit/src/sumcheck/air.rs | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 1bcdd88407..f2ff5104c4 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -313,6 +313,8 @@ impl Air .assert_eq(logup_row_specific.data_ptr, (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); builder .assert_eq(logup_row * logup_row_within_max_round * in_round, logup_in_round_evaluation); + builder + .assert_eq(logup_row * logup_row_within_max_round * not(in_round), logup_next_round_evaluation); self.memory_bridge .read( @@ -354,5 +356,26 @@ impl Air &logup_row_specific.write_records[1], ) .eval(builder, logup_row_within_max_round); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, q2), + FieldExtension::multiply::(p2, q1), + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_in_round_evaluation), in_round_p_evals, logup_row_specific.p_evals); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_next_round_evaluation), next_round_p_evals, logup_row_specific.p_evals); + + let next_round_q_evals = FieldExtension::add( + FieldExtension::multiply::(q1, c1), + FieldExtension::multiply::(q2, c2), + ); + let in_round_q_evals = FieldExtension::multiply::(q1, q2); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_in_round_evaluation), in_round_q_evals, logup_row_specific.q_evals); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_next_round_evaluation), next_round_q_evals, logup_row_specific.q_evals); + } } \ No newline at end of file From 18f0b755754aed537775ade393ca94a8f334d9e0 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 24 Sep 2025 19:07:38 -0400 Subject: [PATCH 21/41] Reduce degree --- extensions/native/circuit/src/sumcheck/air.rs | 39 +++++++++++++++++++ .../native/circuit/src/sumcheck/chip.rs | 14 +++++-- .../native/circuit/src/sumcheck/columns.rs | 9 ++++- .../native/circuit/src/sumcheck/trace.rs | 4 ++ 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index f2ff5104c4..0488eda8e8 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -54,17 +54,26 @@ impl Air header_row, prod_row, logup_row, + + // Whether valid prod/logup row operations follow this row header_continuation, prod_continuation, logup_continuation, + + // Round limit prod_row_within_max_round, logup_row_within_max_round, + // What type of evaluation is performed prod_in_round_evaluation, prod_next_round_evaluation, logup_in_round_evaluation, logup_next_round_evaluation, + // Indicates whether the round evaluations should be added to the accumulator + prod_acc, + logup_acc, + // Timestamps first_timestamp, start_timestamp, @@ -241,6 +250,8 @@ impl Air // Prod spec evaluation let prod_row_specific: &ProdSpecificCols = specific[..ProdSpecificCols::::width()].borrow(); + let next_prod_row_specific: &ProdSpecificCols = + next.specific[..ProdSpecificCols::::width()].borrow(); self.memory_bridge .read( @@ -258,6 +269,8 @@ impl Air .assert_eq(prod_row * prod_row_within_max_round * in_round, prod_in_round_evaluation); builder .assert_eq(prod_row * prod_row_within_max_round * not(in_round), prod_next_round_evaluation); + builder + .assert_eq(prod_row * should_acc, prod_acc); self.memory_bridge .read( @@ -295,9 +308,21 @@ impl Air assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_in_round_evaluation), in_round_p_evals, prod_row_specific.p_evals); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_next_round_evaluation), next_round_p_evals, prod_row_specific.p_evals); + // Accumulate evaluation + let acc_eval = FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_acc), prod_row_specific.acc_eval, acc_eval); + + let next_acc = FieldExtension::subtract( + eval_acc, + next_prod_row_specific.acc_eval, + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.prod_acc), next.eval_acc, next_acc); + // Logup spec evaluation let logup_row_specific: &LogupSpecificCols = specific[..LogupSpecificCols::::width()].borrow(); + let next_logup_row_specfic: &LogupSpecificCols = + next.specific[..LogupSpecificCols::::width()].borrow(); self.memory_bridge .read( @@ -315,6 +340,8 @@ impl Air .assert_eq(logup_row * logup_row_within_max_round * in_round, logup_in_round_evaluation); builder .assert_eq(logup_row * logup_row_within_max_round * not(in_round), logup_next_round_evaluation); + builder + .assert_eq(logup_row * should_acc, logup_acc); self.memory_bridge .read( @@ -376,6 +403,18 @@ impl Air let in_round_q_evals = FieldExtension::multiply::(q1, q2); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_in_round_evaluation), in_round_q_evals, logup_row_specific.q_evals); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_next_round_evaluation), next_round_q_evals, logup_row_specific.q_evals); + + // Accumulate evaluation + let acc_eval = FieldExtension::add( + FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), + FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_acc), logup_row_specific.acc_eval, acc_eval); + let next_acc = FieldExtension::subtract( + eval_acc, + next_logup_row_specfic.acc_eval, + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.logup_acc), next.eval_acc, next_acc); } } \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 78275c0c88..7ccff6b88a 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -56,6 +56,7 @@ pub struct SumcheckEvalRecord { pub p_evals: [F; EXT_DEG], pub q_evals: [F; EXT_DEG], pub eval_acc: [F; EXT_DEG], + pub acc_eval: [F; EXT_DEG], } fn calculate_3d_ext_idx( @@ -247,7 +248,9 @@ impl InstructionExecutor for NativeSumcheckChip { let not_in_round = F::ONE - in_round; if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { - eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, evals)); + let acc_eval = FieldExtension::multiply(alpha_acc, evals); + prod_row.acc_eval = acc_eval; + eval_acc = FieldExtension::add(eval_acc, acc_eval); prod_row.should_acc = true; prod_row.eval_acc = eval_acc.clone(); } @@ -341,10 +344,13 @@ impl InstructionExecutor for NativeSumcheckChip { let not_in_round = F::ONE - in_round; if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { - eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, p_evals)); let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); - eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_denominator, q_evals)); - + let acc_eval = FieldExtension::add( + FieldExtension::multiply(alpha_acc, p_evals), + FieldExtension::multiply(alpha_denominator, q_evals), + ); + logup_row.acc_eval = acc_eval; + eval_acc = FieldExtension::add(eval_acc, acc_eval); logup_row.should_acc = true; logup_row.alpha2 = alpha_denominator; logup_row.eval_acc = eval_acc.clone(); diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 0911166b72..9d8cd072f1 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -35,6 +35,10 @@ pub struct NativeSumcheckCols { pub logup_in_round_evaluation: T, pub logup_next_round_evaluation: T, + /// Indicates if evaluations are accumulated + pub prod_acc: T, + pub logup_acc: T, + /// Timestamps pub first_timestamp: T, pub start_timestamp: T, @@ -115,6 +119,8 @@ pub struct ProdSpecificCols { pub p_evals: [T; EXT_DEG], /// write p_evals pub write_record: MemoryWriteAuxCols, + /// Evaluation for the accumulator + pub acc_eval: [T; EXT_DEG], } #[repr(C)] @@ -130,7 +136,8 @@ pub struct LogupSpecificCols { pub p_evals: [T; EXT_DEG], /// Calculated q evals pub q_evals: [T; EXT_DEG], - /// write both p_evals and q_evals pub write_records: [MemoryWriteAuxCols; 2], + /// Evaluation for the accumulator + pub acc_eval: [T; EXT_DEG], } \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index bcf52c0c68..1c5fced3d3 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -81,6 +81,7 @@ impl NativeSumcheckChip { cols.prod_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; cols.prod_in_round_evaluation = if record.within_round_limit { record.ctx[7] } else { F::ZERO }; cols.prod_next_round_evaluation = if record.within_round_limit { F::ONE - record.ctx[7] } else { F::ZERO }; + cols.prod_acc = if record.should_acc { F::ONE } else { F::ZERO }; let prod: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); @@ -89,6 +90,7 @@ impl NativeSumcheckChip { prod.p[0..EXT_DEG].copy_from_slice(&record.p1); prod.p[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); prod.data_ptr = record.data_ptr; + prod.acc_eval = record.acc_eval; // Read max_round let mem_record = memory.record_by_id(record.read_data_records[0]); @@ -110,6 +112,7 @@ impl NativeSumcheckChip { cols.logup_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; cols.logup_in_round_evaluation = if record.within_round_limit { record.ctx[7] } else { F::ZERO }; cols.logup_next_round_evaluation = if record.within_round_limit { F::ONE - record.ctx[7] } else { F::ZERO }; + cols.logup_acc = if record.should_acc { F::ONE } else { F::ZERO }; let logup: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); @@ -121,6 +124,7 @@ impl NativeSumcheckChip { logup.pq[(EXT_DEG * 2)..(EXT_DEG * 3)].copy_from_slice(&record.q1); logup.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.q2); logup.data_ptr = record.data_ptr; + logup.acc_eval = record.acc_eval; // Read max_round let mem_record = memory.record_by_id(record.read_data_records[0]); From 7096b8b2db74287b6cfc831a3f47c0386defb14e Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 24 Sep 2025 21:31:11 -0400 Subject: [PATCH 22/41] Disable debug flag --- extensions/native/circuit/src/sumcheck/chip.rs | 8 ++++---- extensions/native/recursion/tests/sumcheck.rs | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 7ccff6b88a..fa4abc9227 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -121,10 +121,10 @@ impl InstructionExecutor for NativeSumcheckChip { let mut curr_timestamp: usize = 0; // _debug - println!("=> column width: {:?}", NativeSumcheckCols::::width()); - println!("=> header width: {:?}", HeaderSpecificCols::::width()); - println!("=> prod width: {:?}", ProdSpecificCols::::width()); - println!("=> logup width: {:?}", LogupSpecificCols::::width()); + // println!("=> column width: {:?}", NativeSumcheckCols::::width()); + // println!("=> header width: {:?}", HeaderSpecificCols::::width()); + // println!("=> prod width: {:?}", ProdSpecificCols::::width()); + // println!("=> logup width: {:?}", LogupSpecificCols::::width()); let (read_ctx_pointer, ctx_pointer) = memory.read_cell(register_address_space, input_register_1); diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 48de36b5ce..975f63b211 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -77,14 +77,12 @@ fn test_sumcheck_layer_eval() { fn build_test_program( builder: &mut Builder, ) { - let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; let ctx: Array> = builder.dyn_array(ctx_u32s.len()); for (idx, n) in ctx_u32s.into_iter().enumerate() { builder.set(&ctx, idx, Usize::from(n as usize)); } - let challenges_u32s = [ 548478283u32, 456436544, 1716290291, 791326976, 1829717553, 1422025771, 1917123958, 727015942, From 25a72ba9c6ba1e3aee91fd77c025787afae9e8a7 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 13:53:23 -0400 Subject: [PATCH 23/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 0488eda8e8..e3c76f8411 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -113,6 +113,7 @@ impl Air builder.assert_bool(enabled.clone()); let in_round = ctx[7]; + /* _debug // Carry along columns assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); @@ -179,6 +180,7 @@ impl Air assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + */ // Header let header_row_specific: &HeaderSpecificCols = @@ -299,6 +301,7 @@ impl Air ) .eval(builder, prod_row_within_max_round); + /* _debug // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -317,6 +320,7 @@ impl Air next_prod_row_specific.acc_eval, ); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.prod_acc), next.eval_acc, next_acc); + */ // Logup spec evaluation let logup_row_specific: &LogupSpecificCols = @@ -384,6 +388,7 @@ impl Air ) .eval(builder, logup_row_within_max_round); + /* _debug // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -416,5 +421,6 @@ impl Air next_logup_row_specfic.acc_eval, ); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.logup_acc), next.eval_acc, next_acc); + */ } } \ No newline at end of file From a43bd22d30cba78a1951fa93c6db7c9e9e571a54 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 14:10:57 -0400 Subject: [PATCH 24/41] Recover constraints --- extensions/native/circuit/src/sumcheck/air.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index e3c76f8411..13dacd6b3e 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -113,7 +113,6 @@ impl Air builder.assert_bool(enabled.clone()); let in_round = ctx[7]; - /* _debug // Carry along columns assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); @@ -180,7 +179,6 @@ impl Air assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); - */ // Header let header_row_specific: &HeaderSpecificCols = From 5b50b5bb77900ee8d9033c86a00962b5730ed766 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 14:39:29 -0400 Subject: [PATCH 25/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 13dacd6b3e..aaa11fdad2 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -113,6 +113,14 @@ impl Air builder.assert_bool(enabled.clone()); let in_round = ctx[7]; + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); + let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); + let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); + + /* _debug // Carry along columns assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); @@ -167,18 +175,12 @@ impl Air .when(next.prod_row + next.logup_row) .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); - // Randomness transition - let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); - let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); - let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); - let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); - let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); - let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + */ // Header let header_row_specific: &HeaderSpecificCols = @@ -299,7 +301,6 @@ impl Air ) .eval(builder, prod_row_within_max_round); - /* _debug // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -318,7 +319,6 @@ impl Air next_prod_row_specific.acc_eval, ); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.prod_acc), next.eval_acc, next_acc); - */ // Logup spec evaluation let logup_row_specific: &LogupSpecificCols = From 12fe6fdc673650c15b86b1ecedfdb632047fc6ee Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 14:58:56 -0400 Subject: [PATCH 26/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index aaa11fdad2..4338e84e34 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -386,7 +386,6 @@ impl Air ) .eval(builder, logup_row_within_max_round); - /* _debug // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -419,6 +418,5 @@ impl Air next_logup_row_specfic.acc_eval, ); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.logup_acc), next.eval_acc, next_acc); - */ } } \ No newline at end of file From 21a89711b3a2fd2aa7ca5315ec12d4e713ba3248 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 15:15:17 -0400 Subject: [PATCH 27/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 4338e84e34..9f9315c3f8 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -112,6 +112,8 @@ impl Air let enabled = header_row + prod_row + logup_row; builder.assert_bool(enabled.clone()); let in_round = ctx[7]; + let continuation = header_continuation + prod_continuation + logup_continuation; + builder.assert_bool(continuation.clone()); // Randomness transition let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); @@ -120,7 +122,6 @@ impl Air let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); - /* _debug // Carry along columns assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); @@ -132,6 +133,7 @@ impl Air builder.when(next.prod_row + next.logup_row).assert_eq(prod_nested_len, next.prod_nested_len); builder.when(next.prod_row + next.logup_row).assert_eq(logup_nested_len, next.logup_nested_len); + /* _debug // Row transition builder .when(next.prod_row) @@ -157,8 +159,6 @@ impl Air .assert_eq(ctx[2], curr_logup_n); // Termination condition - let continuation = header_continuation + prod_continuation + logup_continuation; - builder.assert_bool(continuation.clone()); assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); // Timestamp transition From 8ad853ae3ade299216056675ad9d05f9a5073603 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 15:27:52 -0400 Subject: [PATCH 28/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 9f9315c3f8..4caac41a84 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -132,8 +132,7 @@ impl Air ); builder.when(next.prod_row + next.logup_row).assert_eq(prod_nested_len, next.prod_nested_len); builder.when(next.prod_row + next.logup_row).assert_eq(logup_nested_len, next.logup_nested_len); - - /* _debug + // Row transition builder .when(next.prod_row) @@ -158,6 +157,7 @@ impl Air .when(not(logup_continuation)) .assert_eq(ctx[2], curr_logup_n); + /* _debug // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); From 4d28b2144aa757da7fcac779ab31ed19f4c904ee Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 15:37:23 -0400 Subject: [PATCH 29/41] Debug constraints: --- extensions/native/circuit/src/sumcheck/air.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 4caac41a84..9d3eb6305b 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -157,10 +157,6 @@ impl Air .when(not(logup_continuation)) .assert_eq(ctx[2], curr_logup_n); - /* _debug - // Termination condition - assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); - // Timestamp transition builder .when(header_row) @@ -175,6 +171,10 @@ impl Air .when(next.prod_row + next.logup_row) .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); + /* _debug + // Termination condition + assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); From 7be1fd8b8b4346963eadc268f9f868cdf95782ba Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 15:53:37 -0400 Subject: [PATCH 30/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 9d3eb6305b..e2ff748a93 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -174,13 +174,14 @@ impl Air /* _debug // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); + */ + // Randomness transition let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); - */ // Header let header_row_specific: &HeaderSpecificCols = From 964cee369f24e3c17206d081ff0913351ad5e1e1 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 16:05:59 -0400 Subject: [PATCH 31/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index e2ff748a93..1e7d553e9a 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -171,17 +171,17 @@ impl Air .when(next.prod_row + next.logup_row) .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); - /* _debug // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); - */ + /* _debug // Randomness transition let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + */ // Header let header_row_specific: &HeaderSpecificCols = From 2da96e0074460afd9c6caaed5aa47d77633ee059 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 16:21:53 -0400 Subject: [PATCH 32/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 1e7d553e9a..d4b0ffeeb7 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -174,11 +174,13 @@ impl Air // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); - /* _debug // Randomness transition let alpha_denominator = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), alpha_denominator.clone(), next_alpha1); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); + + /* _debug + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); */ From a3ab084556a628d2735c11347d173f3cfc9adca3 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 16:30:35 -0400 Subject: [PATCH 33/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index d4b0ffeeb7..e45a5f44c4 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -177,10 +177,10 @@ impl Air // Randomness transition let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); - - /* _debug let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); + + /* _debug let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); */ From 39afd1aace719d403c4576404caec6a4240c3699 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 18:17:35 -0400 Subject: [PATCH 34/41] Debug constraints --- extensions/native/circuit/src/sumcheck/air.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index e45a5f44c4..242613c57a 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -177,12 +177,13 @@ impl Air // Randomness transition let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + + /* _debug let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); - /* _debug - let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); - assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); */ // Header From 892d1d058d7cf5c62dcecf58272470c2778e37c0 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 21:17:32 -0400 Subject: [PATCH 35/41] Add debug flag --- extensions/native/circuit/src/sumcheck/air.rs | 6 +++--- extensions/native/circuit/src/sumcheck/trace.rs | 14 +++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 242613c57a..fe34f3de94 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -177,13 +177,13 @@ impl Air // Randomness transition let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); - let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); - assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + /* _debug + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); - */ // Header diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 1c5fced3d3..a15438aeea 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -13,6 +13,7 @@ use openvm_stark_backend::{ prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, }; +use rand::distributions::Alphanumeric; use crate::{sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}, EXT_DEG}; impl ChipUsageGetter @@ -146,7 +147,18 @@ impl NativeSumcheckChip { } else { unreachable!() } - + + // _debug + println!("=> header_row: {:?}, prod_row: {:?}, logup_row: {:?}, prod_row_continuation: {:?}, logup-row_continuation: {:?}, alpha: {:?}, challenges: {:?}", + cols.header_row, + cols.prod_row, + cols.logup_row, + cols.prod_continuation, + cols.logup_continuation, + cols.alpha, + cols.challenges, + ); + used_cells += width; } From a817f6f2b1a17c87c1286f764df15706d6c65800 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 25 Sep 2025 21:40:09 -0400 Subject: [PATCH 36/41] Remove debug flags --- extensions/native/circuit/src/sumcheck/air.rs | 8 +++----- extensions/native/circuit/src/sumcheck/trace.rs | 5 +++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index fe34f3de94..2a9969ccbe 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -174,16 +174,14 @@ impl Air // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); + /* _debug // Randomness transition + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); - - - /* _debug let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); - let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); */ // Header diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index a15438aeea..6f94b24291 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -148,7 +148,7 @@ impl NativeSumcheckChip { unreachable!() } - // _debug + /* _debug println!("=> header_row: {:?}, prod_row: {:?}, logup_row: {:?}, prod_row_continuation: {:?}, logup-row_continuation: {:?}, alpha: {:?}, challenges: {:?}", cols.header_row, cols.prod_row, @@ -157,7 +157,8 @@ impl NativeSumcheckChip { cols.logup_continuation, cols.alpha, cols.challenges, - ); + ); + */ used_cells += width; } From de7724265d2abdcfd69c76877e40c4e93536a273 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 28 Sep 2025 15:57:01 -0400 Subject: [PATCH 37/41] Debug constraint --- extensions/native/circuit/src/sumcheck/air.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 2a9969ccbe..03ba0584a1 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -173,6 +173,7 @@ impl Air // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); + assert_array_eq(&mut builder.when(header_continuation), next.challenges[0..EXT_DEG].try_into().expect(""), [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO]); /* _debug // Randomness transition From af6fc63af088c9a324cfe1b520275ee0ce626e12 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 28 Sep 2025 17:27:05 -0400 Subject: [PATCH 38/41] Correct constraints --- extensions/native/circuit/src/sumcheck/air.rs | 12 +++++------- extensions/native/circuit/src/sumcheck/chip.rs | 10 ++++++---- extensions/native/circuit/src/sumcheck/trace.rs | 11 +++++++---- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 03ba0584a1..d6cdd763d1 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -173,18 +173,16 @@ impl Air // Termination condition assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); - assert_array_eq(&mut builder.when(header_continuation), next.challenges[0..EXT_DEG].try_into().expect(""), [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO]); - - /* _debug + // Randomness transition - let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); - assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); + assert_array_eq(&mut builder.when(header_continuation), next.challenges[0..EXT_DEG].try_into().expect(""), [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO]); let alpha_denominator = FieldExtension::multiply(alpha1, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); - */ - + // Header let header_row_specific: &HeaderSpecificCols = specific[..HeaderSpecificCols::::width()].borrow(); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index fa4abc9227..c854ff1f18 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -247,9 +247,10 @@ impl InstructionExecutor for NativeSumcheckChip { prod_row.write_data_records[0] = write_slice_eval_1; let not_in_round = F::ONE - in_round; + let acc_eval = FieldExtension::multiply(alpha_acc, evals); + prod_row.acc_eval = acc_eval; + if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { - let acc_eval = FieldExtension::multiply(alpha_acc, evals); - prod_row.acc_eval = acc_eval; eval_acc = FieldExtension::add(eval_acc, acc_eval); prod_row.should_acc = true; prod_row.eval_acc = eval_acc.clone(); @@ -343,8 +344,10 @@ impl InstructionExecutor for NativeSumcheckChip { logup_row.write_data_records[1] = write_slice_eval_2; let not_in_round = F::ONE - in_round; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.alpha2 = alpha_denominator; + if (round + not_in_round) < (max_round - F::from_canonical_usize(1)) { - let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); let acc_eval = FieldExtension::add( FieldExtension::multiply(alpha_acc, p_evals), FieldExtension::multiply(alpha_denominator, q_evals), @@ -352,7 +355,6 @@ impl InstructionExecutor for NativeSumcheckChip { logup_row.acc_eval = acc_eval; eval_acc = FieldExtension::add(eval_acc, acc_eval); logup_row.should_acc = true; - logup_row.alpha2 = alpha_denominator; logup_row.eval_acc = eval_acc.clone(); } diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 6f94b24291..49235a8951 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -14,7 +14,7 @@ use openvm_stark_backend::{ AirRef, Chip, ChipUsageGetter, }; use rand::distributions::Alphanumeric; -use crate::{sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}, EXT_DEG}; +use crate::{FieldExtension, sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}, EXT_DEG}; impl ChipUsageGetter for NativeSumcheckChip @@ -148,8 +148,10 @@ impl NativeSumcheckChip { unreachable!() } - /* _debug - println!("=> header_row: {:?}, prod_row: {:?}, logup_row: {:?}, prod_row_continuation: {:?}, logup-row_continuation: {:?}, alpha: {:?}, challenges: {:?}", + // /* _debug + let alpha1: [_; EXT_DEG] = cols.challenges[0..EXT_DEG].try_into().expect(""); + let calculated_alpha_denominator = FieldExtension::multiply(alpha1, cols.alpha); + println!("=> header_row: {:?}, prod_row: {:?}, logup_row: {:?}, prod_row_continuation: {:?}, logup-row_continuation: {:?}, alpha: {:?}, challenges: {:?}, calculated_alpha_denominator: {:?}", cols.header_row, cols.prod_row, cols.logup_row, @@ -157,8 +159,9 @@ impl NativeSumcheckChip { cols.logup_continuation, cols.alpha, cols.challenges, + calculated_alpha_denominator, ); - */ + // */ used_cells += width; } From b0c1e4d0bad4c7a5ea6a5e17dfc872d715171575 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 28 Sep 2025 17:27:48 -0400 Subject: [PATCH 39/41] Remove debug flags --- extensions/native/circuit/src/sumcheck/chip.rs | 10 +--------- extensions/native/circuit/src/sumcheck/trace.rs | 15 --------------- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index c854ff1f18..b5fd474eea 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -74,8 +74,6 @@ pub struct NativeSumcheckChip { pub(super) air: NativeSumcheckAir, pub(super) offline_memory: Arc>>, pub record_set: Vec>, - // _debug - // pub(super) streams: Arc>>, } impl NativeSumcheckChip { @@ -119,13 +117,7 @@ impl InstructionExecutor for NativeSumcheckChip { if op == SUMCHECK_LAYER_EVAL.global_opcode() { let mut observation_records: Vec> = vec![]; let mut curr_timestamp: usize = 0; - - // _debug - // println!("=> column width: {:?}", NativeSumcheckCols::::width()); - // println!("=> header width: {:?}", HeaderSpecificCols::::width()); - // println!("=> prod width: {:?}", ProdSpecificCols::::width()); - // println!("=> logup width: {:?}", LogupSpecificCols::::width()); - + let (read_ctx_pointer, ctx_pointer) = memory.read_cell(register_address_space, input_register_1); let (read_cs_pointer, cs_pointer) = diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs index 49235a8951..1d3b9f0941 100644 --- a/extensions/native/circuit/src/sumcheck/trace.rs +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -148,21 +148,6 @@ impl NativeSumcheckChip { unreachable!() } - // /* _debug - let alpha1: [_; EXT_DEG] = cols.challenges[0..EXT_DEG].try_into().expect(""); - let calculated_alpha_denominator = FieldExtension::multiply(alpha1, cols.alpha); - println!("=> header_row: {:?}, prod_row: {:?}, logup_row: {:?}, prod_row_continuation: {:?}, logup-row_continuation: {:?}, alpha: {:?}, challenges: {:?}, calculated_alpha_denominator: {:?}", - cols.header_row, - cols.prod_row, - cols.logup_row, - cols.prod_continuation, - cols.logup_continuation, - cols.alpha, - cols.challenges, - calculated_alpha_denominator, - ); - // */ - used_cells += width; } From 72b6c5ee7e1924e1cc2a6c166445b7793e3fd5c5 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 8 Oct 2025 21:28:07 -0400 Subject: [PATCH 40/41] Prover field access --- crates/sdk/src/prover/agg.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index aa8fc843cb..5c6faac19a 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -23,11 +23,11 @@ use crate::{ }; pub struct AggStarkProver> { - leaf_prover: VmLocalProver, - leaf_controller: LeafProvingController, + pub leaf_prover: VmLocalProver, + pub leaf_controller: LeafProvingController, - internal_prover: VmLocalProver, - root_prover: RootVerifierLocalProver, + pub internal_prover: VmLocalProver, + pub root_prover: RootVerifierLocalProver, pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, From 2a4331b7026c4b61db4ca4da2f5d241376137c45 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 13 Oct 2025 15:57:30 -0400 Subject: [PATCH 41/41] Remove debug flag --- extensions/native/circuit/src/sumcheck/chip.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index b5fd474eea..f97a9ff379 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -373,7 +373,6 @@ impl InstructionExecutor for NativeSumcheckChip { observation_records[last_idx].continuation = false; self.record_set.extend(observation_records); - println!("=> current_height: {:?}", self.height); } else { unreachable!() }