source: main/waeup.sirp/trunk/src/waeup/sirp/utils/batching.py @ 7310

Last change on this file since 7310 was 7273, checked in by Henrik Bettermann, 13 years ago

Add test for student data migration to be sure that student_ids provided in import files are correctly produced and that the random id generator is thus neutralized.

  • Property svn:keywords set to Id
File size: 12.3 KB
RevLine 
[7196]1## $Id: batching.py 7273 2011-12-04 21:08:53Z henrik $
2##
3## Copyright (C) 2011 Uli Fouquet & Henrik Bettermann
4## This program is free software; you can redistribute it and/or modify
5## it under the terms of the GNU General Public License as published by
6## the Free Software Foundation; either version 2 of the License, or
7## (at your option) any later version.
8##
9## This program is distributed in the hope that it will be useful,
10## but WITHOUT ANY WARRANTY; without even the implied warranty of
11## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12## GNU General Public License for more details.
13##
14## You should have received a copy of the GNU General Public License
15## along with this program; if not, write to the Free Software
16## Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17##
[4806]18"""WAeUP components for batch processing.
19
20Batch processors eat CSV files to add, update or remove large numbers
21of certain kinds of objects at once.
22"""
23import grok
[4870]24import copy
[4806]25import csv
[4821]26import os
27import sys
[4900]28import tempfile
[4821]29import time
[4806]30from zope.component import createObject
31from zope.interface import Interface
32from zope.schema import getFields
[5005]33from waeup.sirp.interfaces import (
[6276]34    IBatchProcessor, FatalCSVError, DuplicationError, IObjectConverter)
[4806]35
36class BatchProcessor(grok.GlobalUtility):
37    """A processor to add, update, or remove data.
38
39    This is a non-active baseclass.
40    """
[4831]41    grok.provides(IBatchProcessor)
[4806]42    grok.context(Interface)
43    grok.baseclass()
44
45    # Name used in pages and forms...
[5009]46    name = u'Non-registered base importer'
[6259]47
[4806]48    # Internal name...
[5009]49    util_name = 'baseimporter'
[6259]50
[4806]51    # Items for this processor need an interface with zope.schema fields.
[5009]52    iface = Interface
[6259]53
[4806]54    # The name must be the same as the util_name attribute in order to
55    # register this utility correctly.
56    grok.name(util_name)
57
58    # Headers needed to locate items...
59    location_fields = ['code', 'faculty_code']
[6259]60
[4806]61    # A factory with this name must be registered...
62    factory_name = 'waeup.Department'
63
64    @property
65    def required_fields(self):
[4829]66        """Required fields that have no default.
67
68        A list of names of field, whose value cannot be set if not
69        given during creation. Therefore these fields must exist in
70        input.
71
72        Fields with a default != missing_value do not belong to this
73        category.
74        """
[4806]75        result = []
76        for key, field in getFields(self.iface).items():
77            if key in self.location_fields:
78                continue
[4829]79            if field.default is not field.missing_value:
80                continue
[4806]81            if field.required:
82                result.append(key)
83        return result
[6259]84
[4806]85    @property
86    def req(self):
87        result = dict(
88            create = self.location_fields + self.required_fields,
89            update = self.location_fields,
90            remove = self.location_fields,
91        )
92        return result
93
94    @property
95    def available_fields(self):
96        result = []
97        return sorted(list(set(
98                    self.location_fields + getFields(self.iface).keys())))
[6259]99
[4806]100    def getHeaders(self, mode='create'):
101        return self.available_fields
102
103    def checkHeaders(self, headerfields, mode='create'):
104        req = self.req[mode]
105        # Check for required fields...
106        for field in req:
107            if not field in headerfields:
108                raise FatalCSVError(
109                    "Need at least columns %s for import!" %
110                    ', '.join(["'%s'" % x for x in req]))
[6828]111        # Check for double fields. Cannot happen because this error is
112        # already catched in views
[4806]113        not_ignored_fields = [x for x in headerfields
114                              if not x.startswith('--')]
115        if len(set(not_ignored_fields)) < len(not_ignored_fields):
116            raise FatalCSVError(
117                "Double headers: each column name may only appear once.")
118        return True
119
120    def applyMapping(self, row, mapping):
[4811]121        """Apply mapping to a row of CSV data.
[6824]122
[4811]123        """
[4806]124        result = dict()
125        for key, replacement in mapping.items():
[6824]126            if replacement == u'--IGNORE--':
127                # Skip ignored columns in failed and finished data files.
128                continue
[4806]129            result[replacement] = row[key]
130        return result
[6259]131
[4832]132    def getMapping(self, path, headerfields, mode):
[6824]133        """Get a mapping from CSV file headerfields to actually used fieldnames.
134
[4811]135        """
[4832]136        result = dict()
[4806]137        reader = csv.reader(open(path, 'rb'))
138        raw_header = reader.next()
[4832]139        for num, field in enumerate(headerfields):
140            if field not in self.location_fields and mode == 'remove':
[6824]141                # Skip non-location fields when removing.
142                continue
143            if field == u'--IGNORE--':
144                # Skip ignored columns in failed and finished data files.
145                continue
[4832]146            result[raw_header[num]] = field
147        return result
[4806]148
[6273]149    def stringFromErrs(self, errors, inv_errors):
150        result = []
151        for err in errors:
152            fieldname, message = err
153            result.append("%s: %s" % (fieldname, message))
154        for err in inv_errors:
155            result.append("invariant: %s" % err)
156        return '; '.join(result)
157
[4806]158    def callFactory(self, *args, **kw):
159        return createObject(self.factory_name)
160
161    def parentsExist(self, row, site):
[4811]162        """Tell whether the parent object for data in ``row`` exists.
163        """
[4806]164        raise NotImplementedError('method not implemented')
165
166    def entryExists(self, row, site):
[4811]167        """Tell whether there already exists an entry for ``row`` data.
168        """
[4806]169        raise NotImplementedError('method not implemented')
170
171    def getParent(self, row, site):
[4811]172        """Get the parent object for the entry in ``row``.
173        """
[4806]174        raise NotImplementedError('method not implemented')
[6259]175
[5009]176    def getEntry(self, row, site):
177        """Get the parent object for the entry in ``row``.
178        """
179        raise NotImplementedError('method not implemented')
[6259]180
[4806]181    def addEntry(self, obj, row, site):
[4811]182        """Add the entry given given by ``row`` data.
183        """
[4806]184        raise NotImplementedError('method not implemented')
185
186    def delEntry(self, row, site):
[4811]187        """Delete entry given by ``row`` data.
188        """
[6259]189        raise NotImplementedError('method not implemented')
[4806]190
191    def updateEntry(self, obj, row, site):
[4984]192        """Update obj to the values given in row.
193        """
[4829]194        for key, value in row.items():
[6847]195            # Skip fields not declared in interface.
[6833]196            if hasattr(obj, key):
197                setattr(obj, key, value)
[4829]198        return
[4821]199
[4832]200    def createLogfile(self, path, fail_path, num, warnings, mode, user,
[4885]201                      timedelta, logger=None):
202        """Write to log file.
[4821]203        """
[4885]204        if logger is None:
205            return
206        status = 'OK'
207        if warnings > 0:
208            status = 'FAILED'
209        logger.info("-" * 20)
210        logger.info("%s: Batch processing finished: %s" % (user, status))
211        logger.info("%s: Source: %s" % (user, path))
212        logger.info("%s: Mode: %s" % (user, mode))
213        logger.info("%s: User: %s" % (user, user))
214        if warnings > 0:
[4900]215            logger.info("%s: Failed datasets: %s" % (
216                    user, os.path.basename(fail_path)))
[4885]217        logger.info("%s: Processing time: %0.3f s (%0.4f s/item)" % (
218                user, timedelta, timedelta/(num or 1)))
219        logger.info("%s: Processed: %s lines (%s successful/ %s failed)" % (
220                user, num, num - warnings, warnings
[4821]221                ))
[4885]222        logger.info("-" * 20)
[4821]223        return
[4877]224
225    def writeFailedRow(self, writer, row, warnings):
226        """Write a row with error messages to error CSV.
227
228        If warnings is a list of strings, they will be concatenated.
229        """
230        error_col = warnings
231        if isinstance(warnings, list):
232            error_col = ' / '.join(warnings)
233        row['--ERRORS--'] = error_col
234        writer.writerow(row)
235        return
[6259]236
[6847]237    def checkConversion(self, row, mode='ignore'):
238        """Validates all values in row.
239        """
240        converter = IObjectConverter(self.iface)
241        errs, inv_errs, conv_dict =  converter.fromStringDict(
242            row, self.factory_name)
243        return errs, inv_errs, conv_dict
244
[4885]245    def doImport(self, path, headerfields, mode='create', user='Unknown',
246                 logger=None):
[4811]247        """Perform actual import.
248        """
[4832]249        time_start = time.time()
[4806]250        self.checkHeaders(headerfields, mode)
[4832]251        mapping = self.getMapping(path, headerfields, mode)
[4806]252        reader = csv.DictReader(open(path, 'rb'))
[4889]253
[4900]254        temp_dir = tempfile.mkdtemp()
[6259]255
[6273]256        base = os.path.basename(path)
257        (base, ext) = os.path.splitext(base)
[4900]258        failed_path = os.path.join(temp_dir, "%s.pending%s" % (base, ext))
[6831]259        failed_headers = mapping.values()
[4877]260        failed_headers.append('--ERRORS--')
[4821]261        failed_writer = csv.DictWriter(open(failed_path, 'wb'),
262                                       failed_headers)
[6831]263        failed_writer.writerow(dict([(x,x) for x in failed_headers]))
[4891]264
[4900]265        finished_path = os.path.join(temp_dir, "%s.finished%s" % (base, ext))
[6831]266        finished_headers = mapping.values()
[4891]267        finished_writer = csv.DictWriter(open(finished_path, 'wb'),
268                                         finished_headers)
269        finished_writer.writerow(dict([(x,x) for x in finished_headers]))
[6259]270
[4806]271        num =0
[4878]272        num_warns = 0
[4806]273        site = grok.getSite()
[6847]274       
[4806]275        for raw_row in reader:
276            num += 1
277            string_row = self.applyMapping(raw_row, mapping)
[6273]278            row = dict(string_row.items()) # create deep copy
[6847]279            errs, inv_errs, conv_dict = self.checkConversion(string_row, mode)
[6273]280            if errs or inv_errs:
[4878]281                num_warns += 1
[6273]282                conv_warnings = self.stringFromErrs(errs, inv_errs)
283                self.writeFailedRow(
[6824]284                    failed_writer, string_row, conv_warnings)
[4821]285                continue
[6273]286            row.update(conv_dict)
[6259]287
[4806]288            if mode == 'create':
289                if not self.parentsExist(row, site):
[4878]290                    num_warns += 1
[4877]291                    self.writeFailedRow(
[6824]292                        failed_writer, string_row,
[4877]293                        "Not all parents do exist yet. Skipping")
[4806]294                    continue
295                if self.entryExists(row, site):
[4878]296                    num_warns += 1
[4877]297                    self.writeFailedRow(
[6824]298                        failed_writer, string_row,
[6219]299                        "This object already exists in the same container. Skipping.")
[4806]300                    continue
301                obj = self.callFactory()
[7273]302                # Override all values in row, also
303                # student_ids and applicant_ids which have been
304                # generated in the respective __init__ methods before.
305                self.updateEntry(obj, row, site)
[6243]306                try:
307                    self.addEntry(obj, row, site)
[6273]308                except KeyError, error:
[6219]309                    num_warns += 1
310                    self.writeFailedRow(
[6824]311                        failed_writer, string_row,
[6273]312                        "%s Skipping." % error.message)
[6219]313                    continue
[4806]314            elif mode == 'remove':
315                if not self.entryExists(row, site):
[4878]316                    num_warns += 1
[4877]317                    self.writeFailedRow(
[6824]318                        failed_writer, string_row,
[4877]319                        "Cannot remove: no such entry.")
[4806]320                    continue
321                self.delEntry(row, site)
322            elif mode == 'update':
323                obj = self.getEntry(row, site)
324                if obj is None:
[4878]325                    num_warns += 1
[4877]326                    self.writeFailedRow(
[6824]327                        failed_writer, string_row,
[4877]328                        "Cannot update: no such entry.")
[4806]329                    continue
330                self.updateEntry(obj, row, site)
[4891]331            finished_writer.writerow(string_row)
[4821]332
[4832]333        time_end = time.time()
334        timedelta = time_end - time_start
[6259]335
[4878]336        self.createLogfile(path, failed_path, num, num_warns, mode, user,
[4885]337                           timedelta, logger=logger)
[4894]338        failed_path = os.path.abspath(failed_path)
[4878]339        if num_warns == 0:
[4821]340            del failed_writer
341            os.unlink(failed_path)
[4894]342            failed_path = None
343        return (num, num_warns,
344                os.path.abspath(finished_path), failed_path)
Note: See TracBrowser for help on using the repository browser.