Skip to content

Commit b090e22

Browse files
claudevdmClaude
andauthored
Allow tagged types in with_exception_handling. (#37590)
* Allow tagged types in with_exception_handling. * mypy * comments. * change comapt version --------- Co-authored-by: Claude <cvandermerwe@google.com>
1 parent b39a95f commit b090e22

2 files changed

Lines changed: 287 additions & 10 deletions

File tree

sdks/python/apache_beam/transforms/core.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,7 +1687,8 @@ def with_exception_handling(
16871687
error_handler,
16881688
on_failure_callback,
16891689
allow_unsafe_userstate_in_process,
1690-
self.get_resource_hints())
1690+
self.get_resource_hints(),
1691+
self.get_type_hints())
16911692

16921693
def with_error_handler(self, error_handler, **exception_handling_kwargs):
16931694
"""An alias for `with_exception_handling(error_handler=error_handler, ...)`
@@ -1979,6 +1980,15 @@ def expand(self, pcoll):
19791980
self._main_tag,
19801981
self._allow_unknown_tags)
19811982

1983+
def with_exception_handling(self, main_tag=None, **kwargs):
1984+
if main_tag is None:
1985+
main_tag = self._main_tag or 'good'
1986+
named = self._do_transform.with_exception_handling(
1987+
main_tag=main_tag, **kwargs)
1988+
# named is _NamedPTransform wrapping _ExceptionHandlingWrapper
1989+
named.transform._extra_tags = self._tags
1990+
return named
1991+
19821992

19831993
class DoFnInfo(object):
19841994
"""This class represents the state in the ParDoPayload's function spec,
@@ -2320,7 +2330,8 @@ def __init__(
23202330
error_handler,
23212331
on_failure_callback,
23222332
allow_unsafe_userstate_in_process,
2323-
resource_hints):
2333+
resource_hints,
2334+
pardo_type_hints=None):
23242335
if partial and use_subprocess:
23252336
raise ValueError('partial and use_subprocess are mutually incompatible.')
23262337
self._fn = fn
@@ -2338,8 +2349,17 @@ def __init__(
23382349
self._on_failure_callback = on_failure_callback
23392350
self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
23402351
self._resource_hints = resource_hints
2352+
self._pardo_type_hints = pardo_type_hints
2353+
self._extra_tags = None
23412354

2342-
def expand(self, pcoll):
2355+
def with_outputs(self, *tags, main=None):
2356+
self._extra_tags = tags
2357+
if main is not None:
2358+
self._main_tag = main
2359+
return self
2360+
2361+
def _build_pardo(self, pcoll):
2362+
"""Build the inner ParDo with the exception-handling wrapper DoFn."""
23432363
if self._allow_unsafe_userstate_in_process:
23442364
if self._use_subprocess or self._timeout:
23452365
# TODO(https://github.com/apache/beam/issues/35976): Implement this
@@ -2366,15 +2386,11 @@ def expand(self, pcoll):
23662386
*self._args,
23672387
**self._kwargs,
23682388
)
2369-
# This is the fix: propagate hints.
23702389
pardo.get_resource_hints().update(self._resource_hints)
2390+
return pardo
23712391

2372-
result = pcoll | pardo.with_outputs(
2373-
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
2374-
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
2375-
result[self._main_tag].element_type = self._fn.infer_output_type(
2376-
pcoll.element_type)
2377-
2392+
def _post_process_result(self, pcoll, result):
2393+
"""Apply threshold checking and error handler logic to the result."""
23782394
if self._threshold < 1.0:
23792395

23802396
class MaybeWindow(ptransform.PTransform):
@@ -2408,10 +2424,52 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):
24082424

24092425
if self._error_handler:
24102426
self._error_handler.add_error_pcollection(result[self._dead_letter_tag])
2427+
if self._extra_tags is not None:
2428+
return result
24112429
return result[self._main_tag]
24122430
else:
24132431
return result
24142432

