Skip to content

Commit 0c436e7

Browse files
committed
detect untimely calls to entry.get_blocks()
1 parent af6b2eb commit 0c436e7

3 files changed

Lines changed: 54 additions & 1 deletion

File tree

libarchive/entry.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def get_blocks(self, block_size=ffi.page_size):
148148
if r == 0:
149149
break
150150
yield buf.raw[0:r]
151+
self.__class__ = ConsumedArchiveEntry
151152

152153
@property
153154
def isblk(self):
@@ -387,3 +388,19 @@ def rdevminor(self, value):
387388
@property
388389
def format_name(self):
389390
return ffi.format_name(self._pointer)
391+
392+
393+
class ConsumedArchiveEntry(ArchiveEntry):
394+
395+
__slots__ = ()
396+
397+
def get_blocks(self, **kw):
398+
raise TypeError("the content of this entry has already been read")
399+
400+
401+
class PassedArchiveEntry(ArchiveEntry):
402+
403+
__slots__ = ()
404+
405+
def get_blocks(self, **kw):
406+
raise TypeError("this entry is passed, it's too late to read its content")

libarchive/read.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
ARCHIVE_EOF, OPEN_CALLBACK, READ_CALLBACK, CLOSE_CALLBACK, SEEK_CALLBACK,
88
NO_OPEN_CB, NO_CLOSE_CB, page_size,
99
)
10-
from .entry import ArchiveEntry
10+
from .entry import ArchiveEntry, PassedArchiveEntry
1111

1212

1313
class ArchiveRead:
@@ -26,6 +26,7 @@ def __iter__(self):
2626
if r == ARCHIVE_EOF:
2727
return
2828
yield entry
29+
entry.__class__ = PassedArchiveEntry
2930

3031
@property
3132
def bytes_read(self):

tests/test_entry.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from os.path import join
88
import unicodedata
99

10+
import pytest
11+
1012
from libarchive import memory_reader, memory_writer
13+
from libarchive.entry import ArchiveEntry, ConsumedArchiveEntry, PassedArchiveEntry
1114

1215
from . import data_dir, get_entries, get_tarinfos
1316

@@ -100,3 +103,35 @@ def check_entries(test_file, regen=False, ignore=''):
100103
if isinstance(d[key], text_type):
101104
d[key] = unicodedata.normalize('NFC', d[key])
102105
assert e1 == e2
106+
107+
108+
def test_the_life_cycle_of_archive_entries():
109+
"""Check that the `get_blocks` method only works on the current entry, and only once.
110+
"""
111+
# Create a test archive in memory
112+
buf = bytes(bytearray(10_000_000))
113+
with memory_writer(buf, 'gnutar') as archive:
114+
archive.add_files(
115+
'README.rst',
116+
'libarchive/__init__.py',
117+
'libarchive/entry.py',
118+
)
119+
# Read multiple entries of the test archive and check how the evolve
120+
with memory_reader(buf) as archive:
121+
archive_iter = iter(archive)
122+
entry1 = next(archive_iter)
123+
assert type(entry1) is ArchiveEntry
124+
for block in entry1.get_blocks():
125+
pass
126+
assert type(entry1) is ConsumedArchiveEntry
127+
with pytest.raises(TypeError):
128+
entry1.get_blocks()
129+
entry2 = next(archive_iter)
130+
assert type(entry2) is ArchiveEntry
131+
assert type(entry1) is PassedArchiveEntry
132+
with pytest.raises(TypeError):
133+
entry1.get_blocks()
134+
entry3 = next(archive_iter)
135+
assert type(entry3) is ArchiveEntry
136+
assert type(entry2) is PassedArchiveEntry
137+
assert type(entry1) is PassedArchiveEntry

0 commit comments

Comments
 (0)