diff -r 44f455e6163d Lib/argparse.py --- a/Lib/argparse.py Thu Jun 27 12:23:29 2013 +0200 +++ b/Lib/argparse.py Fri Feb 21 13:33:45 2014 -0800 @@ -1551,6 +1551,9 @@ super(_MutuallyExclusiveGroup, self).__init__(container) self.required = required self._container = container + # register the test for this type of group + mytest = _MutuallyExclusiveGroup.test_mut_ex_groups + self.register('cross_tests', 'mxg', mytest) def _add_action(self, action): if action.required: @@ -1564,6 +1567,28 @@ self._container._remove_action(action) self._group_actions.remove(action) + @staticmethod + def test_mut_ex_groups(parser, seen_non_default_actions, *vargs): + # alternative mutually_exclusive_groups test + # performed once at end of parse_args rather than with each entry + # the arguments listed in the error message may differ + # this gives a small speed improvement + # more importantly it is easier to customize and expand + + for group in parser._mutually_exclusive_groups: + group_seen = seen_non_default_actions.intersection(group._group_actions) + cnt = len(group_seen) + if cnt > 1: + msg = 'only one the arguments %s is allowed' + elif cnt == 0 and group.required: + msg = 'one of the arguments %s is required' + else: + msg = None + if msg: + names = [_get_action_name(action) + for action in group._group_actions + if action.help is not SUPPRESS] + parser.error(msg % ' '.join(names)) class ArgumentParser(_AttributeHolder, _ActionsContainer): """Object for parsing command line strings into Python objects. @@ -1623,6 +1648,10 @@ return string self.register('type', None, identity) + # initialize cross_tests + # self.register('cross_tests', ?,?) + self._registries['cross_tests'] = {} + # add help argument if necessary # (using explicit default to override global argument_default) default_prefix = '-' if '-' in prefix_chars else prefix_chars[0] @@ -1757,15 +1786,9 @@ if self.fromfile_prefix_chars is not None: arg_strings = self._read_args_from_files(arg_strings) - # map all mutually exclusive arguments to the other arguments - # they can't occur with - action_conflicts = {} - for mutex_group in self._mutually_exclusive_groups: - group_actions = mutex_group._group_actions - for i, mutex_action in enumerate(mutex_group._group_actions): - conflicts = action_conflicts.setdefault(mutex_action, []) - conflicts.extend(group_actions[:i]) - conflicts.extend(group_actions[i + 1:]) + """ + remove action_conflicts collection + """ # find all option indices, and determine the arg_string_pattern # which has an 'O' if there is an option at an index, @@ -1808,11 +1831,9 @@ # value don't really count as "present" if argument_values is not action.default: seen_non_default_actions.add(action) - for conflict_action in action_conflicts.get(action, []): - if conflict_action in seen_non_default_actions: - msg = _('not allowed with argument %s') - action_name = _get_action_name(conflict_action) - raise ArgumentError(action, msg % action_name) + """ + remove action_conflicts use + """ # take the action if we didn't receive a SUPPRESS value # (e.g. from a default) @@ -1980,20 +2001,13 @@ self.error(_('the following arguments are required: %s') % ', '.join(required_actions)) - # make sure all required groups had one option present - for group in self._mutually_exclusive_groups: - if group.required: - for action in group._group_actions: - if action in seen_non_default_actions: - break - - # if no actions were used, report the error - else: - names = [_get_action_name(action) - for action in group._group_actions - if action.help is not SUPPRESS] - msg = _('one of the arguments %s is required') - self.error(msg % ' '.join(names)) + """ + remove required mutually_exclusive_groups test + """ + # give user a hook to run more general tests on arguments + # its primary purpose is to give the user access to seen(_non_default)_actions + for testfn in self._get_cross_tests(): + testfn(self, seen_non_default_actions, seen_actions, namespace, extras) # return the updated namespace and the extra arguments return namespace, extras @@ -2207,6 +2221,21 @@ # return the pattern return nargs_pattern + def _get_cross_tests(self): + # fetch a list (possibly empty) of tests to be run at the end of parsing + # for example, the mutually_exclusive_group tests + # or user supplied tests + + # this could be in a 'parser.cross_tests' attribute + # tests = getattr(self, 'cross_tests', []) + # but here I am looking in the _registries + # _registries is already shared among groups + # allowing me to define the group tests in the group class itself + # This use of _registries is slight non_standard since I am + # ignoring the 2nd level keys + tests = self._registries['cross_tests'].values() + return tests + # ======================== # Value conversion methods # ========================