Skip to content

Commit 72fe5c1

Browse files
committed
Decode CELT PVQ bands
1 parent db0d870 commit 72fe5c1

8 files changed

Lines changed: 1436 additions & 0 deletions

File tree

internal/celt/bands.go

Lines changed: 786 additions & 0 deletions
Large diffs are not rendered by default.

internal/celt/bands_test.go

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
package celt
5+
6+
import (
7+
"testing"
8+
9+
"github.com/pion/opus/internal/rangecoding"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestQuantBandSingleBin(t *testing.T) {
15+
decoder := rangeDecoderWithRawBits(0b00000001)
16+
state := bandDecodeState{rangeDecoder: &decoder}
17+
x := []float32{0}
18+
y := []float32{0}
19+
remainingBits := 2 << bitResolution
20+
21+
mask := quantBand(
22+
0, x, y, 1, 2<<bitResolution, spreadNormal, 1, maxBands, 0, nil,
23+
&remainingBits, 0, nil, 0, 1, nil, 1, &state,
24+
)
25+
26+
assert.Equal(t, uint(1), mask)
27+
assert.Equal(t, float32(-1), x[0])
28+
assert.Equal(t, float32(1), y[0])
29+
}
30+
31+
func TestQuantBandFoldedAndPulsePaths(t *testing.T) {
32+
t.Run("zeros when no fold source is available", func(t *testing.T) {
33+
state := bandDecodeState{rangeDecoder: rangeDecoderForBandTests(), seed: 1}
34+
x := []float32{7, 7, 7, 7}
35+
remainingBits := 0
36+
37+
mask := quantBand(
38+
0, x, nil, len(x), 0, spreadNormal, 1, maxBands, 0, nil,
39+
&remainingBits, 0, nil, 1, 1, nil, 0, &state,
40+
)
41+
42+
assert.Zero(t, mask)
43+
assert.Equal(t, []float32{0, 0, 0, 0}, x)
44+
})
45+
46+
t.Run("fills from deterministic noise", func(t *testing.T) {
47+
state := bandDecodeState{rangeDecoder: rangeDecoderForBandTests(), seed: 1}
48+
x := make([]float32, 4)
49+
remainingBits := 0
50+
51+
mask := quantBand(
52+
0, x, nil, len(x), 0, spreadNormal, 1, maxBands, 0, nil,
53+
&remainingBits, 0, nil, 1, 1, nil, 1, &state,
54+
)
55+
56+
assert.Equal(t, uint(1), mask)
57+
assert.InDelta(t, 1, vectorEnergy(x), 0.001)
58+
})
59+
60+
t.Run("folds from a previous lowband", func(t *testing.T) {
61+
state := bandDecodeState{rangeDecoder: rangeDecoderForBandTests(), seed: 1}
62+
x := make([]float32, 4)
63+
lowband := []float32{0.25, -0.25, 0.5, -0.5}
64+
remainingBits := 0
65+
66+
mask := quantBand(
67+
0, x, nil, len(x), 0, spreadNormal, 1, maxBands, 0, lowband,
68+
&remainingBits, 0, nil, 1, 1, nil, 1, &state,
69+
)
70+
71+
assert.Equal(t, uint(1), mask)
72+
assert.InDelta(t, 1, vectorEnergy(x), 0.001)
73+
})
74+
75+
t.Run("decodes algebraic pulses", func(t *testing.T) {
76+
decoder := rangeDecoderWithCDFSymbol(0, cwrsUrow(4, 1)[1]+cwrsUrow(4, 1)[2])
77+
state := bandDecodeState{rangeDecoder: &decoder}
78+
x := make([]float32, 4)
79+
remainingBits := 16
80+
81+
mask := quantBand(
82+
0, x, nil, len(x), 8, spreadNormal, 1, maxBands, 0, nil,
83+
&remainingBits, 0, nil, 1, 1, nil, 1, &state,
84+
)
85+
86+
assert.Equal(t, uint(1), mask)
87+
assert.InDelta(t, 1, vectorEnergy(x), 0.001)
88+
})
89+
}
90+
91+
func TestQuantBandSplits(t *testing.T) {
92+
t.Run("mono split", func(t *testing.T) {
93+
state := bandDecodeState{rangeDecoder: rangeDecoderForBandTests(), seed: 1}
94+
x := make([]float32, 8)
95+
lowbandOut := make([]float32, 8)
96+
scratch := make([]float32, 8)
97+
remainingBits := 512
98+
99+
mask := quantBand(
100+
4, x, nil, len(x), 320, spreadNormal, 1, maxBands, 0, nil,
101+
&remainingBits, 2, lowbandOut, 0, 1, scratch, 1, &state,
102+
)
103+
104+
assert.NotZero(t, mask)
105+
assert.InDelta(t, 1, vectorEnergy(x), 0.001)
106+
assert.NotZero(t, vectorEnergy(lowbandOut))
107+
})
108+
109+
t.Run("stereo split", func(t *testing.T) {
110+
state := bandDecodeState{rangeDecoder: rangeDecoderForBandTests(), seed: 1}
111+
x := make([]float32, 4)
112+
y := make([]float32, 4)
113+
lowbandOut := make([]float32, 4)
114+
scratch := make([]float32, 4)
115+
remainingBits := 512
116+
117+
mask := quantBand(
118+
4, x, y, len(x), 320, spreadNormal, 1, maxBands, 0, nil,
119+
&remainingBits, 2, lowbandOut, 0, 1, scratch, 1, &state,
120+
)
121+
122+
assert.NotZero(t, mask)
123+
assert.InDelta(t, 1, vectorEnergy(x), 0.001)
124+
assert.InDelta(t, 1, vectorEnergy(y), 0.001)
125+
})
126+
}
127+
128+
func TestQuantAllBands(t *testing.T) {
129+
decoder := rangeDecoderWithCDFSymbol(0, 64)
130+
state := bandDecodeState{rangeDecoder: &decoder, seed: 1}
131+
info := frameSideInfo{
132+
lm: 0,
133+
totalBits: 128,
134+
startBand: 0,
135+
endBand: 4,
136+
channelCount: 2,
137+
spread: spreadNormal,
138+
allocation: allocationState{
139+
codedBands: 4,
140+
intensity: 3,
141+
dualStereo: 1,
142+
},
143+
}
144+
for band := info.startBand; band < info.endBand; band++ {
145+
info.allocation.pulses[band] = 8
146+
}
147+
x := make([]float32, int(bandEdges[maxBands]))
148+
y := make([]float32, int(bandEdges[maxBands]))
149+
150+
masks := quantAllBands(&info, x, y, 128<<bitResolution, &state)
151+
152+
require.Len(t, masks, 2*maxBands)
153+
assert.NotZero(t, masks[0])
154+
assert.NotZero(t, masks[1])
155+
assert.NotZero(t, vectorEnergy(x[:int(bandEdges[info.endBand])]))
156+
assert.NotZero(t, vectorEnergy(y[:int(bandEdges[info.endBand])]))
157+
}
158+
159+
func TestBandMathHelpers(t *testing.T) {
160+
assert.False(t, shouldSplitBand(0, 0, 0))
161+
assert.True(t, shouldSplitBand(4, 2, 320))
162+
163+
assert.Equal(t, 1, computeQN(4, 0, 0, 0, false))
164+
assert.Greater(t, computeQN(4, 320, 0, 0, false), 1)
165+
assert.Greater(t, computeQN(2, 320, 0, 0, true), 1)
166+
167+
assert.Equal(t, 32768, bitexactCos(0))
168+
assert.InDelta(t, 23171, bitexactCos(8192), 2)
169+
assert.Equal(t, 0, bitexactLog2Tan(16384, 16384))
170+
assert.Equal(t, -2, fracMul16(2<<14, 2))
171+
assert.Equal(t, uint32(0), isqrt32(0))
172+
assert.Equal(t, uint32(12), isqrt32(144))
173+
assert.Equal(t, uint32(1015568748), lcgRand(1))
174+
assert.Equal(t, uint(0b0001), bitInterleave(0b0011))
175+
assert.Equal(t, uint(0x0C), bitDeinterleave(0b0010))
176+
}
177+
178+
func TestHadamardHelpers(t *testing.T) {
179+
vector := []float32{1, 2, 3, 4}
180+
haar1(vector, len(vector), 1)
181+
assert.InDelta(t, 2.12132, vector[0], 0.0001)
182+
assert.InDelta(t, -0.7, vector[1], 0.1)
183+
184+
vector = []float32{1, 2, 3, 4}
185+
state := bandDecodeState{}
186+
deinterleaveHadamard(vector, 2, 2, false, &state)
187+
assert.Equal(t, []float32{1, 3, 2, 4}, vector)
188+
interleaveHadamard(vector, 2, 2, false, &state)
189+
assert.Equal(t, []float32{1, 2, 3, 4}, vector)
190+
assert.Len(t, state.tmpScratch, len(vector))
191+
192+
vector = []float32{1, 2, 3, 4}
193+
deinterleaveHadamard(vector, 2, 2, true, &state)
194+
interleaveHadamard(vector, 2, 2, true, &state)
195+
assert.Equal(t, []float32{1, 2, 3, 4}, vector)
196+
}
197+
198+
func TestQuantAllBandsIgnoresDualStereoWithoutSecondChannel(t *testing.T) {
199+
decoder := rangeDecoderWithCDFSymbol(0, 64)
200+
state := bandDecodeState{rangeDecoder: &decoder, seed: 1}
201+
info := frameSideInfo{
202+
lm: 0,
203+
totalBits: 128,
204+
startBand: 0,
205+
endBand: 2,
206+
channelCount: 1,
207+
spread: spreadNormal,
208+
allocation: allocationState{
209+
codedBands: 2,
210+
intensity: 1,
211+
dualStereo: 1,
212+
},
213+
}
214+
for band := info.startBand; band < info.endBand; band++ {
215+
info.allocation.pulses[band] = 8
216+
}
217+
x := make([]float32, int(bandEdges[maxBands]))
218+
219+
masks := quantAllBands(&info, x, nil, 128<<bitResolution, &state)
220+
221+
require.Len(t, masks, maxBands)
222+
assert.NotZero(t, masks[0])
223+
}
224+
225+
func TestDecodeBandTheta(t *testing.T) {
226+
decoder := rangeDecoderWithCDFSymbol(0, 7)
227+
assert.Equal(t, 0, decodeBandTheta(4, 4, true, 1, &decoder))
228+
229+
decoder = rangeDecoderWithCDFSymbol(2, 5)
230+
assert.Equal(t, 2, decodeBandTheta(4, 2, false, 2, &decoder))
231+
232+
decoder = rangeDecoderWithCDFSymbol(0, 9)
233+
assert.Equal(t, 0, decodeBandTheta(4, 4, false, 1, &decoder))
234+
}
235+
236+
func TestYBandSlice(t *testing.T) {
237+
assert.Nil(t, yBandSlice(nil, 0, 1))
238+
assert.Equal(t, []float32{2, 3}, yBandSlice([]float32{1, 2, 3, 4}, 1, 3))
239+
}
240+
241+
func rangeDecoderForBandTests() *rangecoding.Decoder {
242+
decoder := rangecoding.Decoder{}
243+
decoder.SetInternalValues(make([]byte, 16), 0, 1<<31, 0)
244+
245+
return &decoder
246+
}

internal/celt/cwrs.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
//nolint:varnamelen // CWRS notation follows RFC/reference vector names.
5+
package celt
6+
7+
import "github.com/pion/opus/internal/rangecoding"
8+
9+
// decodePulses implements the RFC 6716 Section 4.3.4.2 CWRS index decode for
10+
// the PVQ pulse vector. The row buffer stores one recurrence row of V(N,K).
11+
func decodePulses(y []int, n, k int, rangeDecoder *rangecoding.Decoder) {
12+
for i := range n {
13+
y[i] = 0
14+
}
15+
if k <= 0 {
16+
return
17+
}
18+
19+
u := cwrsUrow(n, k)
20+
total := u[k] + u[k+1]
21+
index, _ := rangeDecoder.DecodeUniform(total)
22+
cwrsDecode(y, n, k, index, u)
23+
}
24+
25+
// cwrsUrow initializes the recurrence row needed to count PVQ codewords for a
26+
// vector of n dimensions and up to k pulses.
27+
func cwrsUrow(n, k int) []uint32 {
28+
row := make([]uint32, k+2)
29+
if n == 0 {
30+
row[0] = 1
31+
32+
return row
33+
}
34+
row[0] = 0
35+
if len(row) > 1 {
36+
row[1] = 1
37+
}
38+
if n == 1 {
39+
for i := 2; i < len(row); i++ {
40+
row[i] = 1
41+
}
42+
43+
return row
44+
}
45+
for pulses := 2; pulses < len(row); pulses++ {
46+
row[pulses] = uint32((pulses << 1) - 1)
47+
}
48+
for rowIndex := 2; rowIndex < n; rowIndex++ {
49+
cwrsNextRow(row[1:], 1)
50+
}
51+
52+
return row
53+
}
54+
55+
// cwrsNextRow advances the V(N,K) recurrence by one dimension.
56+
func cwrsNextRow(u []uint32, value0 uint32) {
57+
value := value0
58+
for j := 1; j < len(u); j++ {
59+
next := u[j] + u[j-1] + value
60+
u[j-1] = value
61+
value = next
62+
}
63+
u[len(u)-1] = value
64+
}
65+
66+
// cwrsDecode walks the recurrence row to recover signs and pulse magnitudes
67+
// from the uniformly decoded codeword index.
68+
func cwrsDecode(y []int, n, k int, index uint32, u []uint32) {
69+
for j := range n {
70+
p := u[k+1]
71+
negative := index >= p
72+
if negative {
73+
index -= p
74+
}
75+
76+
yj := k
77+
p = u[k]
78+
for p > index {
79+
k--
80+
p = u[k]
81+
}
82+
index -= p
83+
yj -= k
84+
if negative {
85+
y[j] = -yj
86+
} else {
87+
y[j] = yj
88+
}
89+
cwrsPreviousRow(u, k+2, 0)
90+
}
91+
}
92+
93+
// cwrsPreviousRow rewinds the recurrence after one coefficient has been
94+
// decoded, matching the row update used by the reference CWRS decoder.
95+
func cwrsPreviousRow(u []uint32, n int, value0 uint32) {
96+
value := value0
97+
for j := 1; j < n; j++ {
98+
next := u[j] - u[j-1] - value
99+
u[j-1] = value
100+
value = next
101+
}
102+
u[n-1] = value
103+
}

internal/celt/cwrs_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
package celt
5+
6+
import (
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestCWRSRows(t *testing.T) {
13+
assert.Equal(t, []uint32{0, 1, 3, 5, 7}, cwrsUrow(2, 3))
14+
15+
row := []uint32{1, 3, 5, 7}
16+
cwrsNextRow(row, 1)
17+
assert.Equal(t, []uint32{1, 5, 13, 25}, row)
18+
19+
cwrsPreviousRow(row, 4, 1)
20+
assert.Equal(t, []uint32{1, 3, 5, 7}, row)
21+
}
22+
23+
func TestCWRSDecode(t *testing.T) {
24+
y := []int{99, 99, 99}
25+
decodePulses(y, len(y), 0, nil)
26+
assert.Equal(t, []int{0, 0, 0}, y)
27+
28+
row := cwrsUrow(3, 2)
29+
cwrsDecode(y, len(y), 2, 0, row)
30+
assert.Equal(t, []int{2, 0, 0}, y)
31+
32+
decoder := rangeDecoderWithCDFSymbol(0, cwrsUrow(3, 2)[2]+cwrsUrow(3, 2)[3])
33+
decodePulses(y, len(y), 2, &decoder)
34+
assert.Equal(t, []int{2, 0, 0}, y)
35+
}

0 commit comments

Comments
 (0)