Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid of a11y tree & DOM for input to observe #459

Merged
merged 14 commits into from
Feb 6, 2025
5 changes: 5 additions & 0 deletions .changeset/chilled-apes-sneeze.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

create a11y + dom hybrid input for observe
182 changes: 143 additions & 39 deletions lib/a11y/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ export function formatSimplifiedTree(
level = 0,
): string {
const indent = " ".repeat(level);
let result = `${indent}[${node.nodeId}] ${node.role}${node.name ? `: ${node.name}` : ""}\n`;
let result = `${indent}[${node.nodeId}] ${node.role}${
node.name ? `: ${node.name}` : ""
}\n`;

if (node.children?.length) {
result += node.children
Expand All @@ -29,39 +31,113 @@ export function formatSimplifiedTree(
* 1. Removes generic/none nodes with no children
* 2. Collapses generic/none nodes with single child
* 3. Keeps generic/none nodes with multiple children but cleans their subtrees
* and attempts to resolve their role to a DOM tag name
*/
function cleanStructuralNodes(
async function cleanStructuralNodes(
node: AccessibilityNode,
): AccessibilityNode | null {
// Filter out nodes with negative IDs
page?: StagehandPage,
logger?: (logLine: LogLine) => void,
): Promise<AccessibilityNode | null> {
// 1) Filter out nodes with negative IDs
if (node.nodeId && parseInt(node.nodeId) < 0) {
return null;
}

// Base case: leaf node
if (!node.children) {
// 2) Base case: if no children exist, this is effectively a leaf.
// If it's "generic" or "none", we remove it; otherwise, keep it.
if (!node.children || node.children.length === 0) {
return node.role === "generic" || node.role === "none" ? null : node;
}

// Recursively clean children
const cleanedChildren = node.children
.map((child) => cleanStructuralNodes(child))
.filter(Boolean) as AccessibilityNode[];

// Handle generic/none nodes specially
// 3) Recursively clean children
const cleanedChildrenPromises = node.children.map((child) =>
cleanStructuralNodes(child, page, logger),
);
const resolvedChildren = await Promise.all(cleanedChildrenPromises);
const cleanedChildren = resolvedChildren.filter(
(child): child is AccessibilityNode => child !== null,
);

// 4) **Prune** "generic" or "none" nodes first,
// before resolving them to their tag names.
if (node.role === "generic" || node.role === "none") {
if (cleanedChildren.length === 1) {
// Collapse single-child generic nodes
// Collapse single-child structural node
return cleanedChildren[0];
} else if (cleanedChildren.length > 1) {
// Keep generic nodes with multiple children
return { ...node, children: cleanedChildren };
} else if (cleanedChildren.length === 0) {
// Remove empty structural node
return null;
}
// If we have multiple children, we keep this node as a container.
// We'll update role below if needed.
}

// 5) If we still have a "generic"/"none" node after pruning
// (i.e., because it had multiple children), now we try
// to resolve and replace its role with the DOM tag name.
if (
page &&
logger &&
node.backendDOMNodeId !== undefined &&
(node.role === "generic" || node.role === "none")
) {
try {
const { object } = await page.sendCDP<{
object: { objectId?: string };
}>("DOM.resolveNode", {
backendNodeId: node.backendDOMNodeId,
});

if (object && object.objectId) {
try {
// Get the tagName for the node
const { result } = await page.sendCDP<{
result: { type: string; value?: string };
}>("Runtime.callFunctionOn", {
objectId: object.objectId,
functionDeclaration: `
function() {
return this.tagName ? this.tagName.toLowerCase() : "";
}
`,
returnByValue: true,
});

// If we got a tagName, update the node's role
if (result?.value) {
node.role = result.value;
}
} catch (tagNameError) {
logger({
category: "observation",
message: `Could not fetch tagName for node ${node.backendDOMNodeId}`,
level: 2,
auxiliary: {
error: {
value: tagNameError.message,
type: "string",
},
},
});
}
}
} catch (resolveError) {
logger({
category: "observation",
message: `Could not resolve DOM node ID ${node.backendDOMNodeId}`,
level: 2,
auxiliary: {
error: {
value: resolveError.message,
type: "string",
},
},
});
}
// Remove generic nodes with no children
return null;
}

// For non-generic nodes, keep them if they have children after cleaning
// 6) Return the updated node.
// If it has children, update them; otherwise keep it as-is.
return cleanedChildren.length > 0
? { ...node, children: cleanedChildren }
: node;
Expand All @@ -73,13 +149,23 @@ function cleanStructuralNodes(
* @param nodes - Flat array of accessibility nodes from the CDP
* @returns Object containing both the tree structure and a simplified string representation
*/
export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
export async function buildHierarchicalTree(
nodes: AccessibilityNode[],
page?: StagehandPage,
logger?: (logLine: LogLine) => void,
): Promise<TreeResult> {
// Map to store processed nodes for quick lookup
const nodeMap = new Map<string, AccessibilityNode>();

// First pass: Create nodes that are meaningful
// We only keep nodes that either have a name or children to avoid cluttering the tree
nodes.forEach((node) => {
// Skip node if its ID is negative (e.g., "-1000002014")
const nodeIdValue = parseInt(node.nodeId, 10);
if (nodeIdValue < 0) {
return;
}

const hasChildren = node.childIds && node.childIds.length > 0;
const hasValidName = node.name && node.name.trim() !== "";
const isInteractive =
Expand All @@ -99,6 +185,9 @@ export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
...(hasValidName && { name: node.name }), // Only include name if it exists and isn't empty
...(node.description && { description: node.description }),
...(node.value && { value: node.value }),
...(node.backendDOMNodeId !== undefined && {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt backendDOMNodeId === nodeId?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, nodeId is a11y specific (remember we used to have negative values for nodeIds)

backendDOMNodeId: node.backendDOMNodeId,
}),
});
});

Expand All @@ -119,13 +208,18 @@ export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
});

// Final pass: Build the root-level tree and clean up structural nodes
const finalTree = nodes
const rootNodes = nodes
.filter((node) => !node.parentId && nodeMap.has(node.nodeId)) // Get root nodes
.map((node) => nodeMap.get(node.nodeId))
.filter(Boolean)
.map((node) => cleanStructuralNodes(node))
.filter(Boolean) as AccessibilityNode[];

const cleanedTreePromises = rootNodes.map((node) =>
cleanStructuralNodes(node, page, logger),
);
const finalTree = (await Promise.all(cleanedTreePromises)).filter(
Boolean,
) as AccessibilityNode[];

// Generate a simplified string representation of the tree
const simplifiedFormat = finalTree
.map((node) => formatSimplifiedTree(node))
Expand All @@ -137,29 +231,43 @@ export function buildHierarchicalTree(nodes: AccessibilityNode[]): TreeResult {
};
}

/**
* Retrieves the full accessibility tree via CDP and transforms it into a hierarchical structure.
*/
export async function getAccessibilityTree(
page: StagehandPage,
logger: (logLine: LogLine) => void,
) {
): Promise<TreeResult> {
await page.enableCDP("Accessibility");

try {
// Fetch the full accessibility tree from Chrome DevTools Protocol
const { nodes } = await page.sendCDP<{ nodes: AXNode[] }>(
"Accessibility.getFullAXTree",
);
const startTime = Date.now();

// Extract specific sources
const sources = nodes.map((node) => ({
role: node.role?.value,
name: node.name?.value,
description: node.description?.value,
value: node.value?.value,
nodeId: node.nodeId,
parentId: node.parentId,
childIds: node.childIds,
}));
// Transform into hierarchical structure
const hierarchicalTree = buildHierarchicalTree(sources);
const hierarchicalTree = await buildHierarchicalTree(
nodes.map((node) => ({
role: node.role?.value,
name: node.name?.value,
description: node.description?.value,
value: node.value?.value,
nodeId: node.nodeId,
backendDOMNodeId: node.backendDOMNodeId,
parentId: node.parentId,
childIds: node.childIds,
})),
page,
logger,
);

logger({
category: "observation",
message: `got accessibility tree in ${Date.now() - startTime}ms`,
level: 1,
});

return hierarchicalTree;
} catch (error) {
Expand Down Expand Up @@ -258,7 +366,6 @@ export async function performPlaywrightMethod(
method: string,
args: unknown[],
xpath: string,
// domSettleTimeoutMs?: number,
) {
const locator = stagehandPage.locator(`xpath=${xpath}`).first();
const initialUrl = stagehandPage.url();
Expand Down Expand Up @@ -503,7 +610,6 @@ export async function performPlaywrightMethod(
await newOpenedTab.close();
await stagehandPage.goto(newOpenedTab.url());
await stagehandPage.waitForLoadState("domcontentloaded");
// await stagehandPage._waitForSettledDom(domSettleTimeoutMs);
}

await Promise.race([
Expand Down Expand Up @@ -564,6 +670,4 @@ export async function performPlaywrightMethod(
`Method ${method} not supported`,
);
}

// await stagehandPage._waitForSettledDom(domSettleTimeoutMs);
}
3 changes: 1 addition & 2 deletions lib/handlers/observeHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export class StagehandObserveHandler {
isUsingAccessibilityTree: useAccessibilityTree,
returnAction,
});

const elementsWithSelectors = await Promise.all(
observationResponse.elements.map(async (element) => {
const { elementId, ...rest } = element;
Expand Down Expand Up @@ -137,7 +138,6 @@ export class StagehandObserveHandler {
message: `Invalid object ID returned for element: ${elementId}`,
level: 1,
});
return null;
}

const xpath = await getXPathByResolvedObjectId(
Expand All @@ -151,7 +151,6 @@ export class StagehandObserveHandler {
message: `Empty xpath returned for element: ${elementId}`,
level: 1,
});
return null;
}

return {
Expand Down
2 changes: 1 addition & 1 deletion lib/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ You will be given:
1. a instruction of elements to observe
2. ${
isUsingAccessibilityTree
? "a hierarchical accessibility tree showing the semantic structure of the page"
? "a hierarchical accessibility tree showing the semantic structure of the page. The tree is a hybrid of the DOM and the accessibility tree."
: "a numbered list of possible elements"
}

Expand Down
2 changes: 2 additions & 0 deletions types/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export interface AXNode {
description?: { value: string };
value?: { value: string };
nodeId: string;
backendDOMNodeId?: number;
parentId?: string;
childIds?: string[];
}
Expand All @@ -17,6 +18,7 @@ export type AccessibilityNode = {
childIds?: string[];
parentId?: string;
nodeId?: string;
backendDOMNodeId?: number;
};

export interface TreeResult {
Expand Down
Loading