source: main/waeup.sirp/trunk/src/waeup/sirp/utils/batching.py @ 6404

Last change on this file since 6404 was 6276, checked in by uli, 14 years ago

Remove old converters and disable tests for them.

File size: 10.8 KB
Line 
1"""WAeUP components for batch processing.
2
3Batch processors eat CSV files to add, update or remove large numbers
4of certain kinds of objects at once.
5"""
6import grok
7import copy
8import csv
9import os
10import sys
11import tempfile
12import time
13from zope.component import createObject
14from zope.interface import Interface
15from zope.schema import getFields
16from waeup.sirp.interfaces import (
17    IBatchProcessor, FatalCSVError, DuplicationError, IObjectConverter)
18
19class BatchProcessor(grok.GlobalUtility):
20    """A processor to add, update, or remove data.
21
22    This is a non-active baseclass.
23    """
24    grok.provides(IBatchProcessor)
25    grok.context(Interface)
26    grok.baseclass()
27
28    # Name used in pages and forms...
29    name = u'Non-registered base importer'
30
31    # Internal name...
32    util_name = 'baseimporter'
33
34    # Items for this processor need an interface with zope.schema fields.
35    iface = Interface
36
37    # The name must be the same as the util_name attribute in order to
38    # register this utility correctly.
39    grok.name(util_name)
40
41    # Headers needed to locate items...
42    location_fields = ['code', 'faculty_code']
43
44    # A factory with this name must be registered...
45    factory_name = 'waeup.Department'
46
47    @property
48    def required_fields(self):
49        """Required fields that have no default.
50
51        A list of names of field, whose value cannot be set if not
52        given during creation. Therefore these fields must exist in
53        input.
54
55        Fields with a default != missing_value do not belong to this
56        category.
57        """
58        result = []
59        for key, field in getFields(self.iface).items():
60            if key in self.location_fields:
61                continue
62            if field.default is not field.missing_value:
63                continue
64            if field.required:
65                result.append(key)
66        return result
67
68    @property
69    def req(self):
70        result = dict(
71            create = self.location_fields + self.required_fields,
72            update = self.location_fields,
73            remove = self.location_fields,
74        )
75        return result
76
77    @property
78    def available_fields(self):
79        result = []
80        return sorted(list(set(
81                    self.location_fields + getFields(self.iface).keys())))
82
83    def getHeaders(self, mode='create'):
84        return self.available_fields
85
86    def checkHeaders(self, headerfields, mode='create'):
87        req = self.req[mode]
88        # Check for required fields...
89        for field in req:
90            if not field in headerfields:
91                raise FatalCSVError(
92                    "Need at least columns %s for import!" %
93                    ', '.join(["'%s'" % x for x in req]))
94        # Check for double fields...
95        not_ignored_fields = [x for x in headerfields
96                              if not x.startswith('--')]
97        if len(set(not_ignored_fields)) < len(not_ignored_fields):
98            raise FatalCSVError(
99                "Double headers: each column name may only appear once.")
100        return True
101
102    def applyMapping(self, row, mapping):
103        """Apply mapping to a row of CSV data.
104        """
105        result = dict()
106        for key, replacement in mapping.items():
107            result[replacement] = row[key]
108        return result
109
110    def getMapping(self, path, headerfields, mode):
111        """Get a mapping from CSV file headerfields to actually used
112           fieldnames.
113        """
114        result = dict()
115        reader = csv.reader(open(path, 'rb'))
116        raw_header = reader.next()
117        for num, field in enumerate(headerfields):
118            if field not in self.location_fields and mode == 'remove':
119                # Ignore non-location fields when removing...
120                field = '--IGNORE--'
121            result[raw_header[num]] = field
122        return result
123
124    def stringFromErrs(self, errors, inv_errors):
125        result = []
126        for err in errors:
127            fieldname, message = err
128            result.append("%s: %s" % (fieldname, message))
129        for err in inv_errors:
130            result.append("invariant: %s" % err)
131        return '; '.join(result)
132
133    def callFactory(self, *args, **kw):
134        return createObject(self.factory_name)
135
136    def parentsExist(self, row, site):
137        """Tell whether the parent object for data in ``row`` exists.
138        """
139        raise NotImplementedError('method not implemented')
140
141    def entryExists(self, row, site):
142        """Tell whether there already exists an entry for ``row`` data.
143        """
144        raise NotImplementedError('method not implemented')
145
146    def getParent(self, row, site):
147        """Get the parent object for the entry in ``row``.
148        """
149        raise NotImplementedError('method not implemented')
150
151    def getEntry(self, row, site):
152        """Get the parent object for the entry in ``row``.
153        """
154        raise NotImplementedError('method not implemented')
155
156    def addEntry(self, obj, row, site):
157        """Add the entry given given by ``row`` data.
158        """
159        raise NotImplementedError('method not implemented')
160
161    def delEntry(self, row, site):
162        """Delete entry given by ``row`` data.
163        """
164        raise NotImplementedError('method not implemented')
165
166    def updateEntry(self, obj, row, site):
167        """Update obj to the values given in row.
168        """
169        for key, value in row.items():
170            setattr(obj, key, value)
171        return
172
173    def createLogfile(self, path, fail_path, num, warnings, mode, user,
174                      timedelta, logger=None):
175        """Write to log file.
176        """
177        if logger is None:
178            return
179        status = 'OK'
180        if warnings > 0:
181            status = 'FAILED'
182        logger.info("-" * 20)
183        logger.info("%s: Batch processing finished: %s" % (user, status))
184        logger.info("%s: Source: %s" % (user, path))
185        logger.info("%s: Mode: %s" % (user, mode))
186        logger.info("%s: User: %s" % (user, user))
187        if warnings > 0:
188            logger.info("%s: Failed datasets: %s" % (
189                    user, os.path.basename(fail_path)))
190        logger.info("%s: Processing time: %0.3f s (%0.4f s/item)" % (
191                user, timedelta, timedelta/(num or 1)))
192        logger.info("%s: Processed: %s lines (%s successful/ %s failed)" % (
193                user, num, num - warnings, warnings
194                ))
195        logger.info("-" * 20)
196        return
197
198    def writeFailedRow(self, writer, row, warnings):
199        """Write a row with error messages to error CSV.
200
201        If warnings is a list of strings, they will be concatenated.
202        """
203        error_col = warnings
204        if isinstance(warnings, list):
205            error_col = ' / '.join(warnings)
206        row['--ERRORS--'] = error_col
207        writer.writerow(row)
208        return
209
210    def doImport(self, path, headerfields, mode='create', user='Unknown',
211                 logger=None):
212        """Perform actual import.
213        """
214        time_start = time.time()
215        self.checkHeaders(headerfields, mode)
216        mapping = self.getMapping(path, headerfields, mode)
217        reader = csv.DictReader(open(path, 'rb'))
218
219        temp_dir = tempfile.mkdtemp()
220
221        base = os.path.basename(path)
222        (base, ext) = os.path.splitext(base)
223        failed_path = os.path.join(temp_dir, "%s.pending%s" % (base, ext))
224        failed_headers = mapping.keys()
225        failed_headers.append('--ERRORS--')
226        failed_writer = csv.DictWriter(open(failed_path, 'wb'),
227                                       failed_headers)
228        first_row = mapping.items()
229        first_row.append(("--ERRORS--", "--ERRORS--"),)
230        failed_writer.writerow(dict(first_row))
231
232        finished_path = os.path.join(temp_dir, "%s.finished%s" % (base, ext))
233        finished_headers = [x for x in mapping.values()]
234        finished_writer = csv.DictWriter(open(finished_path, 'wb'),
235                                         finished_headers)
236        finished_writer.writerow(dict([(x,x) for x in finished_headers]))
237
238        num =0
239        num_warns = 0
240        site = grok.getSite()
241        converter = IObjectConverter(self.iface)
242        for raw_row in reader:
243            num += 1
244            string_row = self.applyMapping(raw_row, mapping)
245            row = dict(string_row.items()) # create deep copy
246            errs, inv_errs, conv_dict =  converter.fromStringDict(
247                string_row, self.factory_name)
248            if errs or inv_errs:
249                num_warns += 1
250                conv_warnings = self.stringFromErrs(errs, inv_errs)
251                self.writeFailedRow(
252                    failed_writer, raw_row, conv_warnings)
253                continue
254            row.update(conv_dict)
255
256            if mode == 'create':
257                if not self.parentsExist(row, site):
258                    num_warns += 1
259                    self.writeFailedRow(
260                        failed_writer, raw_row,
261                        "Not all parents do exist yet. Skipping")
262                    continue
263                if self.entryExists(row, site):
264                    num_warns += 1
265                    self.writeFailedRow(
266                        failed_writer, raw_row,
267                        "This object already exists in the same container. Skipping.")
268                    continue
269                obj = self.callFactory()
270                for key, value in row.items():
271                    setattr(obj, key, value)
272                try:
273                    self.addEntry(obj, row, site)
274                except KeyError, error:
275                    num_warns += 1
276                    self.writeFailedRow(
277                        failed_writer, raw_row,
278                        "%s Skipping." % error.message)
279                    continue
280            elif mode == 'remove':
281                if not self.entryExists(row, site):
282                    num_warns += 1
283                    self.writeFailedRow(
284                        failed_writer, raw_row,
285                        "Cannot remove: no such entry.")
286                    continue
287                self.delEntry(row, site)
288            elif mode == 'update':
289                obj = self.getEntry(row, site)
290                if obj is None:
291                    num_warns += 1
292                    self.writeFailedRow(
293                        failed_writer, raw_row,
294                        "Cannot update: no such entry.")
295                    continue
296                self.updateEntry(obj, row, site)
297            finished_writer.writerow(string_row)
298
299        time_end = time.time()
300        timedelta = time_end - time_start
301
302        self.createLogfile(path, failed_path, num, num_warns, mode, user,
303                           timedelta, logger=logger)
304        failed_path = os.path.abspath(failed_path)
305        if num_warns == 0:
306            del failed_writer
307            os.unlink(failed_path)
308            failed_path = None
309        return (num, num_warns,
310                os.path.abspath(finished_path), failed_path)
Note: See TracBrowser for help on using the repository browser.