The Context
My daily workflow largely consists of producing, styling, and circulating plots from a dataset to my advisor and collaborators. We use the C++ framework ROOT to generate and store histograms and I am writing my code in Python to take advantage of its Python bindings (PyROOT).
Since a ROOT file is the fundamental unit of our datasets, I wrote a simple context manager to facilitate the common task of opening a ROOT file, retrieving some histograms, and then closing the file.
import ROOT
class HistogramFile(object):
def __init__(self, filename):
self.filename = filename
def __enter__(self):
self.file = ROOT.TFile.Open(self.filename, 'read')
return self
def __exit__(self, exception_type, exception_value, traceback):
self.file.Close()
def get_histogram(self, name):
"""Return the histogram identified by name from the file.
"""
# The TFile::Get() method returns a pointer to an object stored in a ROOT file.
hist = self.file.Get(name)
if hist:
return hist
else:
raise RuntimeError('Unable to retrieve histogram named {0} from {1}'.format(name, self.filename))
This allows me to write the following snippet (imports of necessary modules implied)
f = ROOT.TFile.Open('dataset.root', 'read')
# Setup a canvas for plotting. The arguments are a name, an optional title, and the width and height in pixels.
canvas = ROOT.TCanvas('canvas', '', 500, 500)
hist = f.Get('electron_momentum')
hist.Draw()
canvas.SaveAs('plot.pdf')
f.Close()
in a more idiomatic fashion
with HistogramFile('dataset.root') as f:
canvas = ROOT.TCanvas('canvas', '', 500, 500)
hist = f.get_histogram('electron_momentum')
hist.Draw()
canvas.SaveAs('plot.pdf')
A dataset is often a collection of multiple ROOT files, so to make a plot I need to sum the histograms with the same name from each of its files together. The following snippet does not work, and demonstrated to me that the files must remain open in order to access their histograms.
with HistogramFile('dataset_part1.root') as f:
hist_1 = f.get_histogram('electron_momentum')
with HistogramFile('dataset_part2.root') as f:
hist_2 = f.get_histogram('electron_momentum')
# The next line causes a " *** Break *** segmentation violation"
hist_total = hist_1 + hist_2
After some searching, I discovered contextlib2's ExitStack, which allows me to programmatically handle a dynamic number of ROOT files.
import contextlib2
class Dataset(object):
def __init__(self, *filenames):
self.filenames = filenames
def __enter__(self):
with contextlib2.ExitStack() as stack:
self.files = [stack.enter_context(HistogramFile(fname) for fname in self.filenames]
self.close = stack.pop_all().close
return self
def __exit__(self, exception_type, exception_value, traceback):
self.close()
def get_histogram(self, name):
"""Return the sum of the histograms identified by name from all files.
"""
return sum(f.get_histogram(name) for f in self.files)
I now have a context manager of context managers that allows me to work with multiple files in a similar manner as before.
# A list of files that could be the result of globbing or os.listdir
dataset_files = ['dataset_part1.root', 'dataset_part2.root', 'dataset_part3.root']
with Dataset(*dataset_files) as dataset:
canvas = ROOT.TCanvas('canvas', '', 500, 500)
hist = dataset.get_histogram('electron_momentum')
hist.Draw()
canvas.SaveAs('plot.pdf')
The Issues
Are there problems with my code that I haven't addressed? Any subtleties or technicalities I have neglected?
I personally have not seen a context manager of context managers when perusing other people's code during my internet adventures. Is it suggestive of poor design choice? I mean, suppose I want to go one level higher in abstraction and design a new class which facilitates working with a collection of Dataset objects. How would I do that without making my code context managers all the way down?
1 Answer 1
Given both the documentation of Python 3 contextlib or contextlib2, I’d say your usage is pretty standard for the tools at play.
However, there is something bothering me a bit in your code:
def get_histogram(self, name): """Return the histogram identified by name from the file. """ # The TFile::Get() method returns a pointer to an object stored in a ROOT file. hist = self.file.Get(name) if hist: return hist else: raise RuntimeError('Unable to retrieve histogram named {0} from {1}'.format(name, self.filename))
Why raise a generic purpose RuntimeError
? If anyone wants to use your code and handle failures, they may catch more than it should.
As PEP8 says:
Derive exceptions from
Exception
rather thanBaseException
. Direct inheritance fromBaseException
is reserved for exceptions where catching them is almost always the wrong thing to do.Design exception hierarchies based on the distinctions that code catching the exceptions is likely to need, rather than the locations where the exceptions are raised. Aim to answer the question "What went wrong?" programmatically, rather than only stating that "A problem occurred" (see PEP 3151 for an example of this lesson being learned for the builtin exception hierarchy)
Class naming conventions apply here, although you should add the suffix "Error" to your exception classes if the exception is an error. Non-error exceptions that are used for non-local flow control or other forms of signaling need no special suffix.
So I’d rather write:
class HistogramNotFoundError(KeyError):
pass
def get_histogram(self, name):
hist = self.file.Get(name)
if not hist:
raise HistogramNotFoundError(name)
return hist
The choice of KeyError
as a base is a bit arbitrary, but I feel it fits nicely.
A last thing, if you intend to build a lot of canvas to draw on, you may also be interested in wrapping that in a context manager. Either by writing a class like you do (but checking the presence of an exception in the __exit__
method before drawing) or by using the @contextlib.contextmanager
decorator:
@contextmanager
def canvas(name, filename, idunno, width, height):
canvas = ROOT.TCanvas(name, idunno, width, height)
yield canvas
canvas.SaveAs(filename)
I’m not using a try: ... finally:
here to avoid generating a file if the canvas was not properly drawn.
Usage being:
dataset_files = ['dataset_part1.root', 'dataset_part2.root', 'dataset_part3.root']
with Dataset(*dataset_files) as dataset, canvas('canvas', 'plot.pdf', '', 500, 500):
hist = dataset.get_histogram('electron_momentum')
hist.Draw()
Explore related questions
See similar questions with these tags.
ROOT
question here :) \$\endgroup\$with
open multiple files in parallel and extract the same tree from each of them: gist.github.com/graipher/9e8ad0776b6ab5a3935c4faf0f5fa091 \$\endgroup\$