Skip to content

Commit 81888ce

Browse files
authored
Avoid transform after invoke, which appears to break CombineablePromise (#340)
* Avoid transform after invoke, which appears to break CombineablePromise This is a hotfix, a more meaningful fix would be to implement 'transform' for combineable promises so that they stay combineable * Add tests for combining invokes * InternalCombineablePromise should not be a wrapped promise
1 parent 46d2ec9 commit 81888ce

File tree

3 files changed

+219
-31
lines changed

3 files changed

+219
-31
lines changed

packages/restate-sdk/src/context_impl.ts

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,9 @@ export interface CallContext {
7474
delay?: number;
7575
}
7676

77-
export type InternalCombineablePromise<T> = CombineablePromise<T> &
78-
WrappedPromise<T> & {
79-
journalIndex: number;
80-
};
77+
export type InternalCombineablePromise<T> = CombineablePromise<T> & {
78+
journalIndex: number;
79+
};
8180

8281
export class ContextImpl implements ObjectContext {
8382
// here, we capture the context information for actions on the Restate context that
@@ -229,7 +228,8 @@ export class ContextImpl implements ObjectContext {
229228
method: string,
230229
data: Uint8Array,
231230
key?: string
232-
): InternalCombineablePromise<Uint8Array> {
231+
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
232+
): InternalCombineablePromise<any> {
233233
this.checkState("invoke");
234234

235235
const msg = new CallEntryMessage({
@@ -241,7 +241,7 @@ export class ContextImpl implements ObjectContext {
241241
return this.markCombineablePromise(
242242
this.stateMachine
243243
.handleUserCodeMessage(INVOKE_ENTRY_MESSAGE_TYPE, msg)
244-
.transform((v) => v as Uint8Array)
244+
.transform((v) => deserializeJson(v as Uint8Array))
245245
);
246246
}
247247

@@ -280,9 +280,7 @@ export class ContextImpl implements ObjectContext {
280280
const route = prop as string;
281281
return (...args: unknown[]) => {
282282
const requestBytes = serializeJson(args.shift());
283-
return this.invoke(name, route, requestBytes).transform(
284-
(responseBytes) => deserializeJson(responseBytes)
285-
);
283+
return this.invoke(name, route, requestBytes);
286284
};
287285
},
288286
}
@@ -302,9 +300,7 @@ export class ContextImpl implements ObjectContext {
302300
const route = prop as string;
303301
return (...args: unknown[]) => {
304302
const requestBytes = serializeJson(args.shift());
305-
return this.invoke(name, route, requestBytes, key).transform(
306-
(responseBytes) => deserializeJson(responseBytes)
307-
);
303+
return this.invoke(name, route, requestBytes, key);
308304
};
309305
},
310306
}
@@ -670,20 +666,23 @@ export class ContextImpl implements ObjectContext {
670666
]) as Promise<T>;
671667
};
672668

673-
return Object.defineProperties(p, {
674-
[RESTATE_CTX_SYMBOL]: {
675-
value: this,
676-
},
677-
journalIndex: {
678-
value: journalIndex,
679-
},
680-
orTimeout: {
681-
value: orTimeout.bind(this),
682-
},
683-
}) as InternalCombineablePromise<T>;
669+
defineProperty(p, RESTATE_CTX_SYMBOL, this);
670+
defineProperty(p, "journalIndex", journalIndex);
671+
defineProperty(p, "orTimeout", orTimeout.bind(this));
672+
673+
return p;
684674
}
685675
}
686676

