source: main/waeup.sirp/trunk/src/waeup/sirp/utils/batching.py @ 7310

Last change on this file since 7310 was 7273, checked in by Henrik Bettermann, 13 years ago

Add test for student data migration to be sure that student_ids provided in import files are correctly produced and that the random id generator is thus neutralized.

  • Property svn:keywords set to Id
File size: 12.3 KB
Line 
1## $Id: batching.py 7273 2011-12-04 21:08:53Z henrik $
2##
3## Copyright (C) 2011 Uli Fouquet & Henrik Bettermann
4## This program is free software; you can redistribute it and/or modify
5## it under the terms of the GNU General Public License as published by
6## the Free Software Foundation; either version 2 of the License, or
7## (at your option) any later version.
8##
9## This program is distributed in the hope that it will be useful,
10## but WITHOUT ANY WARRANTY; without even the implied warranty of
11## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12## GNU General Public License for more details.
13##
14## You should have received a copy of the GNU General Public License
15## along with this program; if not, write to the Free Software
16## Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17##
18"""WAeUP components for batch processing.
19
20Batch processors eat CSV files to add, update or remove large numbers
21of certain kinds of objects at once.
22"""
23import grok
24import copy
25import csv
26import os
27import sys
28import tempfile
29import time
30from zope.component import createObject
31from zope.interface import Interface
32from zope.schema import getFields
33from waeup.sirp.interfaces import (
34    IBatchProcessor, FatalCSVError, DuplicationError, IObjectConverter)
35
36class BatchProcessor(grok.GlobalUtility):
37    """A processor to add, update, or remove data.
38
39    This is a non-active baseclass.
40    """
41    grok.provides(IBatchProcessor)
42    grok.context(Interface)
43    grok.baseclass()
44
45    # Name used in pages and forms...
46    name = u'Non-registered base importer'
47
48    # Internal name...
49    util_name = 'baseimporter'
50
51    # Items for this processor need an interface with zope.schema fields.
52    iface = Interface
53
54    # The name must be the same as the util_name attribute in order to
55    # register this utility correctly.
56    grok.name(util_name)
57
58    # Headers needed to locate items...
59    location_fields = ['code', 'faculty_code']
60
61    # A factory with this name must be registered...
62    factory_name = 'waeup.Department'
63
64    @property
65    def required_fields(self):
66        """Required fields that have no default.
67
68        A list of names of field, whose value cannot be set if not
69        given during creation. Therefore these fields must exist in
70        input.
71
72        Fields with a default != missing_value do not belong to this
73        category.
74        """
75        result = []
76        for key, field in getFields(self.iface).items():
77            if key in self.location_fields:
78                continue
79            if field.default is not field.missing_value:
80                continue
81            if field.required:
82                result.append(key)
83        return result
84
85    @property
86    def req(self):
87        result = dict(
88            create = self.location_fields + self.required_fields,
89            update = self.location_fields,
90            remove = self.location_fields,
91        )
92        return result
93
94    @property
95    def available_fields(self):
96        result = []
97        return sorted(list(set(
98                    self.location_fields + getFields(self.iface).keys())))
99
100    def getHeaders(self, mode='create'):
101        return self.available_fields
102
103    def checkHeaders(self, headerfields, mode='create'):
104        req = self.req[mode]
105        # Check for required fields...
106        for field in req:
107            if not field in headerfields:
108                raise FatalCSVError(
109                    "Need at least columns %s for import!" %
110                    ', '.join(["'%s'" % x for x in req]))
111        # Check for double fields. Cannot happen because this error is
112        # already catched in views
113        not_ignored_fields = [x for x in headerfields
114                              if not x.startswith('--')]
115        if len(set(not_ignored_fields)) < len(not_ignored_fields):
116            raise FatalCSVError(
117                "Double headers: each column name may only appear once.")
118        return True
119
120    def applyMapping(self, row, mapping):
121        """Apply mapping to a row of CSV data.
122
123        """
124        result = dict()
125        for key, replacement in mapping.items():
126            if replacement == u'--IGNORE--':
127                # Skip ignored columns in failed and finished data files.
128                continue
129            result[replacement] = row[key]
130        return result
131
132    def getMapping(self, path, headerfields, mode):
133        """Get a mapping from CSV file headerfields to actually used fieldnames.
134
135        """
136        result = dict()
137        reader = csv.reader(open(path, 'rb'))
138        raw_header = reader.next()
139        for num, field in enumerate(headerfields):
140            if field not in self.location_fields and mode == 'remove':
141                # Skip non-location fields when removing.
142                continue
143            if field == u'--IGNORE--':
144                # Skip ignored columns in failed and finished data files.
145                continue
146            result[raw_header[num]] = field
147        return result
148
149    def stringFromErrs(self, errors, inv_errors):
150        result = []
151        for err in errors:
152            fieldname, message = err
153            result.append("%s: %s" % (fieldname, message))
154        for err in inv_errors:
155            result.append("invariant: %s" % err)
156        return '; '.join(result)
157
158    def callFactory(self, *args, **kw):
159        return createObject(self.factory_name)
160
161    def parentsExist(self, row, site):
162        """Tell whether the parent object for data in ``row`` exists.
163        """
164        raise NotImplementedError('method not implemented')
165
166    def entryExists(self, row, site):
167        """Tell whether there already exists an entry for ``row`` data.
168        """
169        raise NotImplementedError('method not implemented')
170
171    def getParent(self, row, site):
172        """Get the parent object for the entry in ``row``.
173        """
174        raise NotImplementedError('method not implemented')
175
176    def getEntry(self, row, site):
177        """Get the parent object for the entry in ``row``.
178        """
179        raise NotImplementedError('method not implemented')
180
181    def addEntry(self, obj, row, site):
182        """Add the entry given given by ``row`` data.
183        """
184        raise NotImplementedError('method not implemented')
185
186    def delEntry(self, row, site):
187        """Delete entry given by ``row`` data.
188        """
189        raise NotImplementedError('method not implemented')
190
191    def updateEntry(self, obj, row, site):
192        """Update obj to the values given in row.
193        """
194        for key, value in row.items():
195            # Skip fields not declared in interface.
196            if hasattr(obj, key):
197                setattr(obj, key, value)
198        return
199
200    def createLogfile(self, path, fail_path, num, warnings, mode, user,
201                      timedelta, logger=None):
202        """Write to log file.
203        """
204        if logger is None:
205            return
206        status = 'OK'
207        if warnings > 0:
208            status = 'FAILED'
209        logger.info("-" * 20)
210        logger.info("%s: Batch processing finished: %s" % (user, status))
211        logger.info("%s: Source: %s" % (user, path))
212        logger.info("%s: Mode: %s" % (user, mode))
213        logger.info("%s: User: %s" % (user, user))
214        if warnings > 0:
215            logger.info("%s: Failed datasets: %s" % (
216                    user, os.path.basename(fail_path)))
217        logger.info("%s: Processing time: %0.3f s (%0.4f s/item)" % (
218                user, timedelta, timedelta/(num or 1)))
219        logger.info("%s: Processed: %s lines (%s successful/ %s failed)" % (
220                user, num, num - warnings, warnings
221                ))
222        logger.info("-" * 20)
223        return
224
225    def writeFailedRow(self, writer, row, warnings):
226        """Write a row with error messages to error CSV.
227
228        If warnings is a list of strings, they will be concatenated.
229        """
230        error_col = warnings
231        if isinstance(warnings, list):
232            error_col = ' / '.join(warnings)
233        row['--ERRORS--'] = error_col
234        writer.writerow(row)
235        return
236
237    def checkConversion(self, row, mode='ignore'):
238        """Validates all values in row.
239        """
240        converter = IObjectConverter(self.iface)
241        errs, inv_errs, conv_dict =  converter.fromStringDict(
242            row, self.factory_name)
243        return errs, inv_errs, conv_dict
244
245    def doImport(self, path, headerfields, mode='create', user='Unknown',
246                 logger=None):
247        """Perform actual import.
248        """
249        time_start = time.time()
250        self.checkHeaders(headerfields, mode)
251        mapping = self.getMapping(path, headerfields, mode)
252        reader = csv.DictReader(open(path, 'rb'))
253
254        temp_dir = tempfile.mkdtemp()
255
256        base = os.path.basename(path)
257        (base, ext) = os.path.splitext(base)
258        failed_path = os.path.join(temp_dir, "%s.pending%s" % (base, ext))
259        failed_headers = mapping.values()
260        failed_headers.append('--ERRORS--')
261        failed_writer = csv.DictWriter(open(failed_path, 'wb'),
262                                       failed_headers)
263        failed_writer.writerow(dict([(x,x) for x in failed_headers]))
264
265        finished_path = os.path.join(temp_dir, "%s.finished%s" % (base, ext))
266        finished_headers = mapping.values()
267        finished_writer = csv.DictWriter(open(finished_path, 'wb'),
268                                         finished_headers)
269        finished_writer.writerow(dict([(x,x) for x in finished_headers]))
270
271        num =0
272        num_warns = 0
273        site = grok.getSite()
274       
275        for raw_row in reader:
276            num += 1
277            string_row = self.applyMapping(raw_row, mapping)
278            row = dict(string_row.items()) # create deep copy
279            errs, inv_errs, conv_dict = self.checkConversion(string_row, mode)
280            if errs or inv_errs:
281                num_warns += 1
282                conv_warnings = self.stringFromErrs(errs, inv_errs)
283                self.writeFailedRow(
284                    failed_writer, string_row, conv_warnings)
285                continue
286            row.update(conv_dict)
287
288            if mode == 'create':
289                if not self.parentsExist(row, site):
290                    num_warns += 1
291                    self.writeFailedRow(
292                        failed_writer, string_row,
293                        "Not all parents do exist yet. Skipping")
294                    continue
295                if self.entryExists(row, site):
296                    num_warns += 1
297                    self.writeFailedRow(
298                        failed_writer, string_row,
299                        "This object already exists in the same container. Skipping.")
300                    continue
301                obj = self.callFactory()
302                # Override all values in row, also
303                # student_ids and applicant_ids which have been
304                # generated in the respective __init__ methods before.
305                self.updateEntry(obj, row, site)
306                try:
307                    self.addEntry(obj, row, site)
308                except KeyError, error:
309                    num_warns += 1
310                    self.writeFailedRow(
311                        failed_writer, string_row,
312                        "%s Skipping." % error.message)
313                    continue
314            elif mode == 'remove':
315                if not self.entryExists(row, site):
316                    num_warns += 1
317                    self.writeFailedRow(
318                        failed_writer, string_row,
319                        "Cannot remove: no such entry.")
320                    continue
321                self.delEntry(row, site)
322            elif mode == 'update':
323                obj = self.getEntry(row, site)
324                if obj is None:
325                    num_warns += 1
326                    self.writeFailedRow(
327                        failed_writer, string_row,
328                        "Cannot update: no such entry.")
329                    continue
330                self.updateEntry(obj, row, site)
331            finished_writer.writerow(string_row)
332
333        time_end = time.time()
334        timedelta = time_end - time_start
335
336        self.createLogfile(path, failed_path, num, num_warns, mode, user,
337                           timedelta, logger=logger)
338        failed_path = os.path.abspath(failed_path)
339        if num_warns == 0:
340            del failed_writer
341            os.unlink(failed_path)
342            failed_path = None
343        return (num, num_warns,
344                os.path.abspath(finished_path), failed_path)
Note: See TracBrowser for help on using the repository browser.