-
Notifications
You must be signed in to change notification settings - Fork 202
Expand file tree
/
Copy path_workflow.py
More file actions
636 lines (560 loc) · 26.2 KB
/
Copy path_workflow.py
File metadata and controls
636 lines (560 loc) · 26.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
"""Workflow test environment."""
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime, timedelta, timezone
from typing import (
Any,
cast,
)
import google.protobuf.empty_pb2
from typing_extensions import Self
import temporalio.api.nexus.v1
import temporalio.api.operatorservice.v1
import temporalio.api.testservice.v1
import temporalio.bridge.testing
import temporalio.client
import temporalio.common
import temporalio.converter
import temporalio.exceptions
import temporalio.runtime
import temporalio.service
import temporalio.worker
logger = logging.getLogger(__name__)
class WorkflowEnvironment:
"""Workflow environment for testing workflows.
Most developers will want to use the static :py:meth:`start_time_skipping`
to start a test server process that automatically skips time as needed.
Alternatively, :py:meth:`start_local` may be used for a full, local Temporal
server with more features. To use an existing server, use
:py:meth:`from_client`.
This environment is an async context manager, so it can be used with
``async with`` to make sure it shuts down properly. Otherwise,
:py:meth:`shutdown` can be manually called.
To use the environment, simply use the :py:attr:`client` on it.
Workflows invoked on the workflow environment are automatically configured
to have ``assert`` failures fail the workflow with the assertion error.
"""
@classmethod
def from_client(cls, client: temporalio.client.Client) -> Self:
"""Create a workflow environment from the given client.
:py:attr:`supports_time_skipping` will always return ``False`` for this
environment. :py:meth:`sleep` will sleep the actual amount of time and
:py:meth:`get_current_time` will return the current time.
Args:
client: The client to use for the environment.
Returns:
The workflow environment that runs against the given client.
"""
# Add the assertion interceptor
return cls(_client_with_interceptors(client, _AssertionErrorInterceptor()))
@classmethod
async def start_local(
cls,
*,
namespace: str = "default",
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
interceptors: Sequence[temporalio.client.Interceptor] = [],
plugins: Sequence[temporalio.client.Plugin] = [],
default_workflow_query_reject_condition: None
| (temporalio.common.QueryRejectCondition) = None,
retry_config: temporalio.service.RetryConfig | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
identity: str | None = None,
tls: bool | temporalio.service.TLSConfig = False,
ip: str = "127.0.0.1",
port: int | None = None,
download_dest_dir: str | None = None,
ui: bool = False,
runtime: temporalio.runtime.Runtime | None = None,
search_attributes: Sequence[temporalio.common.SearchAttributeKey] = (),
dev_server_existing_path: str | None = None,
dev_server_database_filename: str | None = None,
dev_server_log_format: str = "pretty",
dev_server_log_level: str | None = "warn",
dev_server_download_version: str = "default",
dev_server_extra_args: Sequence[str] = [],
dev_server_download_ttl: timedelta | None = None,
ui_port: int | None = None,
) -> WorkflowEnvironment:
"""Start a full Temporal server locally, downloading if necessary.
This environment is good for testing full server capabilities, but does
not support time skipping like :py:meth:`start_time_skipping` does.
:py:attr:`supports_time_skipping` will always return ``False`` for this
environment. :py:meth:`sleep` will sleep the actual amount of time and
:py:meth:`get_current_time` will return the current time.
Internally, this uses the Temporal CLI dev server from
https://github.com/temporalio/cli. This is a self-contained binary for
Temporal using Sqlite persistence. This call will download the CLI to a
temporary directory by default if it has not already been downloaded
before and ``dev_server_existing_path`` is not set.
In the future, the dev server implementation may be changed to another
implementation. Therefore, all ``dev_server_`` prefixed parameters are
dev-server specific and may not apply to newer versions.
Args:
namespace: Namespace name to use for this environment.
data_converter: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
interceptors: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
default_workflow_query_reject_condition: See parameter of the same
name on :py:meth:`temporalio.client.Client.connect`.
retry_config: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
rpc_metadata: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
identity: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
tls: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
ip: IP address to bind to, or 127.0.0.1 by default.
port: Port number to bind to, or an OS-provided port by default.
download_dest_dir: Directory to download binary to if a download is
needed. If unset, this is the system's temporary directory.
ui: If ``True``, will start a UI in the dev server.
runtime: Specific runtime to use or default if unset.
search_attributes: Search attributes to register with the dev
server.
dev_server_existing_path: Existing path to the CLI binary.
If present, no download will be attempted to fetch the binary.
dev_server_database_filename: Path to the Sqlite database to use
for the dev server. Unset default means only in-memory Sqlite
will be used.
dev_server_log_format: Log format for the dev server.
dev_server_log_level: Log level to use for the dev server. Default
is ``warn``, but if set to ``None`` this will translate the
Python logger's level to a dev server log level.
dev_server_download_version: Specific CLI version to download.
Defaults to ``default`` which downloads the version known to
work best with this SDK.
dev_server_extra_args: Extra arguments for the CLI binary.
dev_server_download_ttl: TTL for the downloaded CLI binary. If unset, it will be
cached indefinitely.
ui_port: UI port to use if UI is enabled.
Returns:
The started CLI dev server workflow environment.
"""
# Use the logger's configured level if none given
if not dev_server_log_level:
if logger.isEnabledFor(logging.DEBUG):
dev_server_log_level = "debug"
elif logger.isEnabledFor(logging.INFO):
dev_server_log_level = "info"
elif logger.isEnabledFor(logging.WARNING):
dev_server_log_level = "warn"
elif logger.isEnabledFor(logging.ERROR):
dev_server_log_level = "error"
else:
dev_server_log_level = "fatal"
# Add search attributes
if search_attributes:
new_args = []
for attr in search_attributes:
new_args.append("--search-attribute")
new_args.append(f"{attr.name}={attr._metadata_type}")
new_args += dev_server_extra_args
dev_server_extra_args = new_args
# Start CLI dev server
runtime = runtime or temporalio.runtime.Runtime.default()
download_ttl_ms = None
if dev_server_download_ttl is not None:
download_ttl_ms = int(dev_server_download_ttl.total_seconds() * 1000)
server = await temporalio.bridge.testing.EphemeralServer.start_dev_server(
runtime._core_runtime,
temporalio.bridge.testing.DevServerConfig(
existing_path=dev_server_existing_path,
sdk_name="sdk-python",
sdk_version=temporalio.service.__version__,
download_version=dev_server_download_version,
download_dest_dir=download_dest_dir,
namespace=namespace,
ip=ip,
port=port,
database_filename=dev_server_database_filename,
ui=ui,
ui_port=ui_port,
log_format=dev_server_log_format,
log_level=dev_server_log_level,
extra_args=dev_server_extra_args,
download_ttl_ms=download_ttl_ms,
),
)
# If we can't connect to the server, we should shut it down
try:
return _EphemeralServerWorkflowEnvironment(
await temporalio.client.Client.connect(
server.target,
namespace=namespace,
data_converter=data_converter,
interceptors=interceptors,
plugins=plugins,
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
tls=tls,
retry_config=retry_config,
rpc_metadata=rpc_metadata,
identity=identity,
runtime=runtime,
),
server,
)
except:
try:
await server.shutdown()
except:
logger.warning(
"Failed stopping local server on client connection failure",
exc_info=True,
)
raise
@classmethod
async def start_time_skipping(
cls,
*,
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
interceptors: Sequence[temporalio.client.Interceptor] = [],
plugins: Sequence[temporalio.client.Plugin] = [],
default_workflow_query_reject_condition: None
| (temporalio.common.QueryRejectCondition) = None,
retry_config: temporalio.service.RetryConfig | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
identity: str | None = None,
port: int | None = None,
download_dest_dir: str | None = None,
runtime: temporalio.runtime.Runtime | None = None,
test_server_existing_path: str | None = None,
test_server_download_version: str = "default",
test_server_extra_args: Sequence[str] = [],
test_server_download_ttl: timedelta | None = None,
) -> WorkflowEnvironment:
"""Start a time skipping workflow environment.
By default, this environment will automatically skip to the next events
in time when a workflow's
:py:meth:`temporalio.client.WorkflowHandle.result` is awaited on (which
includes :py:meth:`temporalio.client.Client.execute_workflow`). Before
the result is awaited on, time can be manually skipped forward using
:py:meth:`sleep`. The currently known time can be obtained via
:py:meth:`get_current_time`.
Internally, this environment lazily downloads a test-server binary for
the current OS/arch into the temp directory if it is not already there.
Then the executable is started and will be killed when
:py:meth:`shutdown` is called (which is implicitly done if this is
started via
``async with await WorkflowEnvironment.start_time_skipping()``).
Users can reuse this environment for testing multiple independent
workflows, but not concurrently. Time skipping, which is automatically
done when awaiting a workflow result and manually done on
:py:meth:`sleep`, is global to the environment, not to the workflow
under test.
In the future, the test server implementation may be changed to another
implementation. Therefore, all ``test_server_`` prefixed parameters are
test server specific and may not apply to newer versions.
Args:
data_converter: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
interceptors: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
default_workflow_query_reject_condition: See parameter of the same
name on :py:meth:`temporalio.client.Client.connect`.
retry_config: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
rpc_metadata: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
identity: See parameter of the same name on
:py:meth:`temporalio.client.Client.connect`.
port: Port number to bind to, or an OS-provided port by default.
download_dest_dir: Directory to download binary to if a download is
needed. If unset, this is the system's temporary directory.
runtime: Specific runtime to use or default if unset.
test_server_existing_path: Existing path to the test server binary.
If present, no download will be attempted to fetch the binary.
test_server_download_version: Specific test server version to
download. Defaults to ``default`` which downloads the version
known to work best with this SDK.
test_server_extra_args: Extra arguments for the test server binary.
test_server_download_ttl: TTL for the downloaded test server binary. If unset, it
will be cached indefinitely.
Returns:
The started workflow environment with time skipping.
"""
# Start test server
runtime = runtime or temporalio.runtime.Runtime.default()
download_ttl_ms = None
if test_server_download_ttl:
download_ttl_ms = int(test_server_download_ttl.total_seconds() * 1000)
server = await temporalio.bridge.testing.EphemeralServer.start_test_server(
runtime._core_runtime,
temporalio.bridge.testing.TestServerConfig(
existing_path=test_server_existing_path,
sdk_name="sdk-python",
sdk_version=temporalio.service.__version__,
download_version=test_server_download_version,
download_dest_dir=download_dest_dir,
download_ttl_ms=download_ttl_ms,
port=port,
extra_args=test_server_extra_args,
),
)
# If we can't connect to the server, we should shut it down
try:
return _EphemeralServerWorkflowEnvironment(
await temporalio.client.Client.connect(
server.target,
data_converter=data_converter,
interceptors=interceptors,
plugins=plugins,
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
retry_config=retry_config,
rpc_metadata=rpc_metadata,
identity=identity,
runtime=runtime,
),
server,
)
except:
try:
await server.shutdown()
except:
logger.warning(
"Failed stopping test server on client connection failure",
exc_info=True,
)
raise
def __init__(self, client: temporalio.client.Client) -> None:
"""Create a workflow environment from a client.
Most users would use a factory methods instead.
"""
self._client = client
async def __aenter__(self) -> WorkflowEnvironment:
"""Noop for ``async with`` support."""
return self
async def __aexit__(self, *args: Any) -> None:
"""For ``async with`` support to just call :py:meth:`shutdown`."""
await self.shutdown()
@property
def client(self) -> temporalio.client.Client:
"""Client to this environment."""
return self._client
async def shutdown(self) -> None:
"""Shut down this environment."""
pass
async def sleep(self, duration: timedelta | float) -> None:
"""Sleep in this environment.
This awaits a regular :py:func:`asyncio.sleep` in regular environments,
or manually skips time in time-skipping environments.
Args:
duration: Amount of time to sleep.
"""
await asyncio.sleep(
duration.total_seconds() if isinstance(duration, timedelta) else duration
)
async def get_current_time(self) -> datetime:
"""Get the current time known to this environment.
For non-time-skipping environments this is simply the system time. For
time-skipping environments this is whatever time has been skipped to.
"""
return datetime.now(timezone.utc)
@property
def supports_time_skipping(self) -> bool:
"""Whether this environment supports time skipping."""
return False
async def create_nexus_endpoint(
self, endpoint_name: str, task_queue: str
) -> temporalio.api.nexus.v1.Endpoint:
"""Create a Nexus endpoint with the given name and task queue.
Args:
endpoint_name: The name of the Nexus endpoint to create.
task_queue: The task queue to associate with the endpoint.
Returns:
The created Nexus endpoint.
"""
response = await self._client.operator_service.create_nexus_endpoint(
temporalio.api.operatorservice.v1.CreateNexusEndpointRequest(
spec=temporalio.api.nexus.v1.EndpointSpec(
name=endpoint_name,
target=temporalio.api.nexus.v1.EndpointTarget(
worker=temporalio.api.nexus.v1.EndpointTarget.Worker(
namespace=self._client.namespace,
task_queue=task_queue,
)
),
)
)
)
return response.endpoint
async def delete_nexus_endpoint(
self, endpoint: temporalio.api.nexus.v1.Endpoint
) -> None:
"""Delete a Nexus endpoint.
Args:
endpoint: The Nexus endpoint to delete.
"""
await self._client.operator_service.delete_nexus_endpoint(
temporalio.api.operatorservice.v1.DeleteNexusEndpointRequest(
id=endpoint.id,
version=endpoint.version,
)
)
@contextmanager
def auto_time_skipping_disabled(self) -> Iterator[None]:
"""Disable any automatic time skipping if this is a time-skipping
environment.
This is a context manager for use via ``with``. Usually in time-skipping
environments, waiting on a workflow result causes time to automatically
skip until the next event. This can disable that. However, this only
applies to results awaited inside this context. This will not disable
automatic time skipping on previous results.
This has no effect on non-time-skipping environments.
"""
# It's always disabled for this base class
yield None
class _EphemeralServerWorkflowEnvironment(WorkflowEnvironment):
def __init__(
self,
client: temporalio.client.Client,
server: temporalio.bridge.testing.EphemeralServer,
) -> None:
# Add assertion interceptor to client and if time skipping is supported,
# add time skipping interceptor
self._supports_time_skipping = server.has_test_service
interceptors: list[temporalio.client.Interceptor] = [
_AssertionErrorInterceptor()
]
if self._supports_time_skipping:
interceptors.append(_TimeSkippingClientInterceptor(self))
super().__init__(_client_with_interceptors(client, *interceptors))
self._server = server
self._auto_time_skipping = True
async def shutdown(self) -> None:
await self._server.shutdown()
async def sleep(self, duration: timedelta | float) -> None:
# Use regular sleep if no time skipping
if not self._supports_time_skipping:
return await super().sleep(duration)
req = temporalio.api.testservice.v1.SleepRequest()
req.duration.FromTimedelta(
duration if isinstance(duration, timedelta) else timedelta(seconds=duration)
)
await self._client.test_service.unlock_time_skipping_with_sleep(req)
async def get_current_time(self) -> datetime:
# Use regular time if no time skipping
if not self._supports_time_skipping:
return await super().get_current_time()
resp = await self._client.test_service.get_current_time(
google.protobuf.empty_pb2.Empty()
)
return resp.time.ToDatetime().replace(tzinfo=timezone.utc)
@property
def supports_time_skipping(self) -> bool:
return self._supports_time_skipping
@contextmanager
def auto_time_skipping_disabled(self) -> Iterator[None]:
already_disabled = not self._auto_time_skipping
self._auto_time_skipping = False
try:
yield None
finally:
if not already_disabled:
self._auto_time_skipping = True
@asynccontextmanager
async def time_skipping_unlocked(self) -> AsyncIterator[None]:
# If it's disabled or not supported, no locking/unlocking, just yield
# and return
if not self._supports_time_skipping or not self._auto_time_skipping:
yield None
return
# Unlock to start time skipping, lock again to stop it
await self.client.test_service.unlock_time_skipping(
temporalio.api.testservice.v1.UnlockTimeSkippingRequest()
)
try:
yield None
# Lock it back, throwing on error
await self.client.test_service.lock_time_skipping(
temporalio.api.testservice.v1.LockTimeSkippingRequest()
)
except:
# Lock it back, swallowing error
try:
await self.client.test_service.lock_time_skipping(
temporalio.api.testservice.v1.LockTimeSkippingRequest()
)
except:
logger.exception("Failed locking time skipping after error")
raise
class _AssertionErrorInterceptor(
temporalio.client.Interceptor, temporalio.worker.Interceptor
):
def workflow_interceptor_class(
self, input: temporalio.worker.WorkflowInterceptorClassInput
) -> type[temporalio.worker.WorkflowInboundInterceptor] | None:
return _AssertionErrorWorkflowInboundInterceptor
class _AssertionErrorWorkflowInboundInterceptor(
temporalio.worker.WorkflowInboundInterceptor
):
async def execute_workflow(
self, input: temporalio.worker.ExecuteWorkflowInput
) -> Any:
with self.assert_error_as_app_error():
return await super().execute_workflow(input)
async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
with self.assert_error_as_app_error():
return await super().handle_signal(input)
@contextmanager
def assert_error_as_app_error(self) -> Iterator[None]:
try:
yield None
except AssertionError as err:
app_err = temporalio.exceptions.ApplicationError(
str(err), type="AssertionError", non_retryable=True
)
app_err.__traceback__ = err.__traceback__
raise app_err from None
class _TimeSkippingClientInterceptor(temporalio.client.Interceptor):
def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None: # type: ignore[reportMissingSuperCall]
self.env = env
def intercept_client(
self, next: temporalio.client.OutboundInterceptor
) -> temporalio.client.OutboundInterceptor:
return _TimeSkippingClientOutboundInterceptor(next, self.env)
class _TimeSkippingClientOutboundInterceptor(temporalio.client.OutboundInterceptor):
def __init__(
self,
next: temporalio.client.OutboundInterceptor,
env: _EphemeralServerWorkflowEnvironment,
) -> None:
super().__init__(next)
self.env = env
async def start_workflow(
self, input: temporalio.client.StartWorkflowInput
) -> temporalio.client.WorkflowHandle[Any, Any]:
# We need to change the class of the handle so we can override result
handle = cast(_TimeSkippingWorkflowHandle, await super().start_workflow(input))
handle.__class__ = _TimeSkippingWorkflowHandle
handle.env = self.env
return handle
class _TimeSkippingWorkflowHandle(temporalio.client.WorkflowHandle):
env: _EphemeralServerWorkflowEnvironment # type: ignore[reportUninitializedInstanceAttribute]
async def result(
self,
*,
follow_runs: bool = True,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> Any:
async with self.env.time_skipping_unlocked():
return await super().result(
follow_runs=follow_runs,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
)
def _client_with_interceptors(
client: temporalio.client.Client, *interceptors: temporalio.client.Interceptor
) -> temporalio.client.Client:
# Shallow clone client and add interceptors
config = client.config()
config_interceptors = list(config["interceptors"])
config_interceptors.extend(interceptors)
config["interceptors"] = config_interceptors
return temporalio.client.Client(**config)