Skip to content

Commit d2456ec

Browse files
authored
feat(entitlements): add entitlement/permission system for tasks and workflows (#370)
Adds a declarative entitlement system that mirrors the existing schema pattern: static class-level declarations, dynamic instance overrides, change events, and graph-level aggregation. This enables consumers to inspect required permissions before workflow execution and optionally enforce them.
1 parent 258e2ac commit d2456ec

30 files changed

+2333
-59
lines changed

packages/ai/src/task/base/AiTask.ts

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,24 @@
1010
*/
1111

1212
import { Job } from "@workglow/job-queue";
13+
import type {
14+
IExecuteContext,
15+
IExecuteReactiveContext,
16+
TaskConfig,
17+
TaskEntitlement,
18+
TaskEntitlements,
19+
TaskOutput,
20+
} from "@workglow/task-graph";
1321
import {
22+
Entitlements,
1423
Task,
1524
TaskConfigSchema,
1625
TaskConfigurationError,
1726
TaskInput,
1827
hasStructuredOutput,
1928
} from "@workglow/task-graph";
20-
import type {
21-
IExecuteContext,
22-
IExecuteReactiveContext,
23-
TaskConfig,
24-
TaskOutput,
25-
} from "@workglow/task-graph";
26-
import type { DataPortSchema, JsonSchema } from "@workglow/util/schema";
2729
import type { ServiceRegistry } from "@workglow/util";
30+
import type { DataPortSchema, JsonSchema } from "@workglow/util/schema";
2831

2932
import { AiJob, AiJobInput } from "../../job/AiJob";
3033
import { MODEL_REPOSITORY } from "../../model/ModelRegistry";
@@ -63,6 +66,35 @@ export class AiTask<
6366
Config extends TaskConfig<Input> = TaskConfig<Input>,
6467
> extends Task<Input, Output, Config> {
6568
public static override type: string = "AiTask";
69+
public static override hasDynamicEntitlements: boolean = true;
70+
71+
public static override entitlements(): TaskEntitlements {
72+
return {
73+
entitlements: [{ id: Entitlements.AI_INFERENCE, reason: "Runs AI model inference" }],
74+
};
75+
}
76+
77+
public override entitlements(): TaskEntitlements {
78+
const base: TaskEntitlement[] = [
79+
{ id: Entitlements.AI_INFERENCE, reason: "Runs AI model inference" },
80+
];
81+
// Prefer runInputData.model (runtime) over defaults.model (construction-time)
82+
const runModel = this.runInputData?.model;
83+
const modelId =
84+
typeof runModel === "string"
85+
? runModel
86+
: typeof this.defaults.model === "string"
87+
? this.defaults.model
88+
: undefined;
89+
if (modelId) {
90+
base.push({
91+
id: Entitlements.AI_MODEL,
92+
reason: `Uses model ${modelId}`,
93+
resources: [modelId],
94+
});
95+
}
96+
return { entitlements: base };
97+
}
6698

6799
public static override configSchema(): DataPortSchema {
68100
return aiTaskConfigSchema;

packages/task-graph/src/common.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
export * from "./task-graph/Dataflow";
88
export * from "./task-graph/DataflowEvents";
99

10+
export * from "./task-graph/GraphEntitlementUtils";
1011
export * from "./task-graph/GraphFormatScanner";
1112
export * from "./task-graph/GraphSchemaUtils";
1213
export * from "./task-graph/ITaskGraph";
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Steven Roussey <sroussey@gmail.com>
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
import {
8+
type EntitlementId,
9+
type TaskEntitlement,
10+
type TaskEntitlements,
11+
type TrackedTaskEntitlement,
12+
type TrackedTaskEntitlements,
13+
EMPTY_ENTITLEMENTS,
14+
mergeEntitlementPair,
15+
} from "../task/TaskEntitlements";
16+
import type { TaskIdType } from "../task/TaskTypes";
17+
import { TaskStatus } from "../task/TaskTypes";
18+
import type { TaskGraph } from "./TaskGraph";
19+
20+
// ========================================================================
21+
// Options
22+
// ========================================================================
23+
24+
export interface GraphEntitlementOptions {
25+
/**
26+
* When true, annotate each entitlement with the source task IDs that require it.
27+
*/
28+
readonly trackOrigins?: boolean;
29+
/**
30+
* Controls which ConditionalTask branches are included.
31+
* - "all" (default): Include entitlements from ALL branches (conservative, pre-execution analysis)
32+
* - "active": Only include entitlements from currently active branches (runtime, after conditions evaluated)
33+
*/
34+
readonly conditionalBranches?: "all" | "active";
35+
}
36+
37+
// ========================================================================
38+
// Graph Entitlement Computation
39+
// ========================================================================
40+
41+
/**
42+
* Computes the aggregated entitlements for a TaskGraph.
43+
* Returns the union of all task entitlements in the graph.
44+
*
45+
* When `trackOrigins` is true, returns TrackedTaskEntitlements with source task IDs.
46+
*/
47+
export function computeGraphEntitlements(
48+
graph: TaskGraph,
49+
options?: GraphEntitlementOptions & { readonly trackOrigins: true }
50+
): TrackedTaskEntitlements;
51+
export function computeGraphEntitlements(
52+
graph: TaskGraph,
53+
options?: GraphEntitlementOptions
54+
): TaskEntitlements;
55+
export function computeGraphEntitlements(
56+
graph: TaskGraph,
57+
options?: GraphEntitlementOptions
58+
): TaskEntitlements | TrackedTaskEntitlements {
59+
const tasks = graph.getTasks();
60+
if (tasks.length === 0) return EMPTY_ENTITLEMENTS;
61+
62+
const trackOrigins = options?.trackOrigins ?? false;
63+
const conditionalBranches = options?.conditionalBranches ?? "all";
64+
65+
// Accumulate entitlements by ID
66+
const merged = new Map<
67+
EntitlementId,
68+
{ entitlement: TaskEntitlement; sourceTaskIds: TaskIdType[] }
69+
>();
70+
71+
for (const task of tasks) {
72+
// For ConditionalTask with "active" mode, skip disabled tasks
73+
if (conditionalBranches === "active" && task.status !== undefined) {
74+
if (task.status === TaskStatus.DISABLED) continue;
75+
}
76+
77+
const taskEntitlements = task.entitlements();
78+
for (const entitlement of taskEntitlements.entitlements) {
79+
const existing = merged.get(entitlement.id);
80+
if (existing) {
81+
// Merge: optional=false wins, resources are unioned
82+
existing.entitlement = mergeEntitlementPair(existing.entitlement, entitlement);
83+
if (trackOrigins) {
84+
existing.sourceTaskIds.push(task.id);
85+
}
86+
} else {
87+
merged.set(entitlement.id, {
88+
entitlement,
89+
sourceTaskIds: trackOrigins ? [task.id] : [],
90+
});
91+
}
92+
}
93+
}
94+
95+
if (merged.size === 0) return EMPTY_ENTITLEMENTS;
96+
97+
if (trackOrigins) {
98+
const entitlements: TrackedTaskEntitlement[] = [];
99+
for (const { entitlement, sourceTaskIds } of merged.values()) {
100+
entitlements.push({ ...entitlement, sourceTaskIds });
101+
}
102+
return { entitlements };
103+
}
104+
105+
return { entitlements: Array.from(merged.values()).map((e) => e.entitlement) };
106+
}

packages/task-graph/src/task-graph/TaskGraph.ts

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
* SPDX-License-Identifier: Apache-2.0
55
*/
66

7-
import { DirectedAcyclicGraph } from "@workglow/util/graph";
87
import { EventEmitter, ServiceRegistry, uuid4 } from "@workglow/util";
8+
import { DirectedAcyclicGraph } from "@workglow/util/graph";
99
import { TaskOutputRepository } from "../storage/TaskOutputRepository";
1010
import type { ITask } from "../task/ITask";
1111
import type { StreamEvent } from "../task/StreamTypes";
12+
import type { TaskEntitlements } from "../task/TaskEntitlements";
1213
import type { JsonTaskItem, TaskGraphJson, TaskGraphJsonOptions } from "../task/TaskJSON";
1314
import type { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes";
14-
import { ensureTask } from "./Conversions";
1515
import type { PipeFunction } from "./Conversions";
16-
import { Dataflow } from "./Dataflow";
16+
import { ensureTask } from "./Conversions";
1717
import type { DataflowIdType } from "./Dataflow";
18+
import { Dataflow } from "./Dataflow";
19+
import { computeGraphEntitlements } from "./GraphEntitlementUtils";
1820
import { addBoundaryNodesToDependencyJson, addBoundaryNodesToGraphJson } from "./GraphSchemaUtils";
1921
import type { ITaskGraph } from "./ITaskGraph";
2022
import {
@@ -27,8 +29,8 @@ import {
2729
TaskGraphStatusEvents,
2830
TaskGraphStatusListeners,
2931
} from "./TaskGraphEvents";
30-
import { CompoundMergeStrategy, GraphResult, TaskGraphRunner } from "./TaskGraphRunner";
3132
import type { GraphResultArray } from "./TaskGraphRunner";
33+
import { CompoundMergeStrategy, GraphResult, TaskGraphRunner } from "./TaskGraphRunner";
3234

3335
/**
3436
* Configuration for running a task graph
@@ -57,9 +59,16 @@ export interface TaskGraphRunConfig {
5759
* Defaults to no limit. Set this to prevent runaway graph construction.
5860
*/
5961
maxTasks?: number;
62+
/**
63+
* When true, check entitlements via the registered IEntitlementEnforcer before
64+
* graph execution begins. Throws TaskEntitlementError if any required (non-optional)
65+
* entitlements are denied. Default: false.
66+
*/
67+
enforceEntitlements?: boolean;
6068
}
6169

62-
export interface TaskGraphRunReactiveConfig extends TaskGraphRunConfig {
70+
export interface TaskGraphRunReactiveConfig
71+
extends Omit<TaskGraphRunConfig, "enforceEntitlements" | "timeout"> {
6372
/** Optional service registry to use for this task graph */
6473
registry?: ServiceRegistry;
6574
}
@@ -563,6 +572,67 @@ export class TaskGraph implements ITaskGraph {
563572
};
564573
}
565574

575+
/**
576+
* Subscribes to entitlement changes on all tasks (existing and future).
577+
* When any task's entitlements change, the graph recomputes and emits its own
578+
* `entitlementChange` event. Structural changes (task_added, task_removed) also trigger.
579+
*
580+
* @param callback - Function called with the aggregated entitlements whenever they change
581+
* @returns a function to unsubscribe from all entitlement events
582+
*/
583+
public subscribeToTaskEntitlements(
584+
callback: (entitlements: TaskEntitlements) => void
585+
): () => void {
586+
const globalUnsubs: (() => void)[] = [];
587+
const taskUnsubs = new Map<TaskIdType, () => void>();
588+
589+
const emitChange = () => {
590+
const entitlements = computeGraphEntitlements(this);
591+
this.emit("entitlementChange", entitlements);
592+
callback(entitlements);
593+
};
594+
595+
const subscribeTask = (taskId: TaskIdType) => {
596+
const task = this.getTask(taskId);
597+
if (!task || typeof task.subscribe !== "function") return;
598+
const unsub = task.subscribe("entitlementChange", () => emitChange());
599+
taskUnsubs.set(taskId, unsub);
600+
};
601+
602+
// Subscribe to entitlementChange events on all existing tasks
603+
for (const task of this.getTasks()) {
604+
subscribeTask(task.id);
605+
}
606+
607+
// Emit the initial state immediately so subscribers don't miss the current entitlements
608+
emitChange();
609+
610+
// Subscribe to new tasks being added
611+
globalUnsubs.push(
612+
this.subscribe("task_added", (taskId: TaskIdType) => {
613+
subscribeTask(taskId);
614+
emitChange();
615+
})
616+
);
617+
618+
globalUnsubs.push(
619+
this.subscribe("task_removed", (taskId: TaskIdType) => {
620+
const unsub = taskUnsubs.get(taskId);
621+
if (unsub) {
622+
unsub();
623+
taskUnsubs.delete(taskId);
624+
}
625+
emitChange();
626+
})
627+
);
628+
629+
return () => {
630+
globalUnsubs.forEach((unsub) => unsub());
631+
taskUnsubs.forEach((unsub) => unsub());
632+
taskUnsubs.clear();
633+
};
634+
}
635+
566636
/**
567637
* Registers an event listener for the specified event
568638
* @param name - The event name to listen for

packages/task-graph/src/task-graph/TaskGraphEvents.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import { EventParameters } from "@workglow/util";
88
import type { StreamEvent } from "../task/StreamTypes";
9+
import type { TaskEntitlements } from "../task/TaskEntitlements";
910
import { TaskIdType } from "../task/TaskTypes";
1011
import { DataflowIdType } from "./Dataflow";
1112

@@ -26,6 +27,8 @@ export type TaskGraphStatusListeners = {
2627
task_stream_chunk: (taskId: TaskIdType, event: StreamEvent) => void;
2728
/** Fired when a task in the graph finishes streaming */
2829
task_stream_end: (taskId: TaskIdType, output: Record<string, any>) => void;
30+
/** Fired when the aggregated entitlements of the graph change */
31+
entitlementChange: (entitlements: TaskEntitlements) => void;
2932
};
3033
export type TaskGraphStatusEvents = keyof TaskGraphStatusListeners;
3134
export type TaskGraphStatusListener<Event extends TaskGraphStatusEvents> =

0 commit comments

Comments
 (0)