blob: e758c2af66bd785ee930a1bb653836c56cd4cdf5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
|
(** Bitstream utilities for jam/cue serialization *)
(** A bitstream writer *)
type writer = {
buf: bytes ref; (** Buffer for bits *)
mutable bit_pos: int; (** Current bit position *)
}
(** A bitstream reader *)
type reader = {
buf: bytes; (** Buffer to read from *)
mutable bit_pos: int; (** Current bit position *)
len: int; (** Length in bits *)
}
(** Create a new bitstream writer *)
let writer_create () = {
buf = ref (Bytes.make 1024 '\x00');
bit_pos = 0;
}
(** Grow the writer buffer if needed *)
let writer_ensure (w : writer) (bits_needed : int) : unit =
let bytes_needed : int = (w.bit_pos + bits_needed + 7) / 8 in
let buf_ref : bytes ref = w.buf in
let current_buf : bytes = !buf_ref in
if bytes_needed > (Bytes.length current_buf) then begin
let old_buf : bytes = current_buf in
let new_size : int = max (bytes_needed * 2) (Bytes.length old_buf * 2) in
let new_buf : bytes = Bytes.make new_size '\x00' in
Bytes.blit old_buf 0 new_buf 0 (Bytes.length old_buf);
buf_ref := new_buf
end
(** Write a single bit *)
let write_bit w bit =
writer_ensure w 1;
let byte_pos = w.bit_pos / 8 in
let bit_off = w.bit_pos mod 8 in
if bit then begin
let buf = !(w.buf) in
let old_byte = Bytes.get_uint8 buf byte_pos in
Bytes.set_uint8 buf byte_pos (old_byte lor (1 lsl bit_off))
end;
w.bit_pos <- w.bit_pos + 1
(** Write multiple bits from a Z.t value *)
let write_bits w value nbits =
writer_ensure w nbits;
for i = 0 to nbits - 1 do
let bit = Z.testbit value i in
write_bit w bit
done
(** Get the final bytes from a writer *)
let writer_to_bytes (w : writer) : bytes =
let byte_len = (w.bit_pos + 7) / 8 in
let buf_ref : bytes ref = w.buf in
let buf : bytes = !buf_ref in
Bytes.sub buf 0 byte_len
(** Create a bitstream reader *)
let reader_create buf =
{
buf;
bit_pos = 0;
len = Bytes.length buf * 8;
}
(* Lookup table for trailing zero counts within a byte. Value for 0 is 8 so
callers can detect the "no one-bit present" case. *)
let trailing_zeros =
let tbl = Array.make 256 8 in
for i = 1 to 255 do
let rec count n value =
if value land 1 = 1 then n else count (n + 1) (value lsr 1)
in
tbl.(i) <- count 0 i
done;
tbl
(** Read a single bit *)
let read_bit r =
if r.bit_pos >= r.len then
raise (Invalid_argument "read_bit: end of stream");
let byte_pos = r.bit_pos / 8 in
let bit_off = r.bit_pos mod 8 in
let byte_val = Bytes.get_uint8 r.buf byte_pos in
r.bit_pos <- r.bit_pos + 1;
(byte_val lsr bit_off) land 1 = 1
(** Read multiple bits as a Z.t - optimized for bulk reads *)
let read_bits r nbits =
if nbits = 0 then Z.zero
else if nbits > 4096 then begin
(* Bulk path: copy bytes then convert to Z. *)
let byte_len = (nbits + 7) / 8 in
let buf = Bytes.make byte_len '\x00' in
let bits_done = ref 0 in
while !bits_done < nbits do
if (!bits_done land 7) = 0 && (r.bit_pos land 7) = 0 then begin
let rem_bits = nbits - !bits_done in
let bytes_to_copy = rem_bits / 8 in
if bytes_to_copy > 0 then begin
Bytes.blit r.buf (r.bit_pos / 8) buf (!bits_done / 8) bytes_to_copy;
r.bit_pos <- r.bit_pos + (bytes_to_copy * 8);
bits_done := !bits_done + (bytes_to_copy * 8)
end
end;
if !bits_done < nbits then begin
if read_bit r then begin
let byte_idx = !bits_done / 8 in
let bit_idx = !bits_done mod 8 in
let existing = Bytes.get_uint8 buf byte_idx in
Bytes.set_uint8 buf byte_idx (existing lor (1 lsl bit_idx))
end;
incr bits_done
end
done;
Z.of_bits (Bytes.unsafe_to_string buf)
end else if nbits <= 64 && (r.bit_pos mod 8 = 0) && nbits mod 8 = 0 then begin
(* Fast path: byte-aligned, <= 8 bytes *)
let byte_pos = r.bit_pos / 8 in
let num_bytes = nbits / 8 in
r.bit_pos <- r.bit_pos + nbits;
let result = ref Z.zero in
for i = 0 to num_bytes - 1 do
let byte_val = Z.of_int (Bytes.get_uint8 r.buf (byte_pos + i)) in
result := Z.logor !result (Z.shift_left byte_val (i * 8))
done;
!result
end else if nbits >= 8 then begin
(* Mixed path: read whole bytes + remaining bits *)
let result = ref Z.zero in
let bits_read = ref 0 in
(* Read as many whole bytes as possible *)
while !bits_read + 8 <= nbits && (r.bit_pos mod 8 <> 0 || !bits_read = 0) do
if read_bit r then
result := Z.logor !result (Z.shift_left Z.one !bits_read);
bits_read := !bits_read + 1
done;
(* Now read whole bytes efficiently if byte-aligned *)
while !bits_read + 8 <= nbits && (r.bit_pos mod 8 = 0) do
let byte_pos = r.bit_pos / 8 in
let byte_val = Z.of_int (Bytes.get_uint8 r.buf byte_pos) in
result := Z.logor !result (Z.shift_left byte_val !bits_read);
r.bit_pos <- r.bit_pos + 8;
bits_read := !bits_read + 8
done;
(* Read remaining bits *)
while !bits_read < nbits do
if read_bit r then
result := Z.logor !result (Z.shift_left Z.one !bits_read);
bits_read := !bits_read + 1
done;
!result
end else begin
(* Small reads: use original bit-by-bit approach *)
let result = ref Z.zero in
for i = 0 to nbits - 1 do
if read_bit r then
result := Z.logor !result (Z.shift_left Z.one i)
done;
!result
end
(** Peek at a bit without advancing *)
let peek_bit r =
if r.bit_pos >= r.len then
raise (Invalid_argument "peek_bit: end of stream");
let byte_pos = r.bit_pos / 8 in
let bit_off = r.bit_pos mod 8 in
let byte_val = Bytes.get_uint8 r.buf byte_pos in
(byte_val lsr bit_off) land 1 = 1
(** Get current bit position *)
let reader_pos r = r.bit_pos
(** Check if at end of stream *)
let reader_at_end r = r.bit_pos >= r.len
let count_zero_bits_until_one r =
let buf = r.buf in
let len_bits = r.len in
let rec scan count bit_pos =
if bit_pos >= len_bits then
raise (Invalid_argument "count_zero_bits_until_one: end of stream")
else begin
let byte_idx = bit_pos lsr 3 in
let bit_off = bit_pos land 7 in
let byte = Bytes.get_uint8 buf byte_idx in
let masked = byte lsr bit_off in
if masked <> 0 then begin
let tz = trailing_zeros.(masked land 0xff) in
let zeros = count + tz in
r.bit_pos <- bit_pos + tz + 1; (* skip zeros and the terminating 1 bit *)
zeros
end else
let remaining = 8 - bit_off in
scan (count + remaining) (bit_pos + remaining)
end
in
scan 0 r.bit_pos
|