summaryrefslogtreecommitdiff
path: root/ocaml/lib/bitstream.ml
blob: e758c2af66bd785ee930a1bb653836c56cd4cdf5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
(** Bitstream utilities for jam/cue serialization *)

(** A bitstream writer *)
type writer = {
  buf: bytes ref;          (** Buffer for bits *)
  mutable bit_pos: int;    (** Current bit position *)
}

(** A bitstream reader *)
type reader = {
  buf: bytes;              (** Buffer to read from *)
  mutable bit_pos: int;    (** Current bit position *)
  len: int;                (** Length in bits *)
}

(** Create a new bitstream writer *)
let writer_create () = {
  buf = ref (Bytes.make 1024 '\x00');
  bit_pos = 0;
}

(** Grow the writer buffer if needed *)
let writer_ensure (w : writer) (bits_needed : int) : unit =
  let bytes_needed : int = (w.bit_pos + bits_needed + 7) / 8 in
  let buf_ref : bytes ref = w.buf in
  let current_buf : bytes = !buf_ref in
  if bytes_needed > (Bytes.length current_buf) then begin
    let old_buf : bytes = current_buf in
    let new_size : int = max (bytes_needed * 2) (Bytes.length old_buf * 2) in
    let new_buf : bytes = Bytes.make new_size '\x00' in
    Bytes.blit old_buf 0 new_buf 0 (Bytes.length old_buf);
    buf_ref := new_buf
  end

(** Write a single bit *)
let write_bit w bit =
  writer_ensure w 1;
  let byte_pos = w.bit_pos / 8 in
  let bit_off = w.bit_pos mod 8 in
  if bit then begin
    let buf = !(w.buf) in
    let old_byte = Bytes.get_uint8 buf byte_pos in
    Bytes.set_uint8 buf byte_pos (old_byte lor (1 lsl bit_off))
  end;
  w.bit_pos <- w.bit_pos + 1

(** Write multiple bits from a Z.t value *)
let write_bits w value nbits =
  writer_ensure w nbits;
  for i = 0 to nbits - 1 do
    let bit = Z.testbit value i in
    write_bit w bit
  done

(** Get the final bytes from a writer *)
let writer_to_bytes (w : writer) : bytes =
  let byte_len = (w.bit_pos + 7) / 8 in
  let buf_ref : bytes ref = w.buf in
  let buf : bytes = !buf_ref in
  Bytes.sub buf 0 byte_len

(** Create a bitstream reader *)
let reader_create buf =
  {
    buf;
    bit_pos = 0;
    len = Bytes.length buf * 8;
  }

(* Lookup table for trailing zero counts within a byte. Value for 0 is 8 so
   callers can detect the "no one-bit present" case. *)
let trailing_zeros =
  let tbl = Array.make 256 8 in
  for i = 1 to 255 do
    let rec count n value =
      if value land 1 = 1 then n else count (n + 1) (value lsr 1)
    in
    tbl.(i) <- count 0 i
  done;
  tbl

(** Read a single bit *)
let read_bit r =
  if r.bit_pos >= r.len then
    raise (Invalid_argument "read_bit: end of stream");
  let byte_pos = r.bit_pos / 8 in
  let bit_off = r.bit_pos mod 8 in
  let byte_val = Bytes.get_uint8 r.buf byte_pos in
  r.bit_pos <- r.bit_pos + 1;
  (byte_val lsr bit_off) land 1 = 1

