summaryrefslogtreecommitdiff
path: root/ocaml/lib/bitstream.ml
blob: 748e204f49bacdd059d7e487e2931b2edec0b012 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
type writer = {
  buf_ref : Bytes.t Stdlib.ref;
  mutable bit_pos : int;
}

type reader = {
  buf: bytes;
  mutable bit_pos: int;
  len: int;
}

let writer_create () =
  { buf_ref = ref (Bytes.make 1024 '\x00'); bit_pos = 0 }

let writer_ensure (w : writer) bits_needed =
  let bytes_needed = (w.bit_pos + bits_needed + 7) / 8 in
  let buf = !(w.buf_ref) in
  if bytes_needed > Bytes.length buf then begin
    let new_size = max (bytes_needed * 2) (Bytes.length buf * 2) in
    let new_buf = Bytes.make new_size '\x00' in
    Bytes.blit buf 0 new_buf 0 (Bytes.length buf);
    w.buf_ref := new_buf
  end

let write_bit (w : writer) bit =
  writer_ensure w 1;
  let byte_pos = w.bit_pos / 8 in
  let bit_off = w.bit_pos mod 8 in
  if bit then begin
    let buf = !(w.buf_ref) in
    let old_byte = Bytes.get_uint8 buf byte_pos in
    Bytes.set_uint8 buf byte_pos (old_byte lor (1 lsl bit_off))
  end;
  w.bit_pos <- w.bit_pos + 1

let write_bits (w : writer) value nbits =
  writer_ensure w nbits;
  for i = 0 to nbits - 1 do
    let bit = Z.testbit value i in
    write_bit w bit
  done

let writer_to_bytes (w : writer) =
  let byte_len = (w.bit_pos + 7) / 8 in
  Bytes.sub !(w.buf_ref) 0 byte_len

let writer_pos (w : writer) = w.bit_pos

let reader_create buf =
  { buf; bit_pos = 0; len = Bytes.length buf * 8 }

let read_bit r =
  if r.bit_pos >= r.len then
    raise (Invalid_argument "read_bit: end of stream");
  let byte_pos = r.bit_pos / 8 in
  let bit_off = r.bit_pos mod 8 in
  let byte_val = Bytes.get_uint8 r.buf byte_pos in
  r.bit_pos <- r.bit_pos + 1;
  (byte_val lsr bit_off) land 1 = 1

let read_bits r nbits =
  if nbits = 0 then Z.zero
  else if nbits >= 64 then begin
    (* For large reads, align to byte boundary then use fast path *)
    let bit_offset = r.bit_pos mod 8 in
    let result = ref Z.zero in
    let bits_read = ref 0 in

    (* Read initial unaligned bits to reach byte boundary *)
    if bit_offset <> 0 then begin
      let align_bits = 8 - bit_offset in
      let to_read = min align_bits nbits in
      for i = 0 to to_read - 1 do
        if read_bit r then
          result := Z.logor !result (Z.shift_left Z.one i)
      done;
      bits_read := to_read
    end;

    (* Now we're byte-aligned, read full bytes directly *)
    if !bits_read < nbits then begin
      let remaining = nbits - !bits_read in
      let full_bytes = remaining / 8 in

      if full_bytes > 0 then begin
        let byte_pos = r.bit_pos / 8 in
        let bytes_data = Bytes.sub r.buf byte_pos full_bytes in
        let bytes_value = Z.of_bits (Bytes.to_string bytes_data) in
        result := Z.logor !result (Z.shift_left bytes_value !bits_read);
        r.bit_pos <- r.bit_pos + (full_bytes * 8);
        bits_read := !bits_read + (full_bytes * 8)
      end;

      (* Read final partial byte *)
      let final_bits = nbits - !bits_read in
      for i = 0 to final_bits - 1 do
        if read_bit r then
          result := Z.logor !result (Z.shift_left Z.one (!bits_read + i))
      done
    end;
    !result
  end else begin
    (* Small read: bit by bit is fine *)
    let result = ref Z.zero in
    for i = 0 to nbits - 1 do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one i)
    done;
    !result
  end

let reader_pos r = r.bit_pos

let count_zero_bits_until_one r =
  let rec loop count =
    if read_bit r then count else loop (count + 1)
  in
  loop 0