677+
// wraps defineProperty such that it informs tsc of the correct type of its output
678+
function defineProperty<Obj extends object, Key extends PropertyKey, T>(
679+
obj: Obj,
680+
prop: Key,
681+
value: T
682+
): asserts obj is Obj & Readonly<Record<Key, T>> {
683+
Object.defineProperty(obj, prop, { value });
684+
}
685+
687686
function unpack<T>(
688687
a: string | RunAction<T>,
689688
b?: RunAction<T>

packages/restate-sdk/test/promise_combinators.test.ts

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
import { describe, expect } from "@jest/globals";
1313
import * as restate from "../src/public_api";
14-
import { TestDriver, TestGreeter, TestResponse } from "./testdriver";
14+
import {
15+
GreeterApi,
16+
TestDriver,
17+
TestGreeter,
18+
TestResponse,
19+
} from "./testdriver";
1520
import {
1621
awakeableMessage,
1722
completionMessage,
@@ -27,6 +32,7 @@ import {
2732
sleepMessage,
2833
sideEffectMessage,
2934
ackMessage,
35+
invokeMessage,
3036
} from "./protoutils";
3137
import {
3238
COMBINATOR_ENTRY_MESSAGE,
@@ -35,6 +41,7 @@ import {
3541
import { TimeoutError } from "../src/types/errors";
3642
import { CombineablePromise } from "../src/context";
3743
import { Empty } from "../src/generated/proto/protocol_pb";
44+
import { Buffer } from "node:buffer";
3845

3946
class AwakeableSleepRaceGreeter implements TestGreeter {
4047
async greet(ctx: restate.ObjectContext): Promise<TestResponse> {
@@ -398,3 +405,182 @@ describe("AwakeableOrTimeoutGreeter", () => {
398405
]);
399406
});
400407
});
408+
409+
class InvokeRaceGreeter implements TestGreeter {
410+
async greet(ctx: restate.ObjectContext): Promise<TestResponse> {
411+
const jack = ctx.objectClient(GreeterApi, "Jack").greet({ name: "Jack" });
412+
const jill = ctx.objectClient(GreeterApi, "Jill").greet({ name: "Jill" });
413+
414+
const result = await CombineablePromise.race([jack, jill]);
415+
416+
return TestResponse.create({
417+
greeting: result.greeting,
418+
});
419+
}
420+
}
421+
422+
describe("InvokeRaceGreeter", () => {
423+
const jackInvoke = invokeMessage(
424+
"greeter",
425+
"greet",
426+
Buffer.from(JSON.stringify({ name: "Jack" })),
427+
undefined,
428+
undefined,
429+
"Jack"
430+
);
431+
const jillInvoke = invokeMessage(
432+
"greeter",
433+
"greet",
434+
Buffer.from(JSON.stringify({ name: "Jill" })),
435+
undefined,
436+
undefined,
437+
"Jill"
438+
);
439+
440+
it("should suspend without completions", async () => {
441+
const result = await new TestDriver(new InvokeRaceGreeter(), [
442+
startMessage(),
443+
inputMessage(greetRequest("Till")),
444+
]).run();
445+
446+
expect(result.length).toStrictEqual(3);
447+
expect(result[0]).toStrictEqual(jackInvoke);
448+
expect(result[1]).toStrictEqual(jillInvoke);
449+
expect(result[2]).toStrictEqual(suspensionMessage([1, 2]));
450+
});
451+
452+
it("handles completion of first invoke", async () => {
453+
const result = await new TestDriver(new InvokeRaceGreeter(), [
454+
startMessage(),
455+
inputMessage(greetRequest("Till")),
456+
completionMessage(1, greetResponse("Hi Jack")),
457+
ackMessage(3),
458+
]).run();
459+
460+
expect(result.length).toStrictEqual(5);
461+
expect(result[0]).toStrictEqual(jackInvoke);
462+
expect(result[1]).toStrictEqual(jillInvoke);
463+
expect(result.slice(2)).toStrictEqual([
464+
combinatorEntryMessage(0, [1]),
465+
outputMessage(greetResponse("Hi Jack")),
466+
END_MESSAGE,
467+
]);
468+
});
469+
470+
it("handles completion of second invoke", async () => {
471+
const result = await new TestDriver(new InvokeRaceGreeter(), [
472+
startMessage(),
473+
inputMessage(greetRequest("Till")),
474+
completionMessage(2, greetResponse("Hi Jill")),
475+
ackMessage(3),
476+
]).run();
477+
478+
expect(result.length).toStrictEqual(5);
479+
expect(result[0]).toStrictEqual(jackInvoke);
480+
expect(result[1]).toStrictEqual(jillInvoke);
481+
expect(result.slice(2)).toStrictEqual([
482+
combinatorEntryMessage(0, [2]),
483+
outputMessage(greetResponse(`Hi Jill`)),
484+
END_MESSAGE,
485+
]);
486+
});
487+
488+
it("handles replay of the first invoke", async () => {
489+
const result = await new TestDriver(new InvokeRaceGreeter(), [
490+
startMessage(),
491+
inputMessage(greetRequest("Till")),
492+
invokeMessage(
493+
"greeter",
494+
"greet",
495+
Buffer.from(JSON.stringify({ name: "Jack" })),
496+
greetResponse("Hi Jack"),
497+
undefined,
498+
"Jack"
499+
),
500+
ackMessage(3),
501+
]).run();
502+
503+
expect(result.length).toStrictEqual(4);
504+
expect(result[0]).toStrictEqual(jillInvoke);
505+
expect(result.slice(1)).toStrictEqual([
506+
combinatorEntryMessage(0, [1]),
507+
outputMessage(greetResponse("Hi Jack")),
508+
END_MESSAGE,
509+
]);
510+
});
511+
512+
it("handles replay of both invokes", async () => {
513+
const result = await new TestDriver(new InvokeRaceGreeter(), [
514+
startMessage(),
515+
inputMessage(greetRequest("Till")),
516+
invokeMessage(
517+
"greeter",
518+
"greet",
519+
Buffer.from(JSON.stringify({ name: "Jack" })),
520+
greetResponse("Hi Jack"),
521+
undefined,
522+
"Jack"
523+
),
524+
invokeMessage(
525+
"greeter",
526+
"greet",
527+
Buffer.from(JSON.stringify({ name: "Jill" })),
528+
greetResponse("Hi Jill"),
529+
undefined,
530+
"Jill"
531+
),
532+
ackMessage(3),
533+
]).run();
534+
535+
expect(result).toStrictEqual([
536+
// The first invoke will be chosen because Promise.race will pick the first promise, in case both are resolved
537+
combinatorEntryMessage(0, [1, 2]),
538+
outputMessage(greetResponse("Hi Jack")),
539+
END_MESSAGE,
540+
]);
541+
});
542+
543+
it("handles replay of the combinator with first invoke completed", async () => {
544+
const result = await new TestDriver(new InvokeRaceGreeter(), [
545+
startMessage(),
546+
inputMessage(greetRequest("Till")),
547+
invokeMessage(
548+
"greeter",
549+
"greet",
550+
Buffer.from(JSON.stringify({ name: "Jack" })),
551+
greetResponse("Hi Jack"),
552+
undefined,
553+
"Jack"
554+
),
555+
jillInvoke,
556+
combinatorEntryMessage(0, [1]),
557+
]).run();
558+
559+
expect(result).toStrictEqual([
560+
outputMessage(greetResponse("Hi Jack")),
561+
END_MESSAGE,
562+
]);
563+
});
564+
565+
it("handles replay of the combinator with second invoke completed", async () => {
566+
const result = await new TestDriver(new InvokeRaceGreeter(), [
567+
startMessage(),
568+
inputMessage(greetRequest("Till")),
569+
jackInvoke,
570+
invokeMessage(
571+
"greeter",
572+
"greet",
573+
Buffer.from(JSON.stringify({ name: "Jill" })),
574+
greetResponse("Hi Jill"),
575+
undefined,
576+
"Jill"
577+
),
578+
combinatorEntryMessage(0, [2]),
579+
]).run();
580+
581+
expect(result).toStrictEqual([
582+
outputMessage(greetResponse("Hi Jill")),
583+
END_MESSAGE,
584+
]);
585+
});
586+
});

packages/restate-sdk/test/testdriver.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ import { StateMachine } from "../src/state_machine";
2323
import { InvocationBuilder } from "../src/invocation";
2424
import { EndpointImpl } from "../src/endpoint/endpoint_impl";
2525
import { ObjectContext } from "../src/context";
26-
import { ServiceDefinition, object } from "../src/public_api";
26+
import {
27+
object,
28+
VirtualObjectDefinition,
29+
VirtualObject,
30+
} from "../src/public_api";
2731
import { ProtocolMode } from "../src/types/discovery";
2832

2933
export type TestRequest = {
@@ -38,11 +42,10 @@ export const TestResponse = {
3842
create: (test: TestResponse): TestResponse => test,
3943
};
4044

41-
export type GreetType = {
42-
greet: (key: string, arg: TestRequest) => Promise<TestResponse>;
43-
};
44-
45-
export const GreeterApi: ServiceDefinition<"greeter", GreetType> = {
45+
export const GreeterApi: VirtualObjectDefinition<
46+
"greeter",
47+
VirtualObject<TestGreeter>
48+
> = {
4649
name: "greeter",
4750
};
4851

@@ -188,7 +191,7 @@ export class TestDriver implements Connection {
188191
rlog.debug(
189192
`Adding result to the result array. Message type: ${
190193
msg.messageType
191-
}, message:
194+
}, message:
192195
${
193196
msg.message instanceof Uint8Array
194197
? (msg.message as Uint8Array).toString()

0 commit comments

Comments
 (0)