(** Bitstream utilities for jam/cue serialization *) (** A bitstream writer *) type writer = { buf: bytes ref; (** Buffer for bits *) mutable bit_pos: int; (** Current bit position *) } (** A bitstream reader *) type reader = { buf: bytes; (** Buffer to read from *) mutable bit_pos: int; (** Current bit position *) len: int; (** Length in bits *) } (** Create a new bitstream writer *) let writer_create () = { buf = ref (Bytes.make 1024 '\x00'); bit_pos = 0; } (** Grow the writer buffer if needed *) let writer_ensure (w : writer) (bits_needed : int) : unit = let bytes_needed : int = (w.bit_pos + bits_needed + 7) / 8 in let buf_ref : bytes ref = w.buf in let current_buf : bytes = !buf_ref in if bytes_needed > (Bytes.length current_buf) then begin let old_buf : bytes = current_buf in let new_size : int = max (bytes_needed * 2) (Bytes.length old_buf * 2) in let new_buf : bytes = Bytes.make new_size '\x00' in Bytes.blit old_buf 0 new_buf 0 (Bytes.length old_buf); buf_ref := new_buf end (** Write a single bit *) let write_bit w bit = writer_ensure w 1; let byte_pos = w.bit_pos / 8 in let bit_off = w.bit_pos mod 8 in if bit then begin let buf = !(w.buf) in let old_byte = Bytes.get_uint8 buf byte_pos in Bytes.set_uint8 buf byte_pos (old_byte lor (1 lsl bit_off)) end; w.bit_pos <- w.bit_pos + 1 (** Write multiple bits from a Z.t value *) let write_bits w value nbits = writer_ensure w nbits; for i = 0 to nbits - 1 do let bit = Z.testbit value i in write_bit w bit done (** Get the final bytes from a writer *) let writer_to_bytes (w : writer) : bytes = let byte_len = (w.bit_pos + 7) / 8 in let buf_ref : bytes ref = w.buf in let buf : bytes = !buf_ref in Bytes.sub buf 0 byte_len (** Create a bitstream reader *) let reader_create buf = { buf; bit_pos = 0; len = Bytes.length buf * 8; } (** Read a single bit *) let read_bit r = if r.bit_pos >= r.len then raise (Invalid_argument "read_bit: end of stream"); let byte_pos = r.bit_pos / 8 in let bit_off = r.bit_pos mod 8 in let byte_val = Bytes.get_uint8 r.buf byte_pos in r.bit_pos <- r.bit_pos + 1; (byte_val lsr bit_off) land 1 = 1 (** Read multiple bits as a Z.t - optimized for bulk reads *) let read_bits r nbits = if nbits = 0 then Z.zero else if nbits <= 64 && (r.bit_pos mod 8 = 0) && nbits mod 8 = 0 then begin (* Fast path: byte-aligned, <= 8 bytes *) let byte_pos = r.bit_pos / 8 in let num_bytes = nbits / 8 in r.bit_pos <- r.bit_pos + nbits; let result = ref Z.zero in for i = 0 to num_bytes - 1 do let byte_val = Z.of_int (Bytes.get_uint8 r.buf (byte_pos + i)) in result := Z.logor !result (Z.shift_left byte_val (i * 8)) done; !result end else if nbits >= 8 then begin (* Mixed path: read whole bytes + remaining bits *) let result = ref Z.zero in let bits_read = ref 0 in (* Read as many whole bytes as possible *) while !bits_read + 8 <= nbits && (r.bit_pos mod 8 <> 0 || !bits_read = 0) do if read_bit r then result := Z.logor !result (Z.shift_left Z.one !bits_read); bits_read := !bits_read + 1 done; (* Now read whole bytes efficiently if byte-aligned *) while !bits_read + 8 <= nbits && (r.bit_pos mod 8 = 0) do let byte_pos = r.bit_pos / 8 in let byte_val = Z.of_int (Bytes.get_uint8 r.buf byte_pos) in result := Z.logor !result (Z.shift_left byte_val !bits_read); r.bit_pos <- r.bit_pos + 8; bits_read := !bits_read + 8 done; (* Read remaining bits *) while !bits_read < nbits do if read_bit r then result := Z.logor !result (Z.shift_left Z.one !bits_read); bits_read := !bits_read + 1 done; !result end else begin (* Small reads: use original bit-by-bit approach *) let result = ref Z.zero in for i = 0 to nbits - 1 do if read_bit r then result := Z.logor !result (Z.shift_left Z.one i) done; !result end (** Peek at a bit without advancing *) let peek_bit r = if r.bit_pos >= r.len then raise (Invalid_argument "peek_bit: end of stream"); let byte_pos = r.bit_pos / 8 in let bit_off = r.bit_pos mod 8 in let byte_val = Bytes.get_uint8 r.buf byte_pos in (byte_val lsr bit_off) land 1 = 1 (** Get current bit position *) let reader_pos r = r.bit_pos (** Check if at end of stream *) let reader_at_end r = r.bit_pos >= r.len