|
22 | 22 | import os |
23 | 23 | import tempfile |
24 | 24 | import unittest |
| 25 | +from typing import Iterable |
| 26 | +from typing import Literal |
25 | 27 | from typing import TypeVar |
26 | 28 |
|
27 | 29 | import pytest |
28 | 30 |
|
29 | 31 | import apache_beam as beam |
30 | 32 | from apache_beam.coders import coders |
| 33 | +from apache_beam.options.pipeline_options import PipelineOptions |
31 | 34 | from apache_beam.testing.util import assert_that |
32 | 35 | from apache_beam.testing.util import equal_to |
33 | 36 | from apache_beam.transforms.resources import ResourceHint |
@@ -515,6 +518,222 @@ class TagHint(ResourceHint): |
515 | 518 | ) |
516 | 519 |
|
517 | 520 |
|
| 521 | +class ExceptionHandlingWithOutputsTest(unittest.TestCase): |
| 522 | + """Tests for combining with_exception_handling() and with_outputs().""" |
| 523 | + def _create_dofn_with_tagged_outputs(self): |
| 524 | + """A DoFn that yields tagged outputs and can raise on even numbers.""" |
| 525 | + class DoWithFailures(beam.DoFn): |
| 526 | + def process( |
| 527 | + self, element: int |
| 528 | + ) -> Iterable[int |
| 529 | + | beam.pvalue.TaggedOutput[Literal['threes'], int] |
| 530 | + | beam.pvalue.TaggedOutput[Literal['fives'], str]]: |
| 531 | + if element % 2 == 0: |
| 532 | + raise ValueError(f'Even numbers not allowed {element}') |
| 533 | + if element % 3 == 0: |
| 534 | + yield beam.pvalue.TaggedOutput('threes', element) # type: ignore[misc] |
| 535 | + elif element % 5 == 0: |
| 536 | + yield beam.pvalue.TaggedOutput('fives', str(element)) # type: ignore[misc] |
| 537 | + else: |
| 538 | + yield element |
| 539 | + |
| 540 | + return DoWithFailures() |
| 541 | + |
| 542 | + def test_with_exception_handling_then_with_outputs(self): |
| 543 | + """Direction 1: .with_exception_handling().with_outputs()""" |
| 544 | + |
| 545 | + with beam.Pipeline() as p: |
| 546 | + results = ( |
| 547 | + p |
| 548 | + | beam.Create([1, 2, 3, 4, 5, 6, 7]) |
| 549 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()). |
| 550 | + with_exception_handling().with_outputs( |
| 551 | + 'threes', 'fives', main='main')) |
| 552 | + |
| 553 | + assert_that(results.main, equal_to([1, 7]), 'main') |
| 554 | + assert_that(results.threes, equal_to([3]), 'threes') |
| 555 | + assert_that(results.fives, equal_to(['5']), 'fives') |
| 556 | + bad_elements = results.bad | beam.Keys() |
| 557 | + assert_that(bad_elements, equal_to([2, 4, 6]), 'bad') |
| 558 | + # Verify type hints from annotations are propagated |
| 559 | + self.assertEqual(results.main.element_type, int) |
| 560 | + self.assertEqual(results.threes.element_type, int) |
| 561 | + self.assertEqual(results.fives.element_type, str) |
| 562 | + self.assertEqual( |
| 563 | + results.bad.element_type, |
| 564 | + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) |
| 565 | + |
| 566 | + def test_with_outputs_then_with_exception_handling(self): |
| 567 | + """Direction 2: .with_outputs().with_exception_handling()""" |
| 568 | + |
| 569 | + with beam.Pipeline() as p: |
| 570 | + results = ( |
| 571 | + p |
| 572 | + | beam.Create([1, 2, 3, 4, 5, 6, 7]) |
| 573 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( |
| 574 | + 'threes', 'fives', main='main').with_exception_handling()) |
| 575 | + |
| 576 | + assert_that(results.main, equal_to([1, 7]), 'main') |
| 577 | + assert_that(results.threes, equal_to([3]), 'threes') |
| 578 | + assert_that(results.fives, equal_to(['5']), 'fives') |
| 579 | + bad_elements = results.bad | beam.Keys() |
| 580 | + assert_that(bad_elements, equal_to([2, 4, 6]), 'bad') |
| 581 | + # Verify type hints from annotations are propagated |
| 582 | + self.assertEqual(results.main.element_type, int) |
| 583 | + self.assertEqual(results.threes.element_type, int) |
| 584 | + self.assertEqual(results.fives.element_type, str) |
| 585 | + self.assertEqual( |
| 586 | + results.bad.element_type, |
| 587 | + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) |
| 588 | + |
| 589 | + def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag( |
| 590 | + self): |
| 591 | + """Direction 2 with custom dead_letter_tag.""" |
| 592 | + |
| 593 | + with beam.Pipeline() as p: |
| 594 | + results = ( |
| 595 | + p |
| 596 | + | beam.Create([1, 2, 3]) |
| 597 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( |
| 598 | + 'threes', |
| 599 | + main='main').with_exception_handling(dead_letter_tag='errors')) |
| 600 | + |
| 601 | + assert_that(results.main, equal_to([1]), 'main') |
| 602 | + assert_that(results.threes, equal_to([3]), 'threes') |
| 603 | + bad_elements = results.errors | beam.Keys() |
| 604 | + assert_that(bad_elements, equal_to([2]), 'errors') |
| 605 | + self.assertEqual(results.threes.element_type, int) |
| 606 | + self.assertEqual( |
| 607 | + results.errors.element_type, |
| 608 | + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) |
| 609 | + |
| 610 | + def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag( |
| 611 | + self): |
| 612 | + """Direction 1 with custom dead_letter_tag.""" |
| 613 | + |
| 614 | + with beam.Pipeline() as p: |
| 615 | + results = ( |
| 616 | + p |
| 617 | + | beam.Create([1, 2, 3]) |
| 618 | + | beam.ParDo( |
| 619 | + self._create_dofn_with_tagged_outputs()).with_exception_handling( |
| 620 | + dead_letter_tag='errors').with_outputs('threes', main='main')) |
| 621 | + |
| 622 | + assert_that(results.main, equal_to([1]), 'main') |
| 623 | + assert_that(results.threes, equal_to([3]), 'threes') |
| 624 | + bad_elements = results.errors | beam.Keys() |
| 625 | + assert_that(bad_elements, equal_to([2]), 'errors') |
| 626 | + self.assertEqual(results.threes.element_type, int) |
| 627 | + self.assertEqual( |
| 628 | + results.errors.element_type, |
| 629 | + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) |
| 630 | + |
| 631 | + def test_exception_handling_no_with_outputs_backward_compat(self): |
| 632 | + """Without with_outputs(), behavior is unchanged.""" |
| 633 | + |
| 634 | + with beam.Pipeline() as p: |
| 635 | + good, bad = ( |
| 636 | + p |
| 637 | + | beam.Create([1, 2, 7]) |
| 638 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()) |
| 639 | + .with_exception_handling()) |
| 640 | + |
| 641 | + assert_that(good, equal_to([1, 7]), 'good') |
| 642 | + bad_elements = bad | beam.Keys() |
| 643 | + assert_that(bad_elements, equal_to([2]), 'bad') |
| 644 | + |
| 645 | + def test_exception_handling_compat_version_uses_old_behavior(self): |
| 646 | + """With compat version < 2.73.0, old expand path is used.""" |
| 647 | + options = PipelineOptions(update_compatibility_version="2.72.0") |
| 648 | + with beam.Pipeline(options=options) as p: |
| 649 | + good, bad = ( |
| 650 | + p |
| 651 | + | beam.Create([1, 2, 7]) |
| 652 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()) |
| 653 | + .with_exception_handling()) |
| 654 | + |
| 655 | + assert_that(good, equal_to([1, 7]), 'good') |
| 656 | + bad_elements = bad | beam.Keys() |
| 657 | + assert_that(bad_elements, equal_to([2]), 'bad') |
| 658 | + |
| 659 | + def test_exception_handling_compat_version_element_type_set_manually(self): |
| 660 | + """With compat version < 2.73.0, element_type is set via manual override |
| 661 | + (the old behavior) rather than via with_output_types.""" |
| 662 | + |
| 663 | + options = PipelineOptions(update_compatibility_version="2.72.0") |
| 664 | + with beam.Pipeline(options=options) as p: |
| 665 | + results = ( |
| 666 | + p |
| 667 | + | beam.Create([1, 2, 3]) |
| 668 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()). |
| 669 | + with_exception_handling().with_outputs('threes', main='main')) |
| 670 | + |
| 671 | + # In old path, dead letter type is Any (no with_output_types call) |
| 672 | + self.assertEqual(results.bad.element_type, typehints.Any) |
| 673 | + # Tagged outputs still get types from DoFn Literal annotations |
| 674 | + # (via DoOutputsTuple.__getitem__ reading tagged_output_types) |
| 675 | + self.assertEqual(results.threes.element_type, int) |
| 676 | + # Main output type should still be inferred via manual override |
| 677 | + assert_that(results.main, equal_to([1]), 'main') |
| 678 | + |
| 679 | + def test_with_outputs_then_exception_handling_with_map(self): |
| 680 | + """with_outputs().with_exception_handling() also works on Map.""" |
| 681 | + with beam.Pipeline() as p: |
| 682 | + results = ( |
| 683 | + p |
| 684 | + | beam.Create([1, 2, 3, 4, 5]) |
| 685 | + | beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs( |
| 686 | + main='main').with_exception_handling()) |
| 687 | + assert_that(results.main, equal_to([1, 3, 5]), 'main') |
| 688 | + bad_elements = results.bad | beam.Keys() |
| 689 | + assert_that(bad_elements, equal_to([2, 4]), 'bad') |
| 690 | + |
| 691 | + def test_with_output_types_chained_on_pardo(self): |
| 692 | + """When type hints are chained on the ParDo (not annotations on the DoFn), |
| 693 | + tagged output types should still be propagated through |
| 694 | + with_exception_handling().with_outputs().""" |
| 695 | + class DoWithFailuresNoAnnotations(beam.DoFn): |
| 696 | + def process(self, element): |
| 697 | + if element % 2 == 0: |
| 698 | + raise ValueError(f'Even numbers not allowed {element}') |
| 699 | + if element % 3 == 0: |
| 700 | + yield beam.pvalue.TaggedOutput('threes', element) |
| 701 | + else: |
| 702 | + yield element |
| 703 | + |
| 704 | + with beam.Pipeline() as p: |
| 705 | + results = ( |
| 706 | + p |
| 707 | + | beam.Create([1, 2, 3, 7]) |
| 708 | + | beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types( |
| 709 | + int, threes=int).with_exception_handling().with_outputs( |
| 710 | + 'threes', main='main')) |
| 711 | + |
| 712 | + assert_that(results.main, equal_to([1, 7]), 'main') |
| 713 | + assert_that(results.threes, equal_to([3]), 'threes') |
| 714 | + bad_elements = results.bad | beam.Keys() |
| 715 | + assert_that(bad_elements, equal_to([2]), 'bad') |
| 716 | + self.assertEqual(results.main.element_type, int) |
| 717 | + self.assertEqual(results.threes.element_type, int) |
| 718 | + |
| 719 | + def test_with_outputs_and_error_handler(self): |
| 720 | + """with_outputs() + error_handler should return DoOutputsTuple, not a |
| 721 | + bare PCollection.""" |
| 722 | + from apache_beam.transforms.error_handling import ErrorHandler |
| 723 | + with beam.Pipeline() as p: |
| 724 | + with ErrorHandler(beam.Map(lambda x: x)) as handler: |
| 725 | + results = ( |
| 726 | + p |
| 727 | + | beam.Create([1, 2, 3, 4, 5, 6, 7]) |
| 728 | + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( |
| 729 | + 'threes', 'fives', |
| 730 | + main='main').with_exception_handling(error_handler=handler)) |
| 731 | + |
| 732 | + assert_that(results.main, equal_to([1, 7]), 'main') |
| 733 | + assert_that(results.threes, equal_to([3]), 'threes') |
| 734 | + assert_that(results.fives, equal_to(['5']), 'fives') |
| 735 | + |
| 736 | + |
518 | 737 | def test_callablewrapper_typehint(): |
519 | 738 | T = TypeVar("T") |
520 | 739 |
|
|
0 commit comments