1   
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir)  
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue  
 23   
 24   
 26      """Threading tests""" 
 27      etree = etree 
 28   
 30          thread = threading.Thread(target=func) 
 31          thread.start() 
 32          thread.join() 
  33   
 35          sync = threading.Event() 
 36          lock = threading.Lock() 
 37          counter = dict(started=0, finished=0, failed=0) 
 38   
 39          def sync_start(func): 
 40              with lock: 
 41                  started = counter['started'] + 1 
 42                  counter['started'] = started 
 43              if started < count + (main_func is not None): 
 44                  sync.wait(4)   
 45                  assert sync.is_set() 
 46              sync.set()   
 47              try: 
 48                  func() 
 49              except: 
 50                  with lock: 
 51                      counter['failed'] += 1 
 52                  raise 
 53              else: 
 54                  with lock: 
 55                      counter['finished'] += 1 
  56   
 57          threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 
 58          for thread in threads: 
 59              thread.start() 
 60          if main_func is not None: 
 61              sync_start(main_func) 
 62          for thread in threads: 
 63              thread.join() 
 64   
 65          self.assertEqual(0, counter['failed']) 
 66          self.assertEqual(counter['finished'], counter['started']) 
  67   
 78   
 79          self._run_thread(run_thread) 
 80          self.assertEqual(xml, tostring(main_root)) 
 81   
 83          XML = self.etree.XML 
 84          style = XML(_bytes('''\ 
 85  <xsl:stylesheet version="1.0" 
 86      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
 87    <xsl:template match="*"> 
 88      <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 
 89    </xsl:template> 
 90  </xsl:stylesheet>''')) 
 91          st = etree.XSLT(style) 
 92   
 93          result = [] 
 94   
 95          def run_thread(): 
 96              root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
 97              result.append( st(root) ) 
  98   
 99          self._run_thread(run_thread) 
100          self.assertEqual('''\ 
101  <?xml version="1.0"?> 
102  <foo><a>B</a></foo> 
103  ''', 
104                            str(result[0])) 
105   
121   
122          self._run_thread(run_thread) 
123          self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 
124                            tostring(root)) 
125   
127          style = self.parse('''\ 
128  <xsl:stylesheet version="1.0" 
129      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
130      <xsl:template match="tag" /> 
131      <!-- extend time for parsing + transform --> 
132  ''' + '\n'.join('<xsl:template match="tag%x" />' % i for i in range(200)) + ''' 
133      <xsl:UnExpectedElement /> 
134  </xsl:stylesheet>''') 
135          self.assertRaises(etree.XSLTParseError, 
136                            etree.XSLT, style) 
137   
138          error_logs = [] 
139   
140          def run_thread(): 
141              try: 
142                  etree.XSLT(style) 
143              except etree.XSLTParseError as e: 
144                  error_logs.append(e.error_log) 
145              else: 
146                  self.assertFalse(True, "XSLT parsing should have failed but didn't") 
 147   
148          self._run_threads(16, run_thread) 
149   
150          self.assertEqual(16, len(error_logs)) 
151          last_log = None 
152          for log in error_logs: 
153              self.assertTrue(len(log)) 
154              if last_log is not None: 
155                  self.assertEqual(len(last_log), len(log)) 
156              self.assertTrue(len(log) >= 2, len(log)) 
157              for error in log: 
158                  self.assertTrue(':ERROR:XSLT:' in str(error), str(error)) 
159              self.assertTrue(any('UnExpectedElement' in str(error) for error in log), log) 
160              last_log = log 
161   
163          tree = self.parse('<tagFF/>') 
164          style = self.parse('''\ 
165  <xsl:stylesheet version="1.0" 
166      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
167      <xsl:template name="tag0"> 
168          <xsl:message terminate="yes">FAIL</xsl:message> 
169      </xsl:template> 
170      <!-- extend time for parsing + transform --> 
171  ''' + '\n'.join('<xsl:template match="tag%X" name="tag%x"> <xsl:call-template name="tag%x" /> </xsl:template>' % (i, i, i-1) 
172                  for i in range(1, 256)) + ''' 
173  </xsl:stylesheet>''') 
174          self.assertRaises(etree.XSLTApplyError, 
175                            etree.XSLT(style), tree) 
176   
177          error_logs = [] 
178   
179          def run_thread(): 
180              transform = etree.XSLT(style) 
181              try: 
182                  transform(tree) 
183              except etree.XSLTApplyError: 
184                  error_logs.append(transform.error_log) 
185              else: 
186                  self.assertFalse(True, "XSLT parsing should have failed but didn't") 
 187   
