summaryrefslogtreecommitdiff
path: root/ocaml/lib/bitstream.ml
blob: 39bfd6aef5d0966569171c17b0565d02c2eda5ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
(** Bitstream utilities for jam/cue serialization *)

(** A bitstream writer *)
type writer = {
  buf: bytes ref;          (** Buffer for bits *)
  mutable bit_pos: int;    (** Current bit position *)
}

(** A bitstream reader *)
type reader = {
  buf: bytes;              (** Buffer to read from *)
  mutable bit_pos: int;    (** Current bit position *)
  len: int;                (** Length in bits *)
}

(** Create a new bitstream writer *)
let writer_create () = {
  buf = ref (Bytes.make 1024 '\x00');
  bit_pos = 0;
}

(** Grow the writer buffer if needed *)
let writer_ensure (w : writer) (bits_needed : int) : unit =
  let bytes_needed : int = (w.bit_pos + bits_needed + 7) / 8 in
  let buf_ref : bytes ref = w.buf in
  let current_buf : bytes = !buf_ref in
  if bytes_needed > (Bytes.length current_buf) then begin
    let old_buf : bytes = current_buf in
    let new_size : int = max (bytes_needed * 2) (Bytes.length old_buf * 2) in
    let new_buf : bytes = Bytes.make new_size '\x00' in
    Bytes.blit old_buf 0 new_buf 0 (Bytes.length old_buf);
    buf_ref := new_buf
  end

(** Write a single bit *)
let write_bit w bit =
  writer_ensure w 1;
  let byte_pos = w.bit_pos / 8 in
  let bit_off = w.bit_pos mod 8 in
  if bit then begin
    let buf = !(w.buf) in
    let old_byte = Bytes.get_uint8 buf byte_pos in
    Bytes.set_uint8 buf byte_pos (old_byte lor (1 lsl bit_off))
  end;
  w.bit_pos <- w.bit_pos + 1

(** Write multiple bits from a Z.t value *)
let write_bits w value nbits =
  writer_ensure w nbits;
  for i = 0 to nbits - 1 do
    let bit = Z.testbit value i in
    write_bit w bit
  done

(** Get the final bytes from a writer *)
let writer_to_bytes (w : writer) : bytes =
  let byte_len = (w.bit_pos + 7) / 8 in
  let buf_ref : bytes ref = w.buf in
  let buf : bytes = !buf_ref in
  Bytes.sub buf 0 byte_len

(** Create a bitstream reader *)
let reader_create buf =
  {
    buf;
    bit_pos = 0;
    len = Bytes.length buf * 8;
  }

(** Read a single bit *)
let read_bit r =
  if r.bit_pos >= r.len then
    raise (Invalid_argument "read_bit: end of stream");
  let byte_pos = r.bit_pos / 8 in
  let bit_off = r.bit_pos mod 8 in
  let byte_val = Bytes.get_uint8 r.buf byte_pos in
  r.bit_pos <- r.bit_pos + 1;
  (byte_val lsr bit_off) land 1 = 1

(** Read multiple bits as a Z.t - optimized for bulk reads *)
let read_bits r nbits =
  if nbits = 0 then Z.zero
  else if nbits <= 64 && (r.bit_pos mod 8 = 0) && nbits mod 8 = 0 then begin
    (* Fast path: byte-aligned, <= 8 bytes *)
    let byte_pos = r.bit_pos / 8 in
    let num_bytes = nbits / 8 in
    r.bit_pos <- r.bit_pos + nbits;

    let result = ref Z.zero in
    for i = 0 to num_bytes - 1 do
      let byte_val = Z.of_int (Bytes.get_uint8 r.buf (byte_pos + i)) in
      result := Z.logor !result (Z.shift_left byte_val (i * 8))
    done;
    !result
  end else if nbits >= 8 then begin
    (* Mixed path: read whole bytes + remaining bits *)
    let result = ref Z.zero in
    let bits_read = ref 0 in

    (* Read as many whole bytes as possible *)
    while !bits_read + 8 <= nbits && (r.bit_pos mod 8 <> 0 || !bits_read = 0) do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one !bits_read);
      bits_read := !bits_read + 1
    done;

    (* Now read whole bytes efficiently if byte-aligned *)
    while !bits_read + 8 <= nbits && (r.bit_pos mod 8 = 0) do
      let byte_pos = r.bit_pos / 8 in
      let byte_val = Z.of_int (Bytes.get_uint8 r.buf byte_pos) in
      result := Z.logor !result (Z.shift_left byte_val !bits_read);
      r.bit_pos <- r.bit_pos + 8;
      bits_read := !bits_read + 8
    done;

    (* Read remaining bits *)
    while !bits_read < nbits do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one !bits_read);
      bits_read := !bits_read + 1
    done;
    !result
  end else begin
    (* Small reads: use original bit-by-bit approach *)
    let result = ref Z.zero in
    for i = 0 to nbits - 1 do
      if read_bit r then
        result := Z.logor !result (Z.shift_left Z.one i)
    done;
    !result
  end

(** Peek at a bit without advancing *)
let peek_bit r =
  if r.bit_pos >= r.len then
    raise (Invalid_argument "peek_bit: end of stream");
  let byte_pos = r.bit_pos / 8 in
  let bit_off = r.bit_pos mod 8 in
  let byte_val = Bytes.get_uint8 r.buf byte_pos in
  (byte_val lsr bit_off) land 1 = 1

(** Get current bit position *)
let reader_pos r = r.bit_pos

(** Check if at end of stream *)
let reader_at_end r = r.bit_pos >= r.len