2433+
def expand_2_72_0(self, pcoll):
2434+
"""Pre-2.73.0 behavior: manual element_type override, no with_output_types.
2435+
"""
2436+
pardo = self._build_pardo(pcoll)
2437+
result = pcoll | pardo.with_outputs(
2438+
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
2439+
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
2440+
result[self._main_tag].element_type = self._fn.infer_output_type(
2441+
pcoll.element_type)
2442+
2443+
return self._post_process_result(pcoll, result)
2444+
2445+
def expand(self, pcoll):
2446+
if pcoll.pipeline.options.is_compat_version_prior_to("2.73.0"):
2447+
return self.expand_2_72_0(pcoll)
2448+
2449+
pardo = self._build_pardo(pcoll)
2450+
2451+
if (self._pardo_type_hints and self._pardo_type_hints._has_output_types()):
2452+
main_output_type = self._pardo_type_hints.simple_output_type(self.label)
2453+
tagged_type_hints = dict(self._pardo_type_hints.tagged_output_types())
2454+
else:
2455+
main_output_type = self._fn.infer_output_type(pcoll.element_type)
2456+
tagged_type_hints = dict(self._fn.get_type_hints().tagged_output_types())
2457+
2458+
# Dead letter format: Tuple[element, Tuple[exception_type, repr, traceback]]
2459+
dead_letter_type = typehints.Tuple[pcoll.element_type,
2460+
typehints.Tuple[type,
2461+
str,
2462+
typehints.List[str]]]
2463+
2464+
tagged_type_hints[self._dead_letter_tag] = dead_letter_type
2465+
pardo = pardo.with_output_types(main_output_type, **tagged_type_hints)
2466+
2467+
all_tags = tuple(set(self._extra_tags or ()) | {self._dead_letter_tag})
2468+
result = pcoll | pardo.with_outputs(
2469+
*all_tags, main=self._main_tag, allow_unknown_tags=True)
2470+
2471+
return self._post_process_result(pcoll, result)
2472+
24152473

24162474
class _ExceptionHandlingWrapperDoFn(DoFn):
24172475
def __init__(

sdks/python/apache_beam/transforms/core_test.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
import os
2323
import tempfile
2424
import unittest
25+
from typing import Iterable
26+
from typing import Literal
2527
from typing import TypeVar
2628

2729
import pytest
2830

2931
import apache_beam as beam
3032
from apache_beam.coders import coders
33+
from apache_beam.options.pipeline_options import PipelineOptions
3134
from apache_beam.testing.util import assert_that
3235
from apache_beam.testing.util import equal_to
3336
from apache_beam.transforms.resources import ResourceHint
@@ -515,6 +518,222 @@ class TagHint(ResourceHint):
515518
)
516519

517520

521+
class ExceptionHandlingWithOutputsTest(unittest.TestCase):
522+
"""Tests for combining with_exception_handling() and with_outputs()."""
523+
def _create_dofn_with_tagged_outputs(self):
524+
"""A DoFn that yields tagged outputs and can raise on even numbers."""
525+
class DoWithFailures(beam.DoFn):
526+
def process(
527+
self, element: int
528+
) -> Iterable[int
529+
| beam.pvalue.TaggedOutput[Literal['threes'], int]
530+
| beam.pvalue.TaggedOutput[Literal['fives'], str]]:
531+
if element % 2 == 0:
532+
raise ValueError(f'Even numbers not allowed {element}')
533+
if element % 3 == 0:
534+
yield beam.pvalue.TaggedOutput('threes', element) # type: ignore[misc]
535+
elif element % 5 == 0:
536+
yield beam.pvalue.TaggedOutput('fives', str(element)) # type: ignore[misc]
537+
else:
538+
yield element
539+
540+
return DoWithFailures()
541+
542+
def test_with_exception_handling_then_with_outputs(self):
543+
"""Direction 1: .with_exception_handling().with_outputs()"""
544+
545+
with beam.Pipeline() as p:
546+
results = (
547+
p
548+
| beam.Create([1, 2, 3, 4, 5, 6, 7])
549+
| beam.ParDo(self._create_dofn_with_tagged_outputs()).
550+
with_exception_handling().with_outputs(
551+
'threes', 'fives', main='main'))
552+
553+
assert_that(results.main, equal_to([1, 7]), 'main')
554+
assert_that(results.threes, equal_to([3]), 'threes')
555+
assert_that(results.fives, equal_to(['5']), 'fives')
556+
bad_elements = results.bad | beam.Keys()
557+
assert_that(bad_elements, equal_to([2, 4, 6]), 'bad')
558+
# Verify type hints from annotations are propagated
559+
self.assertEqual(results.main.element_type, int)
560+
self.assertEqual(results.threes.element_type, int)
561+
self.assertEqual(results.fives.element_type, str)
562+
self.assertEqual(
563+
results.bad.element_type,
564+
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
565+
566+
def test_with_outputs_then_with_exception_handling(self):
567+
"""Direction 2: .with_outputs().with_exception_handling()"""
568+
569+
with beam.Pipeline() as p:
570+
results = (
571+
p
572+
| beam.Create([1, 2, 3, 4, 5, 6, 7])
573+
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
574+
'threes', 'fives', main='main').with_exception_handling())
575+
576+
assert_that(results.main, equal_to([1, 7]), 'main')
577+
assert_that(results.threes, equal_to([3]), 'threes')
578+
assert_that(results.fives, equal_to(['5']), 'fives')
579+
bad_elements = results.bad | beam.Keys()
580+
assert_that(bad_elements, equal_to([2, 4, 6]), 'bad')
581+
# Verify type hints from annotations are propagated
582+
self.assertEqual(results.main.element_type, int)
583+
self.assertEqual(results.threes.element_type, int)
584+
self.assertEqual(results.fives.element_type, str)
585+
self.assertEqual(
586+
results.bad.element_type,
587+
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
588+
589+
def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag(
590+
self):
591+
"""Direction 2 with custom dead_letter_tag."""
592+
593+
with beam.Pipeline() as p:
594+
results = (
595+
p
596+
| beam.Create([1, 2, 3])
597+
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
598+
'threes',
599+
main='main').with_exception_handling(dead_letter_tag='errors'))
600+
601+
assert_that(results.main, equal_to([1]), 'main')
602+
assert_that(results.threes, equal_to([3]), 'threes')
603+
bad_elements = results.errors | beam.Keys()
604+
assert_that(bad_elements, equal_to([2]), 'errors')
605+
self.assertEqual(results.threes.element_type, int)
606+
self.assertEqual(
607+
results.errors.element_type,
608+
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
609+
610+
def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag(
611+
self):
612+
"""Direction 1 with custom dead_letter_tag."""
613+
614+
with beam.Pipeline() as p:
615+
results = (
616+
p
617+
| beam.Create([1, 2, 3])
618+
| beam.ParDo(
619+
self._create_dofn_with_tagged_outputs()).with_exception_handling(
620+
dead_letter_tag='errors').with_outputs('threes', main='main'))
621+
622+
assert_that(results.main, equal_to([1]), 'main')
623+
assert_that(results.threes, equal_to([3]), 'threes')
624+
bad_elements = results.errors | beam.Keys()
625+
assert_that(bad_elements, equal_to([2]), 'errors')
626+
self.assertEqual(results.threes.element_type, int)
627+
self.assertEqual(
628+
results.errors.element_type,
629+
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
630+
631+
def test_exception_handling_no_with_outputs_backward_compat(self):
632+
"""Without with_outputs(), behavior is unchanged."""
633+
634+
with beam.Pipeline() as p:
635+
good, bad = (
636+
p
637+
| beam.Create([1, 2, 7])
638+
| beam.ParDo(self._create_dofn_with_tagged_outputs())
639+
.with_exception_handling())
640+
641+
assert_that(good, equal_to([1, 7]), 'good')
642+
bad_elements = bad | beam.Keys()
643+
assert_that(bad_elements, equal_to([2]), 'bad')
644+
645+
def test_exception_handling_compat_version_uses_old_behavior(self):
646+
"""With compat version < 2.73.0, old expand path is used."""
647+
options = PipelineOptions(update_compatibility_version="2.72.0")
648+
with beam.Pipeline(options=options) as p:
649+
good, bad = (
650+
p
651+
| beam.Create([1, 2, 7])
652+
| beam.ParDo(self._create_dofn_with_tagged_outputs())
653+
.with_exception_handling())
654+
655+
assert_that(good, equal_to([1, 7]), 'good')
656+
bad_elements = bad | beam.Keys()
657+
assert_that(bad_elements, equal_to([2]), 'bad')
658+
659+
def test_exception_handling_compat_version_element_type_set_manually(self):
660+
"""With compat version < 2.73.0, element_type is set via manual override
661+
(the old behavior) rather than via with_output_types."""
662+
663+
options = PipelineOptions(update_compatibility_version="2.72.0")
664+
with beam.Pipeline(options=options) as p:
665+
results = (
666+
p
667+
| beam.Create([1, 2, 3])
668+
| beam.ParDo(self._create_dofn_with_tagged_outputs()).
669+
with_exception_handling().with_outputs('threes', main='main'))
670+
671+
# In old path, dead letter type is Any (no with_output_types call)
672+
self.assertEqual(results.bad.element_type, typehints.Any)
673+
# Tagged outputs still get types from DoFn Literal annotations
674+
# (via DoOutputsTuple.__getitem__ reading tagged_output_types)
675+
self.assertEqual(results.threes.element_type, int)
676+
# Main output type should still be inferred via manual override
677+
assert_that(results.main, equal_to([1]), 'main')
678+
679+
def test_with_outputs_then_exception_handling_with_map(self):
680+
"""with_outputs().with_exception_handling() also works on Map."""
681+
with beam.Pipeline() as p:
682+
results = (
683+
p
684+
| beam.Create([1, 2, 3, 4, 5])
685+
| beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs(
686+
main='main').with_exception_handling())
687+
assert_that(results.main, equal_to([1, 3, 5]), 'main')
688+
bad_elements = results.bad | beam.Keys()
689+
assert_that(bad_elements, equal_to([2, 4]), 'bad')
690+
691+
def test_with_output_types_chained_on_pardo(self):
692+
"""When type hints are chained on the ParDo (not annotations on the DoFn),
693+
tagged output types should still be propagated through
694+
with_exception_handling().with_outputs()."""
695+
class DoWithFailuresNoAnnotations(beam.DoFn):
696+
def process(self, element):
697+
if element % 2 == 0:
698+
raise ValueError(f'Even numbers not allowed {element}')
699+
if element % 3 == 0:
700+
yield beam.pvalue.TaggedOutput('threes', element)
701+
else:
702+
yield element
703+
704+
with beam.Pipeline() as p:
705+
results = (
706+
p
707+
| beam.Create([1, 2, 3, 7])
708+
| beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types(
709+
int, threes=int).with_exception_handling().with_outputs(
710+
'threes', main='main'))
711+
712+
assert_that(results.main, equal_to([1, 7]), 'main')
713+
assert_that(results.threes, equal_to([3]), 'threes')
714+
bad_elements = results.bad | beam.Keys()
715+
assert_that(bad_elements, equal_to([2]), 'bad')
716+
self.assertEqual(results.main.element_type, int)
717+
self.assertEqual(results.threes.element_type, int)
718+
719+
def test_with_outputs_and_error_handler(self):
720+
"""with_outputs() + error_handler should return DoOutputsTuple, not a
721+
bare PCollection."""
722+
from apache_beam.transforms.error_handling import ErrorHandler
723+
with beam.Pipeline() as p:
724+
with ErrorHandler(beam.Map(lambda x: x)) as handler:
725+
results = (
726+
p
727+
| beam.Create([1, 2, 3, 4, 5, 6, 7])
728+
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
729+
'threes', 'fives',
730+
main='main').with_exception_handling(error_handler=handler))
731+
732+
assert_that(results.main, equal_to([1, 7]), 'main')
733+
assert_that(results.threes, equal_to([3]), 'threes')
734+
assert_that(results.fives, equal_to(['5']), 'fives')
735+
736+
518737
def test_callablewrapper_typehint():
519738
T = TypeVar("T")
520739

0 commit comments

Comments
 (0)