188          self._run_threads(16, run_thread) 
189   
190          self.assertEqual(16, len(error_logs)) 
191          last_log = None 
192          for log in error_logs: 
193              self.assertTrue(len(log)) 
194              if last_log is not None: 
195                  self.assertEqual(len(last_log), len(log)) 
196              self.assertEqual(1, len(log)) 
197              for error in log: 
198                  self.assertTrue(':ERROR:XSLT:' in str(error)) 
199              last_log = log 
200   
202           
203           
204          XML = self.etree.XML 
205          tostring = self.etree.tostring 
206          style = self.etree.XSLT(XML(_bytes('''\ 
207      <xsl:stylesheet version="1.0" 
208          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
209        <xsl:template match="*"> 
210          <root class="abc"> 
211            <xsl:copy-of select="@class" /> 
212            <xsl:attribute name="class">xyz</xsl:attribute>  
213          </root> 
214        </xsl:template> 
215      </xsl:stylesheet>'''))) 
216   
217          result = [] 
218          def run_thread(): 
219              root = XML(_bytes('<ROOT class="ABC" />')) 
220              result.append( style(root).getroot() ) 
 221   
222          self._run_thread(run_thread) 
223          self.assertEqual(_bytes('<root class="xyz"/>'), 
224                            tostring(result[0])) 
225   
227          XML = self.etree.XML 
228          tostring = self.etree.tostring 
229          root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
230   
231          stylesheets = [] 
232   
233          def run_thread(): 
234              style = XML(_bytes('''\ 
235      <xsl:stylesheet 
236          xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 
237          version="1.0"> 
238        <xsl:output method="xml" /> 
239        <xsl:template match="/"> 
240           <div id="test"> 
241             <xsl:apply-templates/> 
242           </div> 
243        </xsl:template> 
244      </xsl:stylesheet>''')) 
245              stylesheets.append( etree.XSLT(style) ) 
 246   
247          self._run_thread(run_thread) 
248   
249          st = stylesheets[0] 
250          result = tostring( st(root) ) 
251   
252          self.assertEqual(_bytes('<div id="test">BC</div>'), 
253                            result) 
254   
277   
278          self.etree.clear_error_log() 
279          threads = [] 
280          for thread_no in range(1, 10): 
281              t = threading.Thread(target=parse_error_test, 
282                                   args=(thread_no,)) 
283              threads.append(t) 
284              t.start() 
285   
286          parse_error_test(0) 
287   
288          for t in threads: 
289              t.join() 
290   
306   
307          def run_parse(): 
308              thread_root = self.etree.parse(BytesIO(xml)).getroot() 
309              result.append(thread_root[0]) 
310              result.append(thread_root[-1]) 
311   
312          def run_move_main(): 
313              result.append(fragment[0]) 
314   
315          def run_build(): 
316              result.append( 
317                  Element("{myns}foo", attrib={'{test}attr':'val'})) 
318              SubElement(result, "{otherns}tasty") 
319   
320          def run_xslt(): 
321              style = XML(_bytes('''\ 
322      <xsl:stylesheet version="1.0" 
323          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
324        <xsl:template match="*"> 
325          <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 
326        </xsl:template> 
327      </xsl:stylesheet>''')) 
328              st = etree.XSLT(style) 
329              result.append( st(root).getroot() ) 
330   
331          for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 
332              tostring(result) 
333              self._run_thread(test) 
334   
335          self.assertEqual( 
336              _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 
337                     '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 
338                     '<a><foo>B</foo></a>' 
339                     '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 
340                     '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 
341              tostring(result)) 
342   
343          def strip_first(): 
344              root = Element("newroot") 
345              root.append(result[0]) 
346   
347          while len(result): 
348              self._run_thread(strip_first) 
349   
350          self.assertEqual( 
351              _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 
352              tostring(result)) 
353   
355          SubElement = self.etree.SubElement 
356          names = list('abcdefghijklmnop') 
357          runs_per_name = range(50) 
358          result_matches = re.compile( 
359              br'<thread_root>' 
360              br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 
361              br'</thread_root>').match 
362   
363          def testrun(): 
364              for _ in range(3): 
365                  root = self.etree.Element('thread_root') 
366                  for name in names: 
367                      tag_name = name * 5 
368                      new = [] 
369                      for _ in runs_per_name: 
370                          el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 
371                          new.append(el) 
372                      for el in new: 
373                          el.set('thread_attr2_' + name, 'value2') 
374                  s = etree.tostring(root) 
375                  self.assertTrue(result_matches(s)) 
 376   
