diff -r a728056347ec Lib/test/test_wave.py --- a/Lib/test/test_wave.py Fri Nov 23 19:46:52 2012 +0200 +++ b/Lib/test/test_wave.py Sun Nov 25 18:02:39 2012 +0100 @@ -3,17 +3,42 @@ import wave import struct import unittest +import array +import math nchannels = 2 -sampwidth = 2 framerate = 8000 -nframes = 100 +nframes = 50 class TestWave(unittest.TestCase): def setUp(self): self.f = None + values = range(0, nframes*nchannels) + samples_b8 = array.array('b', values).tobytes() + samples_s16 = array.array('h', values).tobytes() + samples_i32 = array.array('i', values).tobytes() + + samples_f32 = array.array('f', values).tobytes() + samples_d64 = array.array('d', values).tobytes() + self.sampwidths_list = [ 1, 2, 4, 4, 8] + self.samples_list = [ + samples_b8, + samples_s16, + samples_i32, + samples_f32, + samples_d64 + ] + + self.wavformats_list = [ + wave.WAVE_FORMAT_PCM, + wave.WAVE_FORMAT_PCM, + wave.WAVE_FORMAT_PCM, + wave.WAVE_FORMAT_IEEE_FLOAT, + wave.WAVE_FORMAT_IEEE_FLOAT + ] + def tearDown(self): if self.f is not None: self.f.close() @@ -23,24 +48,27 @@ pass def test_it(self, test_rounding=False): - self.f = wave.open(TESTFN, 'wb') - self.f.setnchannels(nchannels) - self.f.setsampwidth(sampwidth) - if test_rounding: - self.f.setframerate(framerate - 0.1) - else: - self.f.setframerate(framerate) - self.f.setnframes(nframes) - output = b'\0' * nframes * nchannels * sampwidth - self.f.writeframes(output) - self.f.close() - - self.f = wave.open(TESTFN, 'rb') - self.assertEqual(nchannels, self.f.getnchannels()) - self.assertEqual(sampwidth, self.f.getsampwidth()) - self.assertEqual(framerate, self.f.getframerate()) - self.assertEqual(nframes, self.f.getnframes()) - self.assertEqual(self.f.readframes(nframes), output) + + for i, samples in enumerate(self.samples_list): + self.f = wave.open(TESTFN, 'wb') + self.f.setnchannels(nchannels) + self.f.setsampwidth(self.sampwidths_list[i]) + if self.wavformats_list[i] == wave.WAVE_FORMAT_IEEE_FLOAT: + self.f.setwavformat(wave.WAVE_FORMAT_IEEE_FLOAT) + if test_rounding: + self.f.setframerate(framerate - 0.1) + else: + self.f.setframerate(framerate) + self.f.setnframes(nframes) + self.f.writeframes(samples) + self.f.close() + + self.f = wave.open(TESTFN, 'rb') + self.assertEqual(nchannels, self.f.getnchannels()) + self.assertEqual(self.sampwidths_list[i], self.f.getsampwidth()) + self.assertEqual(framerate, self.f.getframerate()) + self.assertEqual(nframes, self.f.getnframes()) + self.assertEqual(self.f.readframes(nframes), samples) def test_fractional_framerate(self): """ @@ -50,6 +78,7 @@ self.test_it(test_rounding=True) def test_issue7681(self): + sampwidth = 2 self.f = wave.open(TESTFN, 'wb') self.f.setnchannels(nchannels) self.f.setsampwidth(sampwidth) diff -r a728056347ec Lib/wave.py --- a/Lib/wave.py Fri Nov 23 19:46:52 2012 +0200 +++ b/Lib/wave.py Sun Nov 25 18:02:39 2012 +0100 @@ -78,9 +78,30 @@ class Error(Exception): pass -WAVE_FORMAT_PCM = 0x0001 +# +# Sample and wave formats +# -_array_fmts = None, 'b', 'h', None, 'l' +WAVE_FORMAT_PCM = 0x0001 +WAVE_FORMAT_IEEE_FLOAT = 0x0003 + +SAMPLE_FORMAT_PCM8 = 0x0101 +SAMPLE_FORMAT_PCM16 = 0x0201 +SAMPLE_FORMAT_PCM24 = 0x0301 +SAMPLE_FORMAT_PCM32 = 0x0401 + +SAMPLE_FORMAT_FLT32 = 0x0403 +SAMPLE_FORMAT_FLT64 = 0x0803 + +_sample_array_fmts = { + SAMPLE_FORMAT_PCM8: 'b', + SAMPLE_FORMAT_PCM16: 'h', + SAMPLE_FORMAT_PCM24: None, + SAMPLE_FORMAT_PCM32: 'i', + + SAMPLE_FORMAT_FLT32: 'f', + SAMPLE_FORMAT_FLT64: 'd' + } # Determine endian-ness import struct @@ -195,7 +216,10 @@ def getsampwidth(self): return self._sampwidth - + + def getwavformat(self): + return self._wavformat + def getframerate(self): return self._framerate @@ -232,12 +256,16 @@ if nframes == 0: return b'' if self._sampwidth > 1 and big_endian: + + # TODO: problems may occur with 24bit PCM wave files on big endian systems + # ByteSwap only works for 16 and 32bit + # # unfortunately the fromfile() method does not take # something that only looks like a file object, so - # we have to reach into the innards of the chunk object + # we have to reach into the innards of the chunk object import array chunk = self._data_chunk - data = array.array(_array_fmts[self._sampwidth]) + data = array.array(_sample_array_fmts[self._sampformat]) nitems = nframes * self._nchannels if nitems * self._sampwidth > chunk.chunksize - chunk.size_read: nitems = (chunk.chunksize - chunk.size_read) // self._sampwidth @@ -251,6 +279,7 @@ data = data.tobytes() else: data = self._data_chunk.read(nframes * self._framesize) + if self._convert and data: data = self._convert(data) self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth) @@ -261,12 +290,19 @@ # def _read_fmt_chunk(self, chunk): - wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from(' 16: + extFmtsize = struct.unpack_from(' 0: + raise Warning('WAVE_FORMAT_IEEE_FLOAT should have a zero sized format chunk') + self._framesize = self._nchannels * self._sampwidth self._comptype = 'NONE' self._compname = 'not compressed' @@ -319,7 +355,9 @@ self._nframeswritten = 0 self._datawritten = 0 self._datalength = 0 - self._headerwritten = False + self._headerwritten = False + # assume WAVE_FORMAT_PCM as default for backward compatibility + self._wavformat = WAVE_FORMAT_PCM def __del__(self): self.close() @@ -342,7 +380,7 @@ def setsampwidth(self, sampwidth): if self._datawritten: raise Error('cannot change parameters after starting to write') - if sampwidth < 1 or sampwidth > 4: + if sampwidth < 1 or sampwidth > 8: raise Error('bad sample width') self._sampwidth = sampwidth @@ -351,6 +389,18 @@ raise Error('sample width not set') return self._sampwidth + def setwavformat(self, wavformat): + if self._datawritten: + raise Error('cannot change parameters after starting to write') + if wavformat != WAVE_FORMAT_PCM and wavformat != WAVE_FORMAT_IEEE_FLOAT: + raise Error('bad wave format') + self._wavformat = wavformat + + def getwavformat(self): + if not self._wavformat: + raise Error('wave format not set') + return self._wavformat + def setframerate(self, framerate): if self._datawritten: raise Error('cannot change parameters after starting to write') @@ -374,7 +424,7 @@ def setcomptype(self, comptype, compname): if self._datawritten: raise Error('cannot change parameters after starting to write') - if comptype not in ('NONE',): + if comptype not in ('NONE'): raise Error('unsupported compression type') self._comptype = comptype self._compname = compname @@ -386,7 +436,13 @@ return self._compname def setparams(self, params): - nchannels, sampwidth, framerate, nframes, comptype, compname = params + if len(params) == 6: + nchannels, sampwidth, framerate, nframes, comptype, compname = params + wavformat = WAVE_FORMAT_PCM + elif len(params) == 7: + nchannels, sampwidth, framerate, nframes, comptype, compname, wavformat = params + else: + raise Error('wrong number of parameters') if self._datawritten: raise Error('cannot change parameters after starting to write') self.setnchannels(nchannels) @@ -394,6 +450,7 @@ self.setframerate(framerate) self.setnframes(nframes) self.setcomptype(comptype, compname) + self.setwavformat(wavformat) def getparams(self): if not self._nchannels or not self._sampwidth or not self._framerate: @@ -415,12 +472,13 @@ def writeframesraw(self, data): self._ensure_header_written(len(data)) - nframes = len(data) // (self._sampwidth * self._nchannels) + nframes = len(data) // (self._sampwidth * self._nchannels) if self._convert: data = self._convert(data) if self._sampwidth > 1 and big_endian: import array - data = array.array(_array_fmts[self._sampwidth], data) + self._sampformat = self._formattag | (self._sampwidth<<8) + data = array.array(_sample_array_fmts[self._sampformat], data) data.byteswap() data.tofile(self._file) self._datawritten = self._datawritten + len(data) * self._sampwidth @@ -466,12 +524,22 @@ self._nframes = initlength // (self._nchannels * self._sampwidth) self._datalength = self._nframes * self._nchannels * self._sampwidth self._form_length_pos = self._file.tell() - self._file.write(struct.pack('