252 lines
8.3 KiB
Python

# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 Citrix Systems, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Command-line flag library.
Wraps gflags.
Global flags should be defined here, the rest are defined where they're used.
"""
import getopt
import os
import string
import sys
import gflags
class FlagValues(gflags.FlagValues):
"""Extension of gflags.FlagValues that allows undefined and runtime flags.
Unknown flags will be ignored when parsing the command line, but the
command line will be kept so that it can be replayed if new flags are
defined after the initial parsing.
"""
def __init__(self, extra_context=None):
gflags.FlagValues.__init__(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
self.__dict__['__extra_context'] = extra_context
def __call__(self, argv):
# We're doing some hacky stuff here so that we don't have to copy
# out all the code of the original verbatim and then tweak a few lines.
# We're hijacking the output of getopt so we can still return the
# leftover args at the end
sneaky_unparsed_args = {"value": None}
original_argv = list(argv)
if self.IsGnuGetOpt():
orig_getopt = getattr(getopt, 'gnu_getopt')
orig_name = 'gnu_getopt'
else:
orig_getopt = getattr(getopt, 'getopt')
orig_name = 'getopt'
def _sneaky(*args, **kw):
optlist, unparsed_args = orig_getopt(*args, **kw)
sneaky_unparsed_args['value'] = unparsed_args
return optlist, unparsed_args
try:
setattr(getopt, orig_name, _sneaky)
args = gflags.FlagValues.__call__(self, argv)
except gflags.UnrecognizedFlagError:
# Undefined args were found, for now we don't care so just
# act like everything went well
# (these three lines are copied pretty much verbatim from the end
# of the __call__ function we are wrapping)
unparsed_args = sneaky_unparsed_args['value']
if unparsed_args:
if self.IsGnuGetOpt():
args = argv[:1] + unparsed_args
else:
args = argv[:1] + original_argv[-len(unparsed_args):]
else:
args = argv[:1]
finally:
setattr(getopt, orig_name, orig_getopt)
# Store the arguments for later, we'll need them for new flags
# added at runtime
self.__dict__['__stored_argv'] = original_argv
self.__dict__['__was_already_parsed'] = True
self.ClearDirty()
return args
def Reset(self):
gflags.FlagValues.Reset(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name)
def IsDirty(self, name):
return name in self.__dict__['__dirty']
def ClearDirty(self):
self.__dict__['__is_dirty'] = []
def WasAlreadyParsed(self):
return self.__dict__['__was_already_parsed']
def ParseNewFlags(self):
if '__stored_argv' not in self.__dict__:
return
new_flags = FlagValues(self)
for k in self.__dict__['__dirty']:
new_flags[k] = gflags.FlagValues.__getitem__(self, k)
new_flags(self.__dict__['__stored_argv'])
for k in self.__dict__['__dirty']:
setattr(self, k, getattr(new_flags, k))
self.ClearDirty()
def __setitem__(self, name, flag):
gflags.FlagValues.__setitem__(self, name, flag)
if self.WasAlreadyParsed():
self.SetDirty(name)
def __getitem__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
return gflags.FlagValues.__getitem__(self, name)
def __getattr__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
val = gflags.FlagValues.__getattr__(self, name)
if type(val) is str:
tmpl = string.Template(val)
context = [self, self.__dict__['__extra_context']]
return tmpl.substitute(StrWrapper(context))
return val
class StrWrapper(object):
"""Wrapper around FlagValues objects.
Wraps FlagValues objects for string.Template so that we're
sure to return strings.
"""
def __init__(self, context_objs):
self.context_objs = context_objs
def __getitem__(self, name):
for context in self.context_objs:
val = getattr(context, name, False)
if val:
return str(val)
raise KeyError(name)
# Copied from gflags with small mods to get the naming correct.
# Originally gflags checks for the first module that is not gflags that is
# in the call chain, we want to check for the first module that is not gflags
# and not this module.
def _GetCallingModule():
"""Returns the name of the module that's calling into this module.
We generally use this function to get the name of the module calling a
DEFINE_foo... function.
"""
# Walk down the stack to find the first globals dict that's not ours.
for depth in range(1, sys.getrecursionlimit()):
if not sys._getframe(depth).f_globals is globals():
module_name = __GetModuleName(sys._getframe(depth).f_globals)
if module_name == 'gflags':
continue
if module_name is not None:
return module_name
raise AssertionError("No module was found")
# Copied from gflags because it is a private function
def __GetModuleName(globals_dict):
"""Given a globals dict, returns the name of the module that defines it.
Args:
globals_dict: A dictionary that should correspond to an environment
providing the values of the globals.
Returns:
A string (the name of the module) or None (if the module could not
be identified.
"""
for name, module in sys.modules.iteritems():
if getattr(module, '__dict__', None) is globals_dict:
if name == '__main__':
return sys.argv[0]
return name
return None
def _wrapper(func):
def _wrapped(*args, **kw):
kw.setdefault('flag_values', FLAGS)
func(*args, **kw)
_wrapped.func_name = func.func_name
return _wrapped
FLAGS = FlagValues()
gflags.FLAGS = FLAGS
gflags._GetCallingModule = _GetCallingModule
DEFINE = _wrapper(gflags.DEFINE)
DEFINE_string = _wrapper(gflags.DEFINE_string)
DEFINE_integer = _wrapper(gflags.DEFINE_integer)
DEFINE_bool = _wrapper(gflags.DEFINE_bool)
DEFINE_boolean = _wrapper(gflags.DEFINE_boolean)
DEFINE_float = _wrapper(gflags.DEFINE_float)
DEFINE_enum = _wrapper(gflags.DEFINE_enum)
DEFINE_list = _wrapper(gflags.DEFINE_list)
DEFINE_spaceseplist = _wrapper(gflags.DEFINE_spaceseplist)
DEFINE_multistring = _wrapper(gflags.DEFINE_multistring)
DEFINE_multi_int = _wrapper(gflags.DEFINE_multi_int)
DEFINE_flag = _wrapper(gflags.DEFINE_flag)
HelpFlag = gflags.HelpFlag
HelpshortFlag = gflags.HelpshortFlag
HelpXMLFlag = gflags.HelpXMLFlag
def DECLARE(name, module_string, flag_values=FLAGS):
if module_string not in sys.modules:
__import__(module_string, globals(), locals())
if name not in flag_values:
raise gflags.UnrecognizedFlag(
"%s not defined by %s" % (name, module_string))
# __GLOBAL FLAGS ONLY__
# Define any app-specific flags in their own files, docs at:
# http://code.google.com/p/python-gflags/source/browse/trunk/gflags.py#a9
DEFINE_string('state_path', os.path.join(os.path.dirname(__file__), '../../'),
"Top-level directory for maintaining quantum's state")