(** Read multiple bits as a Z.t - optimized for bulk reads *)
let read_bits r nbits =
  if nbits = 0 then Z.zero
  else if nbits > 4096 then begin
    (* Bulk path: copy bytes then convert to Z. *)
    let byte_len = (nbits + 7) / 8 in
    let buf = Bytes.make byte_len '\x00' in
    let bits_done = ref 0 in

    while !bits_done < nbits do
      if (!bits_done land 7) = 0 && (r.bit_pos land 7) = 0 then begin
        let rem_bits = nbits - !bits_done in
        let bytes_to_copy = rem_bits / 8 in
        if bytes_to_copy > 0 then begin
          Bytes.blit r.buf (r.bit_pos / 8) buf (!bits_done / 8) bytes_to_copy;
          r.bit_pos <- r.bit_pos + (bytes_to_copy * 8);
          bits_done := !bits_done + (bytes_to_copy * 8)
        end
      end;

      if !bits_done < nbits then begin
        if read_bit r then begin
          let byte_idx = !bits_done / 8 in
          let bit_idx = !bits_done mod 8 in
          let existing = Bytes.get_uint8 buf byte_idx in
          Bytes.set_uint8 buf byte_idx (existing lor (1 lsl bit_idx))
        end;
        incr bits_done
      end
    done;

    Z.of_bits (Bytes.unsafe_to_string buf)
  end else if nbits <= 64 && (r.bit_pos mod 8 = 0) && nbits mod 8 = 0 then begin
    (* Fast path: byte-aligned, <= 8 bytes *)
    let byte_pos = r.bit_pos / 8 in
    let num_bytes = nbits / 8 in
    r.bit_pos <- r.bit_pos + nbits;

    let result = ref Z.zero in
    for i = 0 to num_bytes - 1 do
      let byte_val = Z.of_int (Bytes.get_uint8 r.buf (byte_pos + i)) in
      result := Z.logor !result (Z.shift_left byte_val (i * 8))
    done;
    !result
  end else if nbits >= 8 then begin
    (* Mixed path: read whole bytes + remaining bits *)
    let result = ref Z.zero in
    let bits_read = ref 0 in

    (* Read as many whole bytes as possible *)
    while !bits_read + 8 <= nbits && (r.bit_pos mod 8 <> 0 || !bits_read = 0) do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one !bits_read);
      bits_read := !bits_read + 1
    done;

    (* Now read whole bytes efficiently if byte-aligned *)
    while !bits_read + 8 <= nbits && (r.bit_pos mod 8 = 0) do
      let byte_pos = r.bit_pos / 8 in
      let byte_val = Z.of_int (Bytes.get_uint8 r.buf byte_pos) in
      result := Z.logor !result (Z.shift_left byte_val !bits_read);
      r.bit_pos <- r.bit_pos + 8;
      bits_read := !bits_read + 8
    done;

    (* Read remaining bits *)
    while !bits_read < nbits do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one !bits_read);
      bits_read := !bits_read + 1
    done;
    !result
  end else begin
    (* Small reads: use original bit-by-bit approach *)
    let result = ref Z.zero in
    for i = 0 to nbits - 1 do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one i)
    done;
    !result
  end

(** Peek at a bit without advancing *)
let peek_bit r =
  if r.bit_pos >= r.len then
    raise (Invalid_argument "peek_bit: end of stream");
  let byte_pos = r.bit_pos / 8 in
  let bit_off = r.bit_pos mod 8 in
  let byte_val = Bytes.get_uint8 r.buf byte_pos in
  (byte_val lsr bit_off) land 1 = 1

(** Get current bit position *)
let reader_pos r = r.bit_pos

(** Check if at end of stream *)
let reader_at_end r = r.bit_pos >= r.len

let count_zero_bits_until_one r =
  let buf = r.buf in
  let len_bits = r.len in
  let rec scan count bit_pos =
    if bit_pos >= len_bits then
      raise (Invalid_argument "count_zero_bits_until_one: end of stream")
    else begin
      let byte_idx = bit_pos lsr 3 in
      let bit_off = bit_pos land 7 in
      let byte = Bytes.get_uint8 buf byte_idx in
      let masked = byte lsr bit_off in
      if masked <> 0 then begin
        let tz = trailing_zeros.(masked land 0xff) in
        let zeros = count + tz in
        r.bit_pos <- bit_pos + tz + 1; (* skip zeros and the terminating 1 bit *)
        zeros
      end else
        let remaining = 8 - bit_off in
        scan (count + remaining) (bit_pos + remaining)
    end
  in
  scan 0 r.bit_pos