source: main/waeup.sirp/trunk/src/waeup/sirp/utils/batching.py @ 6814

Last change on this file since 6814 was 6276, checked in by uli, 14 years ago

Remove old converters and disable tests for them.

File size: 10.8 KB
RevLine 
[4806]1"""WAeUP components for batch processing.
2
3Batch processors eat CSV files to add, update or remove large numbers
4of certain kinds of objects at once.
5"""
6import grok
[4870]7import copy
[4806]8import csv
[4821]9import os
10import sys
[4900]11import tempfile
[4821]12import time
[4806]13from zope.component import createObject
14from zope.interface import Interface
15from zope.schema import getFields
[5005]16from waeup.sirp.interfaces import (
[6276]17    IBatchProcessor, FatalCSVError, DuplicationError, IObjectConverter)
[4806]18
19class BatchProcessor(grok.GlobalUtility):
20    """A processor to add, update, or remove data.
21
22    This is a non-active baseclass.
23    """
[4831]24    grok.provides(IBatchProcessor)
[4806]25    grok.context(Interface)
26    grok.baseclass()
27
28    # Name used in pages and forms...
[5009]29    name = u'Non-registered base importer'
[6259]30
[4806]31    # Internal name...
[5009]32    util_name = 'baseimporter'
[6259]33
[4806]34    # Items for this processor need an interface with zope.schema fields.
[5009]35    iface = Interface
[6259]36
[4806]37    # The name must be the same as the util_name attribute in order to
38    # register this utility correctly.
39    grok.name(util_name)
40
41    # Headers needed to locate items...
42    location_fields = ['code', 'faculty_code']
[6259]43
[4806]44    # A factory with this name must be registered...
45    factory_name = 'waeup.Department'
46
47    @property
48    def required_fields(self):
[4829]49        """Required fields that have no default.
50
51        A list of names of field, whose value cannot be set if not
52        given during creation. Therefore these fields must exist in
53        input.
54
55        Fields with a default != missing_value do not belong to this
56        category.
57        """
[4806]58        result = []
59        for key, field in getFields(self.iface).items():
60            if key in self.location_fields:
61                continue
[4829]62            if field.default is not field.missing_value:
63                continue
[4806]64            if field.required:
65                result.append(key)
66        return result
[6259]67
[4806]68    @property
69    def req(self):
70        result = dict(
71            create = self.location_fields + self.required_fields,
72            update = self.location_fields,
73            remove = self.location_fields,
74        )
75        return result
76
77    @property
78    def available_fields(self):
79        result = []
80        return sorted(list(set(
81                    self.location_fields + getFields(self.iface).keys())))
[6259]82
[4806]83    def getHeaders(self, mode='create'):
84        return self.available_fields
85
86    def checkHeaders(self, headerfields, mode='create'):
87        req = self.req[mode]
88        # Check for required fields...
89        for field in req:
90            if not field in headerfields:
91                raise FatalCSVError(
92                    "Need at least columns %s for import!" %
93                    ', '.join(["'%s'" % x for x in req]))
94        # Check for double fields...
95        not_ignored_fields = [x for x in headerfields
96                              if not x.startswith('--')]
97        if len(set(not_ignored_fields)) < len(not_ignored_fields):
98            raise FatalCSVError(
99                "Double headers: each column name may only appear once.")
100        return True
101
102    def applyMapping(self, row, mapping):
[4811]103        """Apply mapping to a row of CSV data.
104        """
[4806]105        result = dict()
106        for key, replacement in mapping.items():
107            result[replacement] = row[key]
108        return result
[6259]109
[4832]110    def getMapping(self, path, headerfields, mode):
[4811]111        """Get a mapping from CSV file headerfields to actually used
112           fieldnames.
113        """
[4832]114        result = dict()
[4806]115        reader = csv.reader(open(path, 'rb'))
116        raw_header = reader.next()
[4832]117        for num, field in enumerate(headerfields):
118            if field not in self.location_fields and mode == 'remove':
119                # Ignore non-location fields when removing...
120                field = '--IGNORE--'
121            result[raw_header[num]] = field
122        return result
[4806]123
[6273]124    def stringFromErrs(self, errors, inv_errors):
125        result = []
126        for err in errors:
127            fieldname, message = err
128            result.append("%s: %s" % (fieldname, message))
129        for err in inv_errors:
130            result.append("invariant: %s" % err)
131        return '; '.join(result)
132
[4806]133    def callFactory(self, *args, **kw):
134        return createObject(self.factory_name)
135
136    def parentsExist(self, row, site):
[4811]137        """Tell whether the parent object for data in ``row`` exists.
138        """
[4806]139        raise NotImplementedError('method not implemented')
140
141    def entryExists(self, row, site):
[4811]142        """Tell whether there already exists an entry for ``row`` data.
143        """
[4806]144        raise NotImplementedError('method not implemented')
145
146    def getParent(self, row, site):
[4811]147        """Get the parent object for the entry in ``row``.
148        """
[4806]149        raise NotImplementedError('method not implemented')
[6259]150
[5009]151    def getEntry(self, row, site):
152        """Get the parent object for the entry in ``row``.
153        """
154        raise NotImplementedError('method not implemented')
[6259]155
[4806]156    def addEntry(self, obj, row, site):
[4811]157        """Add the entry given given by ``row`` data.
158        """
[4806]159        raise NotImplementedError('method not implemented')
160
161    def delEntry(self, row, site):
[4811]162        """Delete entry given by ``row`` data.
163        """
[6259]164        raise NotImplementedError('method not implemented')
[4806]165
166    def updateEntry(self, obj, row, site):
[4984]167        """Update obj to the values given in row.
168        """
[4829]169        for key, value in row.items():
170            setattr(obj, key, value)
171        return
[4821]172
[4832]173    def createLogfile(self, path, fail_path, num, warnings, mode, user,
[4885]174                      timedelta, logger=None):
175        """Write to log file.
[4821]176        """
[4885]177        if logger is None:
178            return
179        status = 'OK'
180        if warnings > 0:
181            status = 'FAILED'
182        logger.info("-" * 20)
183        logger.info("%s: Batch processing finished: %s" % (user, status))
184        logger.info("%s: Source: %s" % (user, path))
185        logger.info("%s: Mode: %s" % (user, mode))
186        logger.info("%s: User: %s" % (user, user))
187        if warnings > 0:
[4900]188            logger.info("%s: Failed datasets: %s" % (
189                    user, os.path.basename(fail_path)))
[4885]190        logger.info("%s: Processing time: %0.3f s (%0.4f s/item)" % (
191                user, timedelta, timedelta/(num or 1)))
192        logger.info("%s: Processed: %s lines (%s successful/ %s failed)" % (
193                user, num, num - warnings, warnings
[4821]194                ))
[4885]195        logger.info("-" * 20)
[4821]196        return
[4877]197
198    def writeFailedRow(self, writer, row, warnings):
199        """Write a row with error messages to error CSV.
200
201        If warnings is a list of strings, they will be concatenated.
202        """
203        error_col = warnings
204        if isinstance(warnings, list):
205            error_col = ' / '.join(warnings)
206        row['--ERRORS--'] = error_col
207        writer.writerow(row)
208        return
[6259]209
[4885]210    def doImport(self, path, headerfields, mode='create', user='Unknown',
211                 logger=None):
[4811]212        """Perform actual import.
213        """
[4832]214        time_start = time.time()
[4806]215        self.checkHeaders(headerfields, mode)
[4832]216        mapping = self.getMapping(path, headerfields, mode)
[4806]217        reader = csv.DictReader(open(path, 'rb'))
[4889]218
[4900]219        temp_dir = tempfile.mkdtemp()
[6259]220
[6273]221        base = os.path.basename(path)
222        (base, ext) = os.path.splitext(base)
[4900]223        failed_path = os.path.join(temp_dir, "%s.pending%s" % (base, ext))
[4911]224        failed_headers = mapping.keys()
[4877]225        failed_headers.append('--ERRORS--')
[4821]226        failed_writer = csv.DictWriter(open(failed_path, 'wb'),
227                                       failed_headers)
[4911]228        first_row = mapping.items()
229        first_row.append(("--ERRORS--", "--ERRORS--"),)
230        failed_writer.writerow(dict(first_row))
[4891]231
[4900]232        finished_path = os.path.join(temp_dir, "%s.finished%s" % (base, ext))
[4891]233        finished_headers = [x for x in mapping.values()]
234        finished_writer = csv.DictWriter(open(finished_path, 'wb'),
235                                         finished_headers)
236        finished_writer.writerow(dict([(x,x) for x in finished_headers]))
[6259]237
[4806]238        num =0
[4878]239        num_warns = 0
[4806]240        site = grok.getSite()
[6273]241        converter = IObjectConverter(self.iface)
[4806]242        for raw_row in reader:
243            num += 1
244            string_row = self.applyMapping(raw_row, mapping)
[6273]245            row = dict(string_row.items()) # create deep copy
246            errs, inv_errs, conv_dict =  converter.fromStringDict(
247                string_row, self.factory_name)
248            if errs or inv_errs:
[4878]249                num_warns += 1
[6273]250                conv_warnings = self.stringFromErrs(errs, inv_errs)
251                self.writeFailedRow(
252                    failed_writer, raw_row, conv_warnings)
[4821]253                continue
[6273]254            row.update(conv_dict)
[6259]255
[4806]256            if mode == 'create':
257                if not self.parentsExist(row, site):
[4878]258                    num_warns += 1
[4877]259                    self.writeFailedRow(
[4911]260                        failed_writer, raw_row,
[4877]261                        "Not all parents do exist yet. Skipping")
[4806]262                    continue
263                if self.entryExists(row, site):
[4878]264                    num_warns += 1
[4877]265                    self.writeFailedRow(
[4911]266                        failed_writer, raw_row,
[6219]267                        "This object already exists in the same container. Skipping.")
[4806]268                    continue
269                obj = self.callFactory()
270                for key, value in row.items():
271                    setattr(obj, key, value)
[6243]272                try:
273                    self.addEntry(obj, row, site)
[6273]274                except KeyError, error:
[6219]275                    num_warns += 1
276                    self.writeFailedRow(
277                        failed_writer, raw_row,
[6273]278                        "%s Skipping." % error.message)
[6219]279                    continue
[4806]280            elif mode == 'remove':
281                if not self.entryExists(row, site):
[4878]282                    num_warns += 1
[4877]283                    self.writeFailedRow(
[4911]284                        failed_writer, raw_row,
[4877]285                        "Cannot remove: no such entry.")
[4806]286                    continue
287                self.delEntry(row, site)
288            elif mode == 'update':
289                obj = self.getEntry(row, site)
290                if obj is None:
[4878]291                    num_warns += 1
[4877]292                    self.writeFailedRow(
[4911]293                        failed_writer, raw_row,
[4877]294                        "Cannot update: no such entry.")
[4806]295                    continue
296                self.updateEntry(obj, row, site)
[4891]297            finished_writer.writerow(string_row)
[4821]298
[4832]299        time_end = time.time()
300        timedelta = time_end - time_start
[6259]301
[4878]302        self.createLogfile(path, failed_path, num, num_warns, mode, user,
[4885]303                           timedelta, logger=logger)
[4894]304        failed_path = os.path.abspath(failed_path)
[4878]305        if num_warns == 0:
[4821]306            del failed_writer
307            os.unlink(failed_path)
[4894]308            failed_path = None
309        return (num, num_warns,
310                os.path.abspath(finished_path), failed_path)
Note: See TracBrowser for help on using the repository browser.