diff --git a/src/config/types.memory.ts b/src/config/types.memory.ts index 54581f65fac..fb78f0eaecf 100644 --- a/src/config/types.memory.ts +++ b/src/config/types.memory.ts @@ -20,6 +20,7 @@ export type MemoryQmdConfig = { update?: MemoryQmdUpdateConfig; limits?: MemoryQmdLimitsConfig; scope?: SessionSendPolicyConfig; + weights?: Record; }; export type MemoryQmdMcporterConfig = { diff --git a/src/config/zod-schema.ts b/src/config/zod-schema.ts index f8ad6bfcbc9..bed1e9ce494 100644 --- a/src/config/zod-schema.ts +++ b/src/config/zod-schema.ts @@ -108,6 +108,7 @@ const MemoryQmdSchema = z update: MemoryQmdUpdateSchema.optional(), limits: MemoryQmdLimitsSchema.optional(), scope: SessionSendPolicySchema.optional(), + weights: z.record(z.string(), z.number().positive()).optional(), }) .strict(); diff --git a/src/memory/backend-config.ts b/src/memory/backend-config.ts index da1c13819a3..174bd7a7518 100644 --- a/src/memory/backend-config.ts +++ b/src/memory/backend-config.ts @@ -67,6 +67,7 @@ export type ResolvedQmdConfig = { limits: ResolvedQmdLimitsConfig; includeDefaultMemory: boolean; scope?: SessionSendPolicyConfig; + weights?: Record; }; const DEFAULT_BACKEND: MemoryBackend = "builtin"; @@ -344,6 +345,7 @@ export function resolveMemoryBackendConfig(params: { }, limits: resolveLimits(qmdCfg?.limits), scope: qmdCfg?.scope ?? DEFAULT_QMD_SCOPE, + weights: qmdCfg?.weights, }; return { diff --git a/src/memory/qmd-manager.ts b/src/memory/qmd-manager.ts index 46a80156677..1f9d78c7c97 100644 --- a/src/memory/qmd-manager.ts +++ b/src/memory/qmd-manager.ts @@ -737,6 +737,10 @@ export class QmdMemoryManager implements MemorySearchManager { this.qmd.limits.maxResults, opts?.maxResults ?? this.qmd.limits.maxResults, ); + // Fetch more results to allow for client-side re-ranking (weighting) + const hasWeights = this.qmd.weights && Object.keys(this.qmd.weights).length > 0; + const fetchLimit = hasWeights ? limit * 5 : limit; + const collectionNames = this.listManagedCollectionNames(); if (collectionNames.length === 0) { log.warn("qmd query skipped: no managed collections configured"); @@ -744,6 +748,8 @@ export class QmdMemoryManager implements MemorySearchManager { } const qmdSearchCommand = this.qmd.searchMode; const mcporterEnabled = this.qmd.mcporter.enabled; + // When weights are enabled, defer minScore filtering until after re-ranking + const upstreamMinScore = hasWeights ? 0 : (opts?.minScore ?? 0); const runSearchAttempt = async ( allowMissingCollectionRepair: boolean, ): Promise => { @@ -755,13 +761,12 @@ export class QmdMemoryManager implements MemorySearchManager { : qmdSearchCommand === "vsearch" ? "vector_search" : "deep_search"; - const minScore = opts?.minScore ?? 0; if (collectionNames.length > 1) { return await this.runMcporterAcrossCollections({ tool, query: trimmed, - limit, - minScore, + limit: fetchLimit, + minScore: upstreamMinScore, collectionNames, }); } @@ -769,8 +774,8 @@ export class QmdMemoryManager implements MemorySearchManager { mcporter: this.qmd.mcporter, tool, query: trimmed, - limit, - minScore, + limit: fetchLimit, + minScore: upstreamMinScore, collection: collectionNames[0], timeoutMs: this.qmd.limits.timeoutMs, }); @@ -778,12 +783,12 @@ export class QmdMemoryManager implements MemorySearchManager { if (collectionNames.length > 1) { return await this.runQueryAcrossCollections( trimmed, - limit, + fetchLimit, collectionNames, qmdSearchCommand, ); } - const args = this.buildSearchArgs(qmdSearchCommand, trimmed, limit); + const args = this.buildSearchArgs(qmdSearchCommand, trimmed, fetchLimit); args.push(...this.buildCollectionFilterArgs(collectionNames)); // Always scope to managed collections (default + custom). Even for `search`/`vsearch`, // pass collection filters; if a given QMD build rejects these flags, we fall back to `query`. @@ -803,9 +808,14 @@ export class QmdMemoryManager implements MemorySearchManager { ); try { if (collectionNames.length > 1) { - return await this.runQueryAcrossCollections(trimmed, limit, collectionNames, "query"); + return await this.runQueryAcrossCollections( + trimmed, + fetchLimit, + collectionNames, + "query", + ); } - const fallbackArgs = this.buildSearchArgs("query", trimmed, limit); + const fallbackArgs = this.buildSearchArgs("query", trimmed, fetchLimit); fallbackArgs.push(...this.buildCollectionFilterArgs(collectionNames)); const fallback = await this.runQmd(fallbackArgs, { timeoutMs: this.qmd.limits.timeoutMs, @@ -832,6 +842,8 @@ export class QmdMemoryManager implements MemorySearchManager { parsed = await runSearchAttempt(false); } const results: MemorySearchResult[] = []; + const weights = this.qmd.weights || {}; + for (const entry of parsed) { const docHints = this.normalizeDocHints({ preferredCollection: entry.collection, @@ -843,7 +855,18 @@ export class QmdMemoryManager implements MemorySearchManager { } const snippet = entry.snippet?.slice(0, this.qmd.limits.maxSnippetChars) ?? ""; const lines = this.extractSnippetLines(snippet); - const score = typeof entry.score === "number" ? entry.score : 0; + let score = typeof entry.score === "number" ? entry.score : 0; + + // Apply weights (first matching pattern wins) + if (hasWeights) { + for (const [pattern, weight] of Object.entries(weights)) { + if (this.matchesPath(doc.rel, pattern)) { + score *= weight; + break; + } + } + } + const minScore = opts?.minScore ?? 0; if (score < minScore) { continue; @@ -857,9 +880,44 @@ export class QmdMemoryManager implements MemorySearchManager { source: doc.source, }); } + + // Re-sort if weights were applied + if (hasWeights) { + results.sort((a, b) => b.score - a.score); + } + return this.clampResultsByInjectedChars(this.diversifyResultsBySource(results, limit)); } + private matchesPath(target: string, pattern: string): boolean { + // Simple glob matching + if (pattern === "**") { + return true; + } + + // Handle "dir/**" + if (pattern.endsWith("/**")) { + const prefix = pattern.slice(0, -3); + return target.startsWith(prefix + "/") || target === prefix; + } + + // Handle "**/*.md" or "**/foo" + if (pattern.startsWith("**/")) { + const rest = pattern.slice(3); + if (rest.startsWith("*")) { + return target.endsWith(rest.slice(1)); + } + return target.endsWith("/" + rest) || target === rest; + } + + // Handle "*.md" + if (pattern.startsWith("*")) { + return target.endsWith(pattern.slice(1)); + } + // Exact match + return target === pattern; + } + async sync(params?: { reason?: string; force?: boolean;