377           
378          self._run_threads(10, testrun) 
379   
380           
381          self._run_threads(10, testrun, main_func=testrun) 
382   
384          XML = self.etree.XML 
385          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 
386          child_count = len(root) 
387          def testrun(): 
388              for i in range(10000): 
389                  el = root[i%child_count] 
390                  del el 
 391          self._run_threads(10, testrun) 
392   
394          XML = self.etree.XML 
395   
396          class TestElement(etree.ElementBase): 
397              pass 
 398   
399          class MyLookup(etree.CustomElementClassLookup): 
400              repeat = range(100) 
401              def lookup(self, t, d, ns, name): 
402                  count = 0 
403                  for i in self.repeat: 
404                       
405                      count += 1 
406                  return TestElement 
407   
408          parser = self.etree.XMLParser() 
409          parser.set_element_class_lookup(MyLookup()) 
410   
411          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 
412                     parser) 
413   
414          child_count = len(root) 
415          def testrun(): 
416              for i in range(1000): 
417                  el = root[i%child_count] 
418                  del el 
419          self._run_threads(10, testrun) 
420   
421   
423      """Threading tests based on a thread worker pipeline. 
424      """ 
425      etree = etree 
426      item_count = 40 
427   
428 -    class Worker(threading.Thread): 
 429 -        def __init__(self, in_queue, in_count, **kwargs): 
 430              threading.Thread.__init__(self) 
431              self.in_queue = in_queue 
432              self.in_count = in_count 
433              self.out_queue = Queue(in_count) 
434              self.__dict__.update(kwargs) 
 435   
437              get, put = self.in_queue.get, self.out_queue.put 
438              handle = self.handle 
439              for _ in range(self.in_count): 
440                  put(handle(get())) 
 441   
443              raise NotImplementedError() 
 447              return _fromstring(xml) 
 520          item_count = self.item_count 
521          xml = self.xml.replace(b'thread', b'THREAD')   
522   
523           
524          in_queue, start, last = self._build_pipeline( 
525              item_count, 
526              self.ParseWorker, 
527              self.RotateWorker, 
528              self.ReverseWorker, 
529              self.ParseAndExtendWorker, 
530              self.Validate, 
531              self.ParseAndInjectWorker, 
532              self.SerialiseWorker, 
533              xml=xml) 
534   
535           
536          put = start.in_queue.put 
537          for _ in range(item_count): 
538              put(xml) 
539   
540           
541          start.start() 
542           
543          last.join(60)   
544          self.assertEqual(item_count, last.out_queue.qsize()) 
545           
546          get = last.out_queue.get 
547          results = [get() for _ in range(item_count)] 
548   
549          comparison = results[0] 
550          for i, result in enumerate(results[1:]): 
551              self.assertEqual(comparison, result) 
 552   
554          item_count = self.item_count 
555          xml = self.xml.replace(b'thread', b'GLOBAL')   
556          XML = self.etree.XML 
557           
558          in_queue, start, last = self._build_pipeline( 
559              item_count, 
560              self.RotateWorker, 
561              self.ReverseWorker, 
562              self.ParseAndExtendWorker, 
563              self.Validate, 
564              self.SerialiseWorker, 
565              xml=xml) 
566   
567           
568          put = start.in_queue.put 
569          for _ in range(item_count): 
570              put(XML(xml)) 
571   
572           
573          start.start() 
574           
575          last.join(60)   
576          self.assertEqual(item_count, last.out_queue.qsize()) 
577           
578          get = last.out_queue.get 
579          results = [get() for _ in range(item_count)] 
580   
581          comparison = results[0] 
582          for i, result in enumerate(results[1:]): 
583              self.assertEqual(comparison, result) 
  584   
585   
591   
592  if __name__ == '__main__': 
593      print('to test use test.py %s' % __file__) 
594