diff -r c41c68a18bb6 Doc/library/wave.rst --- a/Doc/library/wave.rst Thu Sep 05 18:02:31 2013 +0300 +++ b/Doc/library/wave.rst Sat Sep 07 00:40:07 2013 +0300 @@ -19,7 +19,7 @@ .. function:: open(file, mode=None) If *file* is a string, open the file by that name, otherwise treat it as a - seekable file-like object. *mode* can be: + file-like object. *mode* can be: ``'rb'`` Read only mode. @@ -43,6 +43,8 @@ ` or :meth:`Wave_write.close() ` method is called. + .. versionchanged:: 3.4 + Added support of unseekable files. .. function:: openfp(file, mode) @@ -154,7 +156,8 @@ .. method:: Wave_write.close() Make sure *nframes* is correct, and close the file if it was opened by - :mod:`wave`. This method is called upon object collection. + :mod:`wave`. This method is called upon object collection. Can raise an + exception if *nframes* is not correct and a file is not seekable. .. method:: Wave_write.setnchannels(n) @@ -208,7 +211,8 @@ .. method:: Wave_write.writeframes(data) - Write audio frames and make sure *nframes* is correct. + Write audio frames and make sure *nframes* is correct. Can raise an + exception if a file is not seekable. Note that it is invalid to set any parameters after calling :meth:`writeframes` diff -r c41c68a18bb6 Lib/test/test_wave.py --- a/Lib/test/test_wave.py Thu Sep 05 18:02:31 2013 +0300 +++ b/Lib/test/test_wave.py Sat Sep 07 00:40:07 2013 +0300 @@ -1,4 +1,5 @@ from test.support import TESTFN, unlink +import io import wave import pickle import unittest @@ -7,6 +8,15 @@ sampwidth = 2 framerate = 8000 nframes = 100 +frames = bytes(i & 0xff for i in range(nframes * nchannels * sampwidth)) + +class UnseekableIO(io.BytesIO): + def tell(self): + raise io.UnsupportedOperation + + def seek(self, *args, **kwargs): + raise io.UnsupportedOperation + class TestWave(unittest.TestCase): @@ -18,6 +28,14 @@ self.f.close() unlink(TESTFN) + def check_file(self, file, nchannels, sampwidth, framerate, nframes, frames): + with wave.open(file, 'rb') as f: + self.assertEqual(f.getnchannels(), nchannels) + self.assertEqual(f.getsampwidth(), sampwidth) + self.assertEqual(f.getframerate(), framerate) + self.assertEqual(f.getnframes(), nframes) + self.assertEqual(f.readframes(nframes), frames) + def test_it(self, test_rounding=False): self.f = wave.open(TESTFN, 'wb') self.f.setnchannels(nchannels) @@ -27,16 +45,113 @@ else: self.f.setframerate(framerate) self.f.setnframes(nframes) - output = b'\0' * nframes * nchannels * sampwidth - self.f.writeframes(output) + self.f.writeframes(frames) self.f.close() - self.f = wave.open(TESTFN, 'rb') - self.assertEqual(nchannels, self.f.getnchannels()) - self.assertEqual(sampwidth, self.f.getsampwidth()) - self.assertEqual(framerate, self.f.getframerate()) - self.assertEqual(nframes, self.f.getnframes()) - self.assertEqual(self.f.readframes(nframes), output) + self.check_file(TESTFN, + nchannels, sampwidth, framerate, nframes, frames) + + def test_incompleted_output(self): + file = io.BytesIO() + file.write(b'ababagalamaga') + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes) + f.writeframes(frames[:-1]) + + file.seek(0) + self.assertEqual(file.read(13), b'ababagalamaga') + self.check_file(file, + nchannels, sampwidth, framerate, nframes - 1, + frames[:-nchannels * sampwidth]) + + def test_multiple_output(self): + file = io.BytesIO() + file.write(b'ababagalamaga') + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes) + f.writeframes(frames[:-1]) + f.writeframes(frames[-1:]) + + file.seek(0) + self.assertEqual(file.read(13), b'ababagalamaga') + self.check_file(file, + nchannels, sampwidth, framerate, nframes, frames) + + def test_overflowed_output(self): + file = io.BytesIO() + file.write(b'ababagalamaga') + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes - 1) + f.writeframes(frames + b'\0') + + file.seek(0) + self.assertEqual(file.read(13), b'ababagalamaga') + self.check_file(file, + nchannels, sampwidth, framerate, nframes, frames) + + def test_unseekable_input(self): + file = io.BytesIO() + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes) + f.writeframes(frames) + + self.check_file(UnseekableIO(file.getvalue()), + nchannels, sampwidth, framerate, nframes, frames) + + def test_unseekable_output(self): + file = UnseekableIO() + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes) + f.writeframes(frames) + + self.check_file(io.BytesIO(file.getvalue()), + nchannels, sampwidth, framerate, nframes, frames) + + def test_unseekable_incompleted_output(self): + file = UnseekableIO() + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes) + with self.assertRaises(io.UnsupportedOperation): + f.writeframes(frames[:-1]) + with self.assertRaises(io.UnsupportedOperation): + f.close() + + self.check_file(io.BytesIO(file.getvalue()), + nchannels, sampwidth, framerate, nframes, frames[:-1]) + + def test_unseekable_overflowed_output(self): + file = UnseekableIO() + with wave.open(file, 'wb') as f: + f.setnchannels(nchannels) + f.setsampwidth(sampwidth) + f.setframerate(framerate) + f.setnframes(nframes - 1) + with self.assertRaises(io.UnsupportedOperation): + f.writeframes(frames) + with self.assertRaises(io.UnsupportedOperation): + f.close() + + self.check_file(io.BytesIO(file.getvalue()), + nchannels, sampwidth, framerate, nframes - 1, + frames[:(nframes - 1) * nchannels * sampwidth]) def test_fractional_framerate(self): """ diff -r c41c68a18bb6 Lib/wave.py --- a/Lib/wave.py Thu Sep 05 18:02:31 2013 +0300 +++ b/Lib/wave.py Sat Sep 07 00:40:07 2013 +0300 @@ -479,14 +479,18 @@ if not self._nframes: self._nframes = initlength // (self._nchannels * self._sampwidth) self._datalength = self._nframes * self._nchannels * self._sampwidth - self._form_length_pos = self._file.tell() + try: + self._form_length_pos = self._file.tell() + except (AttributeError, OSError): + self._form_length_pos = None self._file.write(struct.pack('