source: main/waeup.sirp/trunk/src/waeup/sirp/utils/batching.py @ 6265

Last change on this file since 6265 was 6259, checked in by uli, 14 years ago

Remove trailing whitespace.

File size: 11.3 KB
Line 
1"""WAeUP components for batch processing.
2
3Batch processors eat CSV files to add, update or remove large numbers
4of certain kinds of objects at once.
5"""
6import grok
7import copy
8import csv
9import os
10import sys
11import tempfile
12import time
13from zope.component import createObject
14from zope.interface import Interface
15from zope.schema import getFields
16from waeup.sirp.interfaces import (
17    IBatchProcessor, ISchemaTypeConverter, FatalCSVError, DuplicationError)
18
19class BatchProcessor(grok.GlobalUtility):
20    """A processor to add, update, or remove data.
21
22    This is a non-active baseclass.
23    """
24    grok.provides(IBatchProcessor)
25    grok.context(Interface)
26    grok.baseclass()
27
28    # Name used in pages and forms...
29    name = u'Non-registered base importer'
30
31    # Internal name...
32    util_name = 'baseimporter'
33
34    # Items for this processor need an interface with zope.schema fields.
35    iface = Interface
36
37    # The name must be the same as the util_name attribute in order to
38    # register this utility correctly.
39    grok.name(util_name)
40
41    # Headers needed to locate items...
42    location_fields = ['code', 'faculty_code']
43
44    # A factory with this name must be registered...
45    factory_name = 'waeup.Department'
46
47    @property
48    def required_fields(self):
49        """Required fields that have no default.
50
51        A list of names of field, whose value cannot be set if not
52        given during creation. Therefore these fields must exist in
53        input.
54
55        Fields with a default != missing_value do not belong to this
56        category.
57        """
58        result = []
59        for key, field in getFields(self.iface).items():
60            if key in self.location_fields:
61                continue
62            if field.default is not field.missing_value:
63                continue
64            if field.required:
65                result.append(key)
66        return result
67
68    @property
69    def req(self):
70        result = dict(
71            create = self.location_fields + self.required_fields,
72            update = self.location_fields,
73            remove = self.location_fields,
74        )
75        return result
76
77    @property
78    def available_fields(self):
79        result = []
80        return sorted(list(set(
81                    self.location_fields + getFields(self.iface).keys())))
82
83    def getHeaders(self, mode='create'):
84        return self.available_fields
85
86    def checkHeaders(self, headerfields, mode='create'):
87        req = self.req[mode]
88        # Check for required fields...
89        for field in req:
90            if not field in headerfields:
91                raise FatalCSVError(
92                    "Need at least columns %s for import!" %
93                    ', '.join(["'%s'" % x for x in req]))
94        # Check for double fields...
95        not_ignored_fields = [x for x in headerfields
96                              if not x.startswith('--')]
97        if len(set(not_ignored_fields)) < len(not_ignored_fields):
98            raise FatalCSVError(
99                "Double headers: each column name may only appear once.")
100        return True
101
102    def applyMapping(self, row, mapping):
103        """Apply mapping to a row of CSV data.
104        """
105        result = dict()
106        for key, replacement in mapping.items():
107            result[replacement] = row[key]
108        return result
109
110    def getMapping(self, path, headerfields, mode):
111        """Get a mapping from CSV file headerfields to actually used
112           fieldnames.
113        """
114        result = dict()
115        reader = csv.reader(open(path, 'rb'))
116        raw_header = reader.next()
117        for num, field in enumerate(headerfields):
118            if field not in self.location_fields and mode == 'remove':
119                # Ignore non-location fields when removing...
120                field = '--IGNORE--'
121            result[raw_header[num]] = field
122        return result
123
124    def getFieldConverters(self, fieldnames):
125        """Get converters for fieldnames.
126        """
127        result = dict()
128        for key, field in getFields(self.iface).items():
129            if key not in fieldnames:
130                continue
131            converter = ISchemaTypeConverter(field)
132            result[key] = converter
133        return result
134
135    def convertToTypes(self, row, converter_dict):
136        """Convert values in given row to destination type.
137        """
138        if '--IGNORE--' in row.keys():
139            del row['--IGNORE--']
140        warnings = []
141        for key, value in row.items():
142            converter = converter_dict.get(key, None)
143            if converter:
144                try:
145                    row.update({key:converter_dict[key].fromString(value)})
146                except:
147                    msg = "conversion error: field %s: %s %r" % (
148                        key, sys.exc_info()[0], sys.exc_info()[1])
149                    warnings.append(msg)
150        return (row, warnings)
151
152    def callFactory(self, *args, **kw):
153        return createObject(self.factory_name)
154
155    def parentsExist(self, row, site):
156        """Tell whether the parent object for data in ``row`` exists.
157        """
158        raise NotImplementedError('method not implemented')
159
160    def entryExists(self, row, site):
161        """Tell whether there already exists an entry for ``row`` data.
162        """
163        raise NotImplementedError('method not implemented')
164
165    def getParent(self, row, site):
166        """Get the parent object for the entry in ``row``.
167        """
168        raise NotImplementedError('method not implemented')
169
170    def getEntry(self, row, site):
171        """Get the parent object for the entry in ``row``.
172        """
173        raise NotImplementedError('method not implemented')
174
175    def addEntry(self, obj, row, site):
176        """Add the entry given given by ``row`` data.
177        """
178        raise NotImplementedError('method not implemented')
179
180    def delEntry(self, row, site):
181        """Delete entry given by ``row`` data.
182        """
183        raise NotImplementedError('method not implemented')
184
185    def updateEntry(self, obj, row, site):
186        """Update obj to the values given in row.
187        """
188        for key, value in row.items():
189            setattr(obj, key, value)
190        return
191
192    def createLogfile(self, path, fail_path, num, warnings, mode, user,
193                      timedelta, logger=None):
194        """Write to log file.
195        """
196        if logger is None:
197            return
198        status = 'OK'
199        if warnings > 0:
200            status = 'FAILED'
201        logger.info("-" * 20)
202        logger.info("%s: Batch processing finished: %s" % (user, status))
203        logger.info("%s: Source: %s" % (user, path))
204        logger.info("%s: Mode: %s" % (user, mode))
205        logger.info("%s: User: %s" % (user, user))
206        if warnings > 0:
207            logger.info("%s: Failed datasets: %s" % (
208                    user, os.path.basename(fail_path)))
209        logger.info("%s: Processing time: %0.3f s (%0.4f s/item)" % (
210                user, timedelta, timedelta/(num or 1)))
211        logger.info("%s: Processed: %s lines (%s successful/ %s failed)" % (
212                user, num, num - warnings, warnings
213                ))
214        logger.info("-" * 20)
215        return
216
217    def writeFailedRow(self, writer, row, warnings):
218        """Write a row with error messages to error CSV.
219
220        If warnings is a list of strings, they will be concatenated.
221        """
222        error_col = warnings
223        if isinstance(warnings, list):
224            error_col = ' / '.join(warnings)
225        row['--ERRORS--'] = error_col
226        writer.writerow(row)
227        return
228
229    def doImport(self, path, headerfields, mode='create', user='Unknown',
230                 logger=None):
231        """Perform actual import.
232        """
233        time_start = time.time()
234        self.checkHeaders(headerfields, mode)
235        mapping = self.getMapping(path, headerfields, mode)
236        converters = self.getFieldConverters(headerfields)
237        reader = csv.DictReader(open(path, 'rb'))
238
239        temp_dir = tempfile.mkdtemp()
240
241        (base, ext) = os.path.splitext(path)
242        failed_path = os.path.join(temp_dir, "%s.pending%s" % (base, ext))
243        failed_headers = mapping.keys()
244        failed_headers.append('--ERRORS--')
245        failed_writer = csv.DictWriter(open(failed_path, 'wb'),
246                                       failed_headers)
247        first_row = mapping.items()
248        first_row.append(("--ERRORS--", "--ERRORS--"),)
249        failed_writer.writerow(dict(first_row))
250
251        finished_path = os.path.join(temp_dir, "%s.finished%s" % (base, ext))
252        finished_headers = [x for x in mapping.values()]
253        finished_writer = csv.DictWriter(open(finished_path, 'wb'),
254                                         finished_headers)
255        finished_writer.writerow(dict([(x,x) for x in finished_headers]))
256
257        num =0
258        num_warns = 0
259        site = grok.getSite()
260        for raw_row in reader:
261            num += 1
262            string_row = self.applyMapping(raw_row, mapping)
263            row, conv_warnings = self.convertToTypes(
264                copy.deepcopy(string_row), converters)
265            if len(conv_warnings):
266                num_warns += 1
267                self.writeFailedRow(failed_writer, raw_row, conv_warnings)
268                continue
269
270            if mode == 'create':
271                if not self.parentsExist(row, site):
272                    num_warns += 1
273                    self.writeFailedRow(
274                        failed_writer, raw_row,
275                        "Not all parents do exist yet. Skipping")
276                    continue
277                if self.entryExists(row, site):
278                    num_warns += 1
279                    self.writeFailedRow(
280                        failed_writer, raw_row,
281                        "This object already exists in the same container. Skipping.")
282                    continue
283                obj = self.callFactory()
284                for key, value in row.items():
285                    setattr(obj, key, value)
286                try:
287                    self.addEntry(obj, row, site)
288                except DuplicationError, error:
289                    num_warns += 1
290                    self.writeFailedRow(
291                        failed_writer, raw_row,
292                        "%s Skipping." % error.msg)
293                    continue
294            elif mode == 'remove':
295                if not self.entryExists(row, site):
296                    num_warns += 1
297                    self.writeFailedRow(
298                        failed_writer, raw_row,
299                        "Cannot remove: no such entry.")
300                    continue
301                self.delEntry(row, site)
302            elif mode == 'update':
303                obj = self.getEntry(row, site)
304                if obj is None:
305                    num_warns += 1
306                    self.writeFailedRow(
307                        failed_writer, raw_row,
308                        "Cannot update: no such entry.")
309                    continue
310                self.updateEntry(obj, row, site)
311            finished_writer.writerow(string_row)
312
313        time_end = time.time()
314        timedelta = time_end - time_start
315
316        self.createLogfile(path, failed_path, num, num_warns, mode, user,
317                           timedelta, logger=logger)
318        failed_path = os.path.abspath(failed_path)
319        if num_warns == 0:
320            del failed_writer
321            os.unlink(failed_path)
322            failed_path = None
323        return (num, num_warns,
324                os.path.abspath(finished_path), failed_path)
Note: See TracBrowser for help on using the repository browser.