diff --git a/decoder.go b/decoder.go index 866fe1d..8aa825e 100644 --- a/decoder.go +++ b/decoder.go @@ -9,6 +9,7 @@ import ( "github.com/pion/opus/internal/bitdepth" "github.com/pion/opus/internal/celt" + "github.com/pion/opus/internal/rangecoding" silkresample "github.com/pion/opus/internal/resample/silk" "github.com/pion/opus/internal/silk" ) @@ -19,6 +20,8 @@ const ( maxSilkFrameSampleCount = 320 maxCeltFrameSampleCount = 960 celtSampleRate = 48000 + hybridRedundantFrameSampleCount = celtSampleRate / 200 + hybridFadeSampleCount = celtSampleRate / 400 ) // Decoder decodes the Opus bitstream into PCM. @@ -27,18 +30,39 @@ type Decoder struct { silkBuffer []float32 celtDecoder celt.Decoder celtBuffer []float32 + rangeDecoder rangecoding.Decoder + rangeFinal uint32 previousMode configurationMode + previousRedundancy bool resampleBuffer []float32 resampleChannelIn [2][]float32 resampleChannelOut [2][]float32 silkResampler [2]silkresample.Resampler silkResamplerBandwidth Bandwidth silkResamplerChannels int + hybridSilkResampler [2]silkresample.Resampler + hybridSilkChannels int + silkRedundancyFades []silkRedundancyFade + silkCeltAdditions []silkCeltAddition floatBuffer []float32 sampleRate int channels int } +type silkRedundancyFade struct { + celtToSilk bool + audio []float32 + startSample int + frameSampleCount int + channelCount int +} + +type silkCeltAddition struct { + audio []float32 + startSample int + channelCount int +} + // NewDecoder creates a new Opus Decoder. func NewDecoder() Decoder { decoder, _ := NewDecoderWithOutput(BandwidthFullband.SampleRate(), 1) @@ -76,9 +100,19 @@ func (d *Decoder) Init(sampleRate, channels int) error { d.sampleRate = sampleRate d.channels = channels d.silkDecoder = silk.NewDecoder() + d.celtDecoder.Reset() + d.celtBuffer = d.celtBuffer[:0] + d.rangeDecoder = rangecoding.Decoder{} d.silkResampler = [2]silkresample.Resampler{} d.silkResamplerBandwidth = 0 d.silkResamplerChannels = 0 + d.hybridSilkResampler = [2]silkresample.Resampler{} + d.hybridSilkChannels = 0 + d.silkRedundancyFades = d.silkRedundancyFades[:0] + d.silkCeltAdditions = d.silkCeltAdditions[:0] + d.rangeFinal = 0 + d.previousMode = 0 + d.previousRedundancy = false return nil } @@ -150,6 +184,8 @@ func (d *Decoder) resampleSilkChannel( return nil } +// resetModeState applies the decoder resets required by RFC 6716 Section 4.5.2 +// before the first frame decoded in a new operating mode. func (d *Decoder) resetModeState(mode configurationMode) { if d.previousMode == mode { return @@ -157,17 +193,52 @@ func (d *Decoder) resetModeState(mode configurationMode) { switch mode { case configurationModeSilkOnly: - d.silkDecoder = silk.NewDecoder() + if d.previousMode == configurationModeCELTOnly { + d.silkDecoder = silk.NewDecoder() + } + if d.previousMode == configurationModeHybrid { + d.copyHybridSilkResamplerToSilk() + } case configurationModeCELTOnly: - d.celtDecoder.Reset() - clear(d.celtBuffer) + if !d.previousRedundancy { + d.celtDecoder.Reset() + clear(d.celtBuffer) + } case configurationModeHybrid: - d.silkDecoder = silk.NewDecoder() - d.celtDecoder.Reset() - clear(d.celtBuffer) + if d.previousMode == configurationModeCELTOnly { + d.silkDecoder = silk.NewDecoder() + d.hybridSilkResampler = [2]silkresample.Resampler{} + d.hybridSilkChannels = 0 + } + if d.previousMode == configurationModeSilkOnly { + d.copySilkResamplerToHybrid() + } + } +} + +// copySilkResamplerToHybrid preserves the WB SILK resampler history across the +// normatively continuous WB SILK -> Hybrid transition in RFC 6716 Section 4.5. +func (d *Decoder) copySilkResamplerToHybrid() { + if d.sampleRate != celtSampleRate || d.silkResamplerBandwidth != BandwidthWideband || d.silkResamplerChannels == 0 { + return } + for i := range d.hybridSilkResampler { + d.hybridSilkResampler[i].CopyStateFrom(&d.silkResampler[i]) + } + d.hybridSilkChannels = d.silkResamplerChannels +} - d.previousMode = mode +// copyHybridSilkResamplerToSilk preserves the same WB SILK history for the +// reverse Hybrid -> WB SILK transition described by RFC 6716 Section 4.5. +func (d *Decoder) copyHybridSilkResamplerToSilk() { + if d.sampleRate != celtSampleRate || d.hybridSilkChannels == 0 { + return + } + for i := range d.silkResampler { + d.silkResampler[i].CopyStateFrom(&d.hybridSilkResampler[i]) + } + d.silkResamplerBandwidth = BandwidthWideband + d.silkResamplerChannels = d.hybridSilkChannels } func (c Configuration) silkFrameSampleCount() int { @@ -200,11 +271,19 @@ func (c Configuration) celtFrameSampleCount() int { return int(int64(c.frameDuration().nanoseconds()) * int64(celtSampleRate) / 1000000000) } +func (c Configuration) hybridFrameSampleCount() int { + if c.mode() != configurationModeHybrid { + return 0 + } + + return int(int64(c.frameDuration().nanoseconds()) * int64(celtSampleRate) / 1000000000) +} + func (c Configuration) decodedSampleRate() int { switch c.mode() { case configurationModeSilkOnly: return c.bandwidth().SampleRate() - case configurationModeCELTOnly: + case configurationModeCELTOnly, configurationModeHybrid: return celtSampleRate default: return 0 @@ -425,9 +504,12 @@ func parsePacketFrames(in []byte, tocHeader tableOfContentsHeader) ([][]byte, er } } -func (d *Decoder) decode(in []byte, out []float32) (bandwidth Bandwidth, isStereo bool, sampleCount int, err error) { +func (d *Decoder) decode( + in []byte, + out []float32, +) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { if len(in) < 1 { - return 0, false, 0, errTooShortForTableOfContentsHeader + return 0, false, 0, 0, errTooShortForTableOfContentsHeader } tocHeader := tableOfContentsHeader(in[0]) @@ -435,7 +517,7 @@ func (d *Decoder) decode(in []byte, out []float32) (bandwidth Bandwidth, isStere encodedFrames, err := parsePacketFrames(in, tocHeader) if err != nil { - return 0, false, 0, err + return 0, false, 0, 0, err } switch cfg.mode() { @@ -446,26 +528,444 @@ func (d *Decoder) decode(in []byte, out []float32) (bandwidth Bandwidth, isStere case configurationModeCELTOnly: d.resetModeState(configurationModeCELTOnly) - return 0, false, 0, fmt.Errorf("%w: %d", errUnsupportedConfigurationMode, cfg.mode()) + return d.decodeCeltFrames(cfg, tocHeader, encodedFrames, out) case configurationModeHybrid: d.resetModeState(configurationModeHybrid) - return 0, false, 0, fmt.Errorf("%w: %d", errUnsupportedConfigurationMode, cfg.mode()) + return d.decodeHybridFrames(cfg, tocHeader, encodedFrames, out) default: - return 0, false, 0, fmt.Errorf("%w: %d", errUnsupportedConfigurationMode, cfg.mode()) + return 0, false, 0, 0, fmt.Errorf("%w: %d", errUnsupportedConfigurationMode, cfg.mode()) + } +} + +// decodeCeltFrames decodes the CELT-only path at CELT's internal 48 kHz rate. +func (d *Decoder) decodeCeltFrames( + cfg Configuration, + tocHeader tableOfContentsHeader, + encodedFrames [][]byte, + out []float32, +) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { + frameSampleCount := cfg.celtFrameSampleCount() + streamChannelCount := 1 + if tocHeader.isStereo() { + streamChannelCount = 2 + } + decodedChannelCount = d.channels + if decodedChannelCount == 0 { + decodedChannelCount = streamChannelCount + } + requiredSamples := frameSampleCount * len(encodedFrames) * decodedChannelCount + if cap(out) < requiredSamples { + d.silkBuffer = make([]float32, requiredSamples) + out = d.silkBuffer + } + out = out[:requiredSamples] + for i := range out { + out[i] = 0 + } + + startBand, endBand, err := d.celtDecoder.Mode().BandRangeForSampleRate(cfg.bandwidth().SampleRate()) + if err != nil { + return 0, false, 0, 0, err + } + frameOutputSamples := frameSampleCount * decodedChannelCount + for i, encodedFrame := range encodedFrames { + frameOut := out[i*frameOutputSamples : (i+1)*frameOutputSamples] + if err = d.celtDecoder.Decode( + encodedFrame, + frameOut, + tocHeader.isStereo(), + decodedChannelCount, + frameSampleCount, + startBand, + endBand, + ); err != nil { + return 0, false, 0, 0, err + } + d.previousMode = configurationModeCELTOnly + d.previousRedundancy = false + if len(encodedFrame) <= 1 { + d.rangeFinal = 0 + } else { + d.rangeFinal = d.celtDecoder.FinalRange() + } + } + + return cfg.bandwidth(), tocHeader.isStereo(), requiredSamples, decodedChannelCount, nil +} + +// decodeHybridFrames combines the SILK and CELT layers for Hybrid packets. +func (d *Decoder) decodeHybridFrames( + cfg Configuration, + tocHeader tableOfContentsHeader, + encodedFrames [][]byte, + out []float32, +) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { + frameSampleCount := cfg.hybridFrameSampleCount() + streamChannelCount := 1 + if tocHeader.isStereo() { + streamChannelCount = 2 + } + decodedChannelCount = d.channels + if decodedChannelCount == 0 { + decodedChannelCount = streamChannelCount + } + requiredSamples := frameSampleCount * len(encodedFrames) * decodedChannelCount + if cap(out) < requiredSamples { + d.silkBuffer = make([]float32, requiredSamples) + out = d.silkBuffer + } + out = out[:requiredSamples] + for i := range out { + out[i] = 0 + } + + startBand, endBand, err := d.celtDecoder.Mode().HybridBandRange(cfg.bandwidth().SampleRate()) + if err != nil { + return 0, false, 0, 0, err + } + frameOutputSamples := frameSampleCount * decodedChannelCount + silkSamplesPerChannel := frameSampleCount * BandwidthWideband.SampleRate() / celtSampleRate + for i, encodedFrame := range encodedFrames { + frameOut := out[i*frameOutputSamples : (i+1)*frameOutputSamples] + if err = d.decodeHybridFrame( + encodedFrame, + frameOut, + tocHeader.isStereo(), + streamChannelCount, + decodedChannelCount, + frameSampleCount, + silkSamplesPerChannel, + cfg.frameDuration().nanoseconds(), + startBand, + endBand, + ); err != nil { + return 0, false, 0, 0, err + } + } + + return cfg.bandwidth(), tocHeader.isStereo(), requiredSamples, decodedChannelCount, nil +} + +type hybridRedundancy struct { + present bool + celtToSilk bool + celtDataLen int + endBand int + data []byte + audio []float32 + rng uint32 +} + +// decodeHybridFrame follows RFC 6716 Sections 4.5.1 and 4.5.2 for one Hybrid +// frame: decode shared-range SILK, split optional CELT redundancy, decode CELT, +// then apply the required transition cross-lap when redundancy is present. +// +//nolint:cyclop +func (d *Decoder) decodeHybridFrame( + encodedFrame []byte, + out []float32, + isStereo bool, + streamChannelCount int, + outputChannelCount int, + frameSampleCount int, + silkSamplesPerChannel int, + frameNanoseconds int, + startBand int, + endBand int, +) error { + d.rangeDecoder.Init(encodedFrame) + + silkInternal := make([]float32, silkSamplesPerChannel*streamChannelCount) + if err := d.silkDecoder.DecodeWithRange( + &d.rangeDecoder, + silkInternal, + isStereo, + frameNanoseconds, + silk.Bandwidth(BandwidthWideband), + ); err != nil { + return err + } + + var err error + redundancy := d.decodeHybridRedundancyHeader(encodedFrame) + if redundancy.present && redundancy.celtToSilk { + if err = d.decodeHybridRedundantFrame(&redundancy, isStereo, outputChannelCount, endBand); err != nil { + return err + } + } + if d.previousMode != configurationModeHybrid && d.previousMode != 0 && !d.previousRedundancy { + d.celtDecoder.Reset() + clear(d.celtBuffer) + } + if err = d.celtDecoder.DecodeWithRange( + encodedFrame[:redundancy.celtDataLen], + out, + isStereo, + outputChannelCount, + frameSampleCount, + startBand, + endBand, + &d.rangeDecoder, + ); err != nil { + return err + } + + silk48 := make([]float32, frameSampleCount*streamChannelCount) + if err = d.resampleHybridSilkTo48(silkInternal, silk48, streamChannelCount); err != nil { + return err + } + d.addHybridSilk(out, silk48, streamChannelCount, outputChannelCount, frameSampleCount) + if redundancy.present && !redundancy.celtToSilk { + d.celtDecoder.Reset() + clear(d.celtBuffer) + if err = d.decodeHybridRedundantFrame(&redundancy, isStereo, outputChannelCount, endBand); err != nil { + return err + } + fadeStart := (frameSampleCount - hybridFadeSampleCount) * outputChannelCount + redundantStart := hybridFadeSampleCount * outputChannelCount + celt.SmoothFade( + out[fadeStart:], + redundancy.audio[redundantStart:], + out[fadeStart:], + hybridFadeSampleCount, + outputChannelCount, + ) } + if redundancy.present && redundancy.celtToSilk { + for sample := range hybridFadeSampleCount { + for channel := range outputChannelCount { + index := sample*outputChannelCount + channel + out[index] = redundancy.audio[index] + } + } + fadeStart := hybridFadeSampleCount * outputChannelCount + celt.SmoothFade( + redundancy.audio[fadeStart:], + out[fadeStart:], + out[fadeStart:], + hybridFadeSampleCount, + outputChannelCount, + ) + } + if len(encodedFrame) <= 1 { + d.rangeFinal = 0 + } else { + d.rangeFinal = d.rangeDecoder.FinalRange() ^ redundancy.rng + } + d.previousMode = configurationModeHybrid + d.previousRedundancy = redundancy.present && !redundancy.celtToSilk + + return nil +} + +// decodeHybridRedundancyHeader parses the Hybrid transition side information +// from RFC 6716 Sections 4.5.1.1 through 4.5.1.3. +// +//nolint:gosec +func (d *Decoder) decodeHybridRedundancyHeader( + encodedFrame []byte, +) hybridRedundancy { + redundancy := hybridRedundancy{celtDataLen: len(encodedFrame)} + if int(d.rangeDecoder.Tell())+17+20 > 8*len(encodedFrame) { + return redundancy + } + if d.rangeDecoder.DecodeSymbolLogP(12) == 0 { + return redundancy + } + + celtToSilk := d.rangeDecoder.DecodeSymbolLogP(1) != 0 + redundancyBytesRaw, _ := d.rangeDecoder.DecodeUniform(256) + redundancyBytes := int(redundancyBytesRaw) + 2 + redundancy.celtDataLen -= redundancyBytes + if redundancy.celtDataLen < 0 || redundancy.celtDataLen*8 < int(d.rangeDecoder.Tell()) { + return hybridRedundancy{celtDataLen: len(encodedFrame)} + } + d.rangeDecoder.SetStorageSize(redundancy.celtDataLen) + redundancy.present = true + redundancy.celtToSilk = celtToSilk + redundancy.data = encodedFrame[redundancy.celtDataLen:] + + return redundancy +} + +// decodeSilkOnlyRedundancyHeader parses the SILK-only variant of the transition +// side information defined by RFC 6716 Sections 4.5.1.1 and 4.5.1.2. +// +//nolint:gosec +func (d *Decoder) decodeSilkOnlyRedundancyHeader( + encodedFrame []byte, + bandwidth Bandwidth, +) (hybridRedundancy, error) { + redundancy := hybridRedundancy{celtDataLen: len(encodedFrame)} + if int(d.rangeDecoder.Tell())+17 > 8*len(encodedFrame) { + return redundancy, nil + } + + celtToSilk := d.rangeDecoder.DecodeSymbolLogP(1) != 0 + redundancyBytes := len(encodedFrame) - int((d.rangeDecoder.Tell()+7)>>3) + redundancy.celtDataLen -= redundancyBytes + if redundancyBytes <= 0 || redundancy.celtDataLen < 0 || redundancy.celtDataLen*8 < int(d.rangeDecoder.Tell()) { + return hybridRedundancy{celtDataLen: len(encodedFrame)}, nil + } + d.rangeDecoder.SetStorageSize(redundancy.celtDataLen) + + endBand, err := d.celtEndBandForSilkBandwidth(bandwidth) + if err != nil { + return redundancy, err + } + redundancy.present = true + redundancy.celtToSilk = celtToSilk + redundancy.data = encodedFrame[redundancy.celtDataLen:] + redundancy.endBand = endBand + + return redundancy, nil +} + +// celtEndBandForSilkBandwidth selects the redundant CELT bandwidth required by +// RFC 6716 Section 4.5.1.4; MB SILK transitions use WB CELT bandwidth. +func (d *Decoder) celtEndBandForSilkBandwidth(bandwidth Bandwidth) (int, error) { + sampleRate := bandwidth.SampleRate() + if bandwidth == BandwidthMediumband { + sampleRate = BandwidthWideband.SampleRate() + } + _, endBand, err := d.celtDecoder.Mode().BandRangeForSampleRate(sampleRate) + + return endBand, err } +// decodeHybridRedundantFrame decodes the fixed 5 ms redundant CELT frame from +// RFC 6716 Section 4.5.1.4. +func (d *Decoder) decodeHybridRedundantFrame( + redundancy *hybridRedundancy, + isStereo bool, + outputChannelCount int, + endBand int, +) error { + redundancy.audio = make([]float32, hybridRedundantFrameSampleCount*outputChannelCount) + if err := d.celtDecoder.Decode( + redundancy.data, + redundancy.audio, + isStereo, + outputChannelCount, + hybridRedundantFrameSampleCount, + 0, + endBand, + ); err != nil { + return err + } + redundancy.rng = d.celtDecoder.FinalRange() + + return nil +} + +// resampleHybridSilkTo48 lifts the Hybrid packet's WB SILK layer to the 48 kHz +// CELT domain before the two layers are summed. +func (d *Decoder) resampleHybridSilkTo48(in []float32, out []float32, channelCount int) error { + if d.hybridSilkChannels == 0 { + for i := range d.hybridSilkResampler { + if err := d.hybridSilkResampler[i].Init(BandwidthWideband.SampleRate(), celtSampleRate); err != nil { + return err + } + } + } + if channelCount == 2 && d.hybridSilkChannels == 1 { + d.hybridSilkResampler[1].CopyStateFrom(&d.hybridSilkResampler[0]) + } + d.hybridSilkChannels = channelCount + + samplesPerChannel := len(in) / channelCount + resampledSamplesPerChannel := len(out) / channelCount + for channelIndex := range channelCount { + if err := d.resampleHybridSilkChannel( + in, + out, + channelIndex, + channelCount, + samplesPerChannel, + resampledSamplesPerChannel, + ); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder) resampleHybridSilkChannel( + in []float32, + out []float32, + channelIndex, channelCount, samplesPerChannel, resampledSamplesPerChannel int, +) error { + if cap(d.resampleChannelIn[channelIndex]) < samplesPerChannel { + d.resampleChannelIn[channelIndex] = make([]float32, samplesPerChannel) + } + if cap(d.resampleChannelOut[channelIndex]) < resampledSamplesPerChannel { + d.resampleChannelOut[channelIndex] = make([]float32, resampledSamplesPerChannel) + } + channelIn := d.resampleChannelIn[channelIndex][:samplesPerChannel] + channelOut := d.resampleChannelOut[channelIndex][:resampledSamplesPerChannel] + for i := range samplesPerChannel { + channelIn[i] = in[(i*channelCount)+channelIndex] + } + if err := d.hybridSilkResampler[channelIndex].Resample(channelIn, channelOut); err != nil { + return err + } + for i := range resampledSamplesPerChannel { + out[(i*channelCount)+channelIndex] = channelOut[i] + } + + return nil +} + +// addHybridSilk combines the decoded WB SILK contribution with the CELT layer +// after both are represented at 48 kHz. +func (d *Decoder) addHybridSilk( + out []float32, + silk48 []float32, + streamChannelCount int, + outputChannelCount int, + samplesPerChannel int, +) { + for i := range silk48 { + silk48[i] = float32(bitdepth.Float32ToSigned16(silk48[i])) / 32768 + } + for sample := range samplesPerChannel { + silkIndex := sample * streamChannelCount + outIndex := sample * outputChannelCount + switch { + case streamChannelCount == outputChannelCount: + for channel := range outputChannelCount { + out[outIndex+channel] += silk48[silkIndex+channel] + } + case streamChannelCount == 1 && outputChannelCount == 2: + out[outIndex] += silk48[silkIndex] + out[outIndex+1] += silk48[silkIndex] + case streamChannelCount == 2 && outputChannelCount == 1: + out[outIndex] += 0.5 * (silk48[silkIndex] + silk48[silkIndex+1]) + } + } +} + +// decodeSilkFrames handles ordinary SILK packets plus the redundant CELT side +// data that RFC 6716 Section 4.5.1 allows on mode transitions. +// +//nolint:cyclop func (d *Decoder) decodeSilkFrames( cfg Configuration, tocHeader tableOfContentsHeader, encodedFrames [][]byte, out []float32, -) (bandwidth Bandwidth, isStereo bool, sampleCount int, err error) { - frameSampleCount := cfg.silkFrameSampleCount() +) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { + frameSamplesPerChannel := cfg.silkFrameSampleCount() + frameSampleCount := frameSamplesPerChannel + decodedChannelCount = 1 if tocHeader.isStereo() { frameSampleCount *= 2 + decodedChannelCount = 2 } + d.silkRedundancyFades = d.silkRedundancyFades[:0] + d.silkCeltAdditions = d.silkCeltAdditions[:0] requiredSamples := frameSampleCount * len(encodedFrames) if cap(out) < requiredSamples { d.silkBuffer = make([]float32, requiredSamples) @@ -477,22 +977,81 @@ func (d *Decoder) decodeSilkFrames( } for i, encodedFrame := range encodedFrames { + previousMode := d.previousMode + previousRedundancy := d.previousRedundancy frameOut := out[i*frameSampleCount : (i+1)*frameSampleCount] - err := d.silkDecoder.Decode( - encodedFrame, + d.rangeDecoder.Init(encodedFrame) + err := d.silkDecoder.DecodeWithRange( + &d.rangeDecoder, frameOut, tocHeader.isStereo(), cfg.frameDuration().nanoseconds(), silk.Bandwidth(cfg.bandwidth()), ) if err != nil { - return 0, false, 0, err + return 0, false, 0, 0, err + } + redundancy, err := d.decodeSilkOnlyRedundancyHeader(encodedFrame, cfg.bandwidth()) + if err != nil { + return 0, false, 0, 0, err + } + if redundancy.present { + if !redundancy.celtToSilk { + d.celtDecoder.Reset() + clear(d.celtBuffer) + } + if err = d.decodeHybridRedundantFrame( + &redundancy, + tocHeader.isStereo(), + decodedChannelCount, + redundancy.endBand, + ); err != nil { + return 0, false, 0, 0, err + } + d.silkRedundancyFades = append(d.silkRedundancyFades, silkRedundancyFade{ + celtToSilk: redundancy.celtToSilk, + audio: redundancy.audio, + startSample: i * frameSamplesPerChannel * celtSampleRate / cfg.bandwidth().SampleRate(), + frameSampleCount: frameSamplesPerChannel * celtSampleRate / cfg.bandwidth().SampleRate(), + channelCount: decodedChannelCount, + }) + } + if previousMode == configurationModeHybrid && + (!redundancy.present || !redundancy.celtToSilk || !previousRedundancy) { + endBand, err := d.celtEndBandForSilkBandwidth(cfg.bandwidth()) + if err != nil { + return 0, false, 0, 0, err + } + transitionAudio := make([]float32, hybridFadeSampleCount*decodedChannelCount) + if err = d.celtDecoder.Decode( + []byte{0xff, 0xff}, + transitionAudio, + tocHeader.isStereo(), + decodedChannelCount, + hybridFadeSampleCount, + 0, + endBand, + ); err != nil { + return 0, false, 0, 0, err + } + d.silkCeltAdditions = append(d.silkCeltAdditions, silkCeltAddition{ + audio: transitionAudio, + startSample: i * frameSamplesPerChannel * celtSampleRate / cfg.bandwidth().SampleRate(), + channelCount: decodedChannelCount, + }) } + if len(encodedFrame) <= 1 { + d.rangeFinal = 0 + } else { + d.rangeFinal = d.rangeDecoder.FinalRange() ^ redundancy.rng + } + d.previousMode = configurationModeSilkOnly + d.previousRedundancy = redundancy.present && !redundancy.celtToSilk } sampleCount = requiredSamples - return cfg.bandwidth(), tocHeader.isStereo(), sampleCount, nil + return cfg.bandwidth(), tocHeader.isStereo(), sampleCount, decodedChannelCount, nil } func (d *Decoder) decodeToFloat32( @@ -506,35 +1065,97 @@ func (d *Decoder) decodeToFloat32( return 0, 0, false, errInvalidChannelCount } - bandwidth, isStereo, sampleCount, err := d.decode(in, d.silkBuffer) + bandwidth, isStereo, sampleCount, decodedChannelCount, err := d.decode(in, d.silkBuffer) if err != nil { return 0, 0, false, err } - channelCount := 1 - if isStereo { - channelCount = 2 - } - - samplesPerChannel = (sampleCount / channelCount) * d.sampleRate / bandwidth.SampleRate() - requiredSamples := samplesPerChannel * channelCount + samplesPerChannel = (sampleCount / decodedChannelCount) * d.sampleRate / bandwidth.SampleRate() + requiredSamples := samplesPerChannel * decodedChannelCount if cap(d.resampleBuffer) < requiredSamples { d.resampleBuffer = make([]float32, requiredSamples) } d.resampleBuffer = d.resampleBuffer[:requiredSamples] - if err = d.resampleSilk(d.silkBuffer[:sampleCount], d.resampleBuffer, channelCount, bandwidth); err != nil { - return 0, 0, false, err + if d.sampleRate == bandwidth.SampleRate() { + copy(d.resampleBuffer, d.silkBuffer[:sampleCount]) + } else { + if err = d.resampleSilk(d.silkBuffer[:sampleCount], d.resampleBuffer, decodedChannelCount, bandwidth); err != nil { + return 0, 0, false, err + } } + d.applySilkRedundancyFades(decodedChannelCount) if len(out) < samplesPerChannel*d.channels { return 0, 0, false, errOutBufferTooSmall } - d.copyResampledSamples(out, channelCount) + d.copyResampledSamples(out, decodedChannelCount) return samplesPerChannel, bandwidth, isStereo, nil } +// applySilkRedundancyFades applies the leading/trailing 2.5 ms cross-laps from +// RFC 6716 Section 4.5.1.4 after SILK output has been resampled to 48 kHz. +// +//nolint:cyclop +func (d *Decoder) applySilkRedundancyFades(channelCount int) { + fades := d.silkRedundancyFades + additions := d.silkCeltAdditions + d.silkRedundancyFades = d.silkRedundancyFades[:0] + d.silkCeltAdditions = d.silkCeltAdditions[:0] + if d.sampleRate != celtSampleRate { + return + } + for _, addition := range additions { + if addition.channelCount != channelCount { + continue + } + start := addition.startSample * channelCount + if start < 0 || start+len(addition.audio) > len(d.resampleBuffer) { + continue + } + for i, sample := range addition.audio { + d.resampleBuffer[start+i] += sample + } + } + for _, fade := range fades { + if fade.channelCount != channelCount { + continue + } + frameStart := fade.startSample * channelCount + if fade.celtToSilk { + copyCount := hybridFadeSampleCount * channelCount + if frameStart+2*copyCount > len(d.resampleBuffer) || copyCount > len(fade.audio) { + continue + } + copy(d.resampleBuffer[frameStart:frameStart+copyCount], fade.audio[:copyCount]) + celt.SmoothFade( + fade.audio[copyCount:], + d.resampleBuffer[frameStart+copyCount:], + d.resampleBuffer[frameStart+copyCount:], + hybridFadeSampleCount, + channelCount, + ) + + continue + } + + fadeStart := (fade.startSample + fade.frameSampleCount - hybridFadeSampleCount) * channelCount + redundantStart := hybridFadeSampleCount * channelCount + if fadeStart < 0 || fadeStart+hybridFadeSampleCount*channelCount > len(d.resampleBuffer) || + redundantStart+hybridFadeSampleCount*channelCount > len(fade.audio) { + continue + } + celt.SmoothFade( + d.resampleBuffer[fadeStart:], + fade.audio[redundantStart:], + d.resampleBuffer[fadeStart:], + hybridFadeSampleCount, + channelCount, + ) + } +} + func (d *Decoder) copyResampledSamples(out []float32, channelCount int) { outIndex := 0 for i := 0; i < len(d.resampleBuffer); i += channelCount { diff --git a/decoder_test.go b/decoder_test.go index 0f53ec2..06d2941 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -123,6 +123,29 @@ func TestNewDecoderWithOutput(t *testing.T) { assert.ErrorIs(t, err, errInvalidChannelCount) } +func TestInitResetsCeltState(t *testing.T) { + decoder := NewDecoder() + _, stereo, sampleCount, decodedChannelCount, err := decoder.decode( + []byte{byte(16<<3) | byte(frameCodeOneFrame), 0xff, 0xff}, + nil, + ) + assert.NoError(t, err) + assert.False(t, stereo) + assert.Positive(t, sampleCount) + assert.Equal(t, 1, decodedChannelCount) + assert.NotZero(t, decoder.celtDecoder.FinalRange()) + + decoder.celtBuffer = []float32{1} + decoder.rangeFinal = 42 + + err = decoder.Init(48000, 1) + + assert.NoError(t, err) + assert.Zero(t, decoder.celtDecoder.FinalRange()) + assert.Empty(t, decoder.celtBuffer) + assert.Zero(t, decoder.rangeFinal) +} + func TestDecodeToFloat32(t *testing.T) { decoder, err := NewDecoderWithOutput(16000, 2) assert.NoError(t, err) @@ -159,7 +182,7 @@ func TestDecodeSilkFrameDurations(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { decoder := NewDecoder() - _, _, _, err := decoder.decode([]byte{byte(test.configuration<<3) | byte(frameCodeOneFrame)}, nil) + _, _, _, _, err := decoder.decode([]byte{byte(test.configuration<<3) | byte(frameCodeOneFrame)}, nil) assert.NoError(t, err) assert.Len(t, decoder.silkBuffer, test.sampleCount) }) @@ -189,29 +212,232 @@ func TestDecodedSampleRate(t *testing.T) { assert.Equal(t, 16000, Configuration(8).decodedSampleRate()) assert.Equal(t, celtSampleRate, Configuration(16).decodedSampleRate()) assert.Equal(t, celtSampleRate, Configuration(31).decodedSampleRate()) - assert.Equal(t, 0, Configuration(12).decodedSampleRate()) + assert.Equal(t, celtSampleRate, Configuration(12).decodedSampleRate()) } -func TestDecodeCeltOnlyStillUnsupported(t *testing.T) { +func TestDecodeCeltOnly(t *testing.T) { decoder := NewDecoder() - bandwidth, isStereo, sampleCount, err := decoder.decode([]byte{byte(16<<3) | byte(frameCodeOneFrame)}, nil) + bandwidth, isStereo, sampleCount, _, err := decoder.decode([]byte{byte(16<<3) | byte(frameCodeOneFrame)}, nil) - assert.ErrorIs(t, err, errUnsupportedConfigurationMode) - assert.Zero(t, bandwidth) + assert.NoError(t, err) + assert.Equal(t, BandwidthNarrowband, bandwidth) assert.False(t, isStereo) - assert.Zero(t, sampleCount) + assert.Equal(t, 120, sampleCount) + assert.Zero(t, decoder.rangeFinal) assert.Equal(t, configurationModeCELTOnly, decoder.previousMode) } -func TestDecodeHybridStillUnsupported(t *testing.T) { +func TestDecodeHybrid(t *testing.T) { decoder := NewDecoder() - bandwidth, isStereo, sampleCount, err := decoder.decode([]byte{byte(12<<3) | byte(frameCodeOneFrame)}, nil) + bandwidth, isStereo, sampleCount, _, err := decoder.decode([]byte{byte(12<<3) | byte(frameCodeOneFrame)}, nil) - assert.ErrorIs(t, err, errUnsupportedConfigurationMode) - assert.Zero(t, bandwidth) + assert.NoError(t, err) + assert.Equal(t, BandwidthSuperwideband, bandwidth) assert.False(t, isStereo) - assert.Zero(t, sampleCount) + assert.Equal(t, 480, sampleCount) assert.Equal(t, configurationModeHybrid, decoder.previousMode) } + +func TestResetModeStateCopiesSilkResamplerAcrossHybridTransitions(t *testing.T) { + decoder := NewDecoder() + assert.NoError(t, decoder.silkResampler[0].Init(BandwidthWideband.SampleRate(), celtSampleRate)) + decoder.silkResamplerBandwidth = BandwidthWideband + decoder.silkResamplerChannels = 1 + decoder.previousMode = configurationModeSilkOnly + + decoder.resetModeState(configurationModeHybrid) + + assert.Equal(t, 1, decoder.hybridSilkChannels) + + decoder.previousMode = configurationModeHybrid + decoder.silkResamplerBandwidth = 0 + decoder.silkResamplerChannels = 0 + + decoder.resetModeState(configurationModeSilkOnly) + + assert.Equal(t, BandwidthWideband, decoder.silkResamplerBandwidth) + assert.Equal(t, 1, decoder.silkResamplerChannels) +} + +func TestDecodeHybridRedundancyHeader(t *testing.T) { + // These deterministic payloads drive the RFC 6716 Section 4.5.1 Hybrid + // redundancy parser through both valid transition directions. + for _, test := range []struct { + name string + frame []byte + celtToSilk bool + celtData int + }{ + { + name: "silk to celt", + frame: []byte{ + 255, 240, 20, 244, 193, 153, 114, 153, 174, 176, 113, 79, 114, 176, 30, 111, + 78, 251, 135, 241, 38, 152, 99, 238, 115, 216, 157, 159, 172, 149, 251, 21, + }, + celtToSilk: false, + celtData: 28, + }, + { + name: "celt to silk", + frame: []byte{ + 255, 248, 200, 183, 233, 107, 204, 67, 193, 228, 222, 25, 186, 202, 13, 26, + 79, 90, 131, 149, 102, 178, 120, 213, 146, 125, 92, 227, 83, 96, 134, 146, + }, + celtToSilk: true, + celtData: 5, + }, + } { + t.Run(test.name, func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder.Init(test.frame) + + redundancy := decoder.decodeHybridRedundancyHeader(test.frame) + + assert.True(t, redundancy.present) + assert.Equal(t, test.celtToSilk, redundancy.celtToSilk) + assert.Equal(t, test.celtData, redundancy.celtDataLen) + assert.Equal(t, test.frame[test.celtData:], redundancy.data) + }) + } +} + +func TestDecodeSilkOnlyRedundancyHeader(t *testing.T) { + decoder := NewDecoder() + frame := make([]byte, 32) + // Start after the SILK payload so the remaining bytes become redundant CELT + // data, as described by RFC 6716 Section 4.5.1.2. + decoder.rangeDecoder.SetInternalValues(frame, 32, 1<<30, 0) + + redundancy, err := decoder.decodeSilkOnlyRedundancyHeader(frame, BandwidthMediumband) + + assert.NoError(t, err) + assert.True(t, redundancy.present) + assert.True(t, redundancy.celtToSilk) + assert.Equal(t, 1, redundancy.celtDataLen) + assert.Equal(t, frame[1:], redundancy.data) + + _, expectedEndBand, err := decoder.celtDecoder.Mode().BandRangeForSampleRate(BandwidthWideband.SampleRate()) + assert.NoError(t, err) + assert.Equal(t, expectedEndBand, redundancy.endBand) +} + +func TestDecodeHybridRedundantFrame(t *testing.T) { + decoder := NewDecoder() + redundancy := hybridRedundancy{data: []byte{0xff, 0xff}} + endBand, err := decoder.celtEndBandForSilkBandwidth(BandwidthWideband) + assert.NoError(t, err) + + err = decoder.decodeHybridRedundantFrame(&redundancy, false, 1, endBand) + + assert.NoError(t, err) + assert.Len(t, redundancy.audio, hybridRedundantFrameSampleCount) + assert.NotZero(t, redundancy.rng) +} + +func TestAddHybridSilkMapsChannels(t *testing.T) { + for _, test := range []struct { + name string + streamChannelCount int + outputChannelCount int + silk48 []float32 + expected []float32 + }{ + { + name: "mono", + streamChannelCount: 1, + outputChannelCount: 1, + silk48: []float32{0.25}, + expected: []float32{0.25}, + }, + { + name: "mono to stereo", + streamChannelCount: 1, + outputChannelCount: 2, + silk48: []float32{0.25}, + expected: []float32{0.25, 0.25}, + }, + { + name: "stereo to mono", + streamChannelCount: 2, + outputChannelCount: 1, + silk48: []float32{0.25, 0.5}, + expected: []float32{0.375}, + }, + } { + t.Run(test.name, func(t *testing.T) { + decoder := NewDecoder() + out := make([]float32, len(test.expected)) + + decoder.addHybridSilk(out, test.silk48, test.streamChannelCount, test.outputChannelCount, 1) + + assert.Equal(t, test.expected, out) + }) + } +} + +func TestDecodeSilkFramesAddsHybridTransitionAudio(t *testing.T) { + decoder := NewDecoder() + decoder.previousMode = configurationModeHybrid + + bandwidth, isStereo, sampleCount, decodedChannelCount, err := decoder.decodeSilkFrames( + Configuration(8), + tableOfContentsHeader(byte(8<<3)|byte(frameCodeOneFrame)), + [][]byte{nil}, + nil, + ) + + assert.NoError(t, err) + assert.Equal(t, BandwidthWideband, bandwidth) + assert.False(t, isStereo) + assert.Equal(t, 160, sampleCount) + assert.Equal(t, 1, decodedChannelCount) + assert.Len(t, decoder.silkCeltAdditions, 1) + assert.Len(t, decoder.silkCeltAdditions[0].audio, hybridFadeSampleCount) +} + +func TestApplySilkRedundancyFades(t *testing.T) { + decoder := NewDecoder() + decoder.resampleBuffer = make([]float32, 600) + for i := range decoder.resampleBuffer { + decoder.resampleBuffer[i] = 0.25 + } + + leadingAudio := make([]float32, 2*hybridFadeSampleCount) + trailingAudio := make([]float32, 2*hybridFadeSampleCount) + for i := range leadingAudio { + leadingAudio[i] = 0.5 + trailingAudio[i] = 0.75 + } + decoder.silkCeltAdditions = append(decoder.silkCeltAdditions, silkCeltAddition{ + audio: []float32{0.125}, + startSample: 0, + channelCount: 1, + }) + decoder.silkRedundancyFades = append( + decoder.silkRedundancyFades, + silkRedundancyFade{ + celtToSilk: true, + audio: leadingAudio, + startSample: 1, + frameSampleCount: 2 * hybridFadeSampleCount, + channelCount: 1, + }, + silkRedundancyFade{ + audio: trailingAudio, + startSample: 2 * hybridFadeSampleCount, + frameSampleCount: 2 * hybridFadeSampleCount, + channelCount: 1, + }, + ) + + decoder.applySilkRedundancyFades(1) + + assert.Equal(t, float32(0.375), decoder.resampleBuffer[0]) + assert.Equal(t, float32(0.5), decoder.resampleBuffer[1]) + assert.NotEqual(t, float32(0.25), decoder.resampleBuffer[1+hybridFadeSampleCount]) + assert.NotEqual(t, float32(0.25), decoder.resampleBuffer[3*hybridFadeSampleCount+60]) + assert.Empty(t, decoder.silkCeltAdditions) + assert.Empty(t, decoder.silkRedundancyFades) +} diff --git a/internal/silk/decoder.go b/internal/silk/decoder.go index 32e0475..c9349dc 100644 --- a/internal/silk/decoder.go +++ b/internal/silk/decoder.go @@ -2361,6 +2361,36 @@ func (d *Decoder) Decode( isStereo bool, nanoseconds int, bandwidth Bandwidth, +) error { + d.rangeDecoder.Init(in) + + return d.decodeWithInitializedRange(out, isStereo, nanoseconds, bandwidth) +} + +// DecodeWithRange decodes one SILK frame from an Opus range decoder shared +// with the CELT layer, as required by RFC 6716 hybrid packets. +func (d *Decoder) DecodeWithRange( + rangeDecoder *rangecoding.Decoder, + out []float32, + isStereo bool, + nanoseconds int, + bandwidth Bandwidth, +) error { + if rangeDecoder == nil { + return errOutBufferTooSmall + } + d.rangeDecoder = *rangeDecoder + err := d.decodeWithInitializedRange(out, isStereo, nanoseconds, bandwidth) + *rangeDecoder = d.rangeDecoder + + return err +} + +func (d *Decoder) decodeWithInitializedRange( + out []float32, + isStereo bool, + nanoseconds int, + bandwidth Bandwidth, ) error { frameCount := silkFrameCount(nanoseconds) silkFrameNanoseconds := min(nanoseconds, nanoseconds20Ms) @@ -2378,8 +2408,6 @@ func (d *Decoder) Decode( return errOutBufferTooSmall } - d.rangeDecoder.Init(in) - midVoiceActivityDetected, midLowBitRateRedundancy := d.decodeHeaderBits(frameCount) if midLowBitRateRedundancy { return errUnsupportedSilkLowBitrateRedundancy diff --git a/internal/silk/decoder_test.go b/internal/silk/decoder_test.go index f83eba2..c4495c0 100644 --- a/internal/silk/decoder_test.go +++ b/internal/silk/decoder_test.go @@ -154,6 +154,23 @@ func TestDecodeBufferSize(t *testing.T) { assert.Equal(t, errOutBufferTooSmall, err) } +func TestDecodeWithRange(t *testing.T) { + decoder := NewDecoder() + assert.ErrorIs( + t, + decoder.DecodeWithRange(nil, make([]float32, 320), false, nanoseconds20Ms, BandwidthWideband), + errOutBufferTooSmall, + ) + + rangeDecoder := rangecoding.Decoder{} + rangeDecoder.Init(testSilkFrame()) + + err := decoder.DecodeWithRange(&rangeDecoder, make([]float32, 320), false, nanoseconds20Ms, BandwidthWideband) + + assert.NoError(t, err) + assert.NotZero(t, rangeDecoder.FinalRange()) +} + func TestNormalizeLineSpectralFrequencyStageOne(t *testing.T) { d := &Decoder{rangeDecoder: createRangeDecoder(testSilkFrame(), 47, 722810880, 387065757)} diff --git a/packet_test.go b/packet_test.go index 41ccf9d..aacb81a 100644 --- a/packet_test.go +++ b/packet_test.go @@ -115,7 +115,7 @@ func TestDecodeRejectsEmptyPacket(t *testing.T) { decoder := NewDecoder() - bandwidth, isStereo, sampleCount, err := decoder.decode(nil, make([]float32, 0)) + bandwidth, isStereo, sampleCount, _, err := decoder.decode(nil, make([]float32, 0)) require.Error(t, err) assert.Zero(t, bandwidth) @@ -207,7 +207,7 @@ func TestDecodePacketFrames(t *testing.T) { t.Parallel() decoder := NewDecoder() - _, _, _, err := decoder.decode([]byte{tocByte(frameCodeTwoEqualFrames) | 0b100}, decoder.silkBuffer) + _, _, _, _, err := decoder.decode([]byte{tocByte(frameCodeTwoEqualFrames) | 0b100}, decoder.silkBuffer) require.NoError(t, err) assert.Equal(t, maxSilkFrameSampleCount*4, len(decoder.silkBuffer))