ソースを参照

browser/utils: updates for web search and loading OLLAMA_API_KEY from the environment (#250)

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
Parth Sareen 16 時間 前
コミット
a6689ac591
4 ファイル変更125 行追加76 行削除
  1. 40 29
      examples/websearch/websearch-tools.ts
  2. 14 12
      src/browser.ts
  3. 7 8
      src/interfaces.ts
  4. 64 27
      src/utils.ts

+ 40 - 29
examples/websearch/websearch-tools.ts

@@ -1,12 +1,17 @@
-import { Ollama, type Message, type SearchResponse, type FetchResponse } from 'ollama'
+import {
+  Ollama,
+  type Message,
+  type WebSearchResponse,
+  type WebFetchResponse,
+} from 'ollama'
 
 async function main() {
-
-  if (!process.env.OLLAMA_API_KEY) throw new Error('Set OLLAMA_API_KEY to use web search tools')
-
-  const client = new Ollama({
-    headers: { Authorization: `Bearer ${process.env.OLLAMA_API_KEY}` },
-  })
+  // Set enviornment variable OLLAMA_API_KEY=<YOUR>.<KEY>
+  // or set the header manually
+  //   const client = new Ollama({
+  //     headers: { Authorization: `Bearer ${process.env.OLLAMA_API_KEY}` },
+  //   })
+  const client = new Ollama()
 
   // Tool schemas
   const webSearchTool = {
@@ -20,7 +25,7 @@ async function main() {
           query: { type: 'string', description: 'Search query string.' },
           max_results: {
             type: 'number',
-            description: 'The maximum number of results to return per query (default 5, max 10).',
+            description: 'The maximum number of results to return per query (default 3).',
           },
         },
         required: ['query'],
@@ -43,28 +48,32 @@ async function main() {
     },
   }
 
-	const availableTools = {
-		webSearch: async (args: { query: string; max_results?: number }): Promise<SearchResponse> => {
-			const res = await client.webSearch(args)
-			return res as SearchResponse
-		},
-		webFetch: async (args: { url: string }): Promise<FetchResponse> => {
-			const res = await client.webFetch(args)
-			return res as FetchResponse
-		},
-	}
+  const availableTools = {
+    webSearch: async (args: {
+      query: string
+      max_results?: number
+    }): Promise<WebSearchResponse> => {
+      const res = await client.webSearch(args)
+      return res as WebSearchResponse
+    },
+    webFetch: async (args: { url: string }): Promise<WebFetchResponse> => {
+      const res = await client.webFetch(args)
+      return res as WebFetchResponse
+    },
+  }
+
+  const query = 'What is Ollama?'
+  console.log('Prompt:', query, '\n')
 
   const messages: Message[] = [
     {
       role: 'user',
-      content: 'What is Ollama?',
+      content: query,
     },
   ]
 
-  console.log('----- Prompt:', messages.find((m) => m.role === 'user')?.content, '\n')
-  
   while (true) {
-	const response = await client.chat({
+    const response = await client.chat({
       model: 'qwen3',
       messages: messages,
       tools: [webSearchTool, webFetchTool],
@@ -76,7 +85,6 @@ async function main() {
     var content = ''
     var thinking = ''
     for await (const chunk of response) {
-
       if (chunk.message.thinking) {
         thinking += chunk.message.thinking
       }
@@ -97,14 +105,19 @@ async function main() {
           const functionToCall = availableTools[toolCall.function.name]
           if (functionToCall) {
             const args = toolCall.function.arguments as any
-            console.log('\nCalling function:', toolCall.function.name, 'with arguments:', args)
+            console.log(
+              '\nCalling function:',
+              toolCall.function.name,
+              'with arguments:',
+              args,
+            )
             const output = await functionToCall(args)
             console.log('Function result:', JSON.stringify(output).slice(0, 200), '\n')
-            
+
             messages.push(chunk.message)
             messages.push({
               role: 'tool',
-              content: JSON.stringify(output),
+              content: JSON.stringify(output).slice(0, 2000 * 4), // cap at ~2000 tokens
               tool_name: toolCall.function.name,
             })
           }
@@ -116,9 +129,7 @@ async function main() {
       process.stdout.write('\n')
       break
     }
-
-    
   }
 }
 
-main().catch(console.error)
+main().catch(console.error)

+ 14 - 12
src/browser.ts

@@ -24,10 +24,10 @@ import type {
   ShowRequest,
   ShowResponse,
   StatusResponse,
-  SearchRequest,
-  SearchResponse,
-  FetchRequest,
-  FetchResponse,
+  WebSearchRequest,
+  WebSearchResponse,
+  WebFetchRequest,
+  WebFetchResponse,
 } from './interfaces.js'
 import { defaultHost } from './constant.js'
 
@@ -49,6 +49,8 @@ export class Ollama {
     this.fetch = config?.fetch ?? fetch
   }
 
+
+
   // Abort any ongoing streamed requests to Ollama
   public abort() {
     for (const request of this.ongoingStreamedRequests) {
@@ -327,32 +329,32 @@ async encodeImage(image: Uint8Array | string): Promise<string> {
 
   /**
    * Performs web search using the Ollama web search API
-   * @param request {SearchRequest} - The search request containing query and options
-   * @returns {Promise<SearchResponse>} - The search results
+   * @param request {WebSearchRequest} - The search request containing query and options
+   * @returns {Promise<WebSearchResponse>} - The search results
    * @throws {Error} - If the request is invalid or the server returns an error
    */
-  async webSearch(request: SearchRequest): Promise<SearchResponse> {
+  async webSearch(request: WebSearchRequest): Promise<WebSearchResponse> {
     if (!request.query || request.query.length === 0) {
       throw new Error('Query is required')
     }
     const response = await utils.post(this.fetch, `https://ollama.com/api/web_search`, { ...request }, {
       headers: this.config.headers
     })
-    return (await response.json()) as SearchResponse
+    return (await response.json()) as WebSearchResponse
   }
 
   /**
    * Fetches a single page using the Ollama web fetch API
-   * @param request {FetchRequest} - The fetch request containing a URL
-   * @returns {Promise<FetchResponse>} - The fetch result
+   * @param request {WebFetchRequest} - The fetch request containing a URL
+   * @returns {Promise<WebFetchResponse>} - The fetch result
    * @throws {Error} - If the request is invalid or the server returns an error
    */
-  async webFetch(request: FetchRequest): Promise<FetchResponse> {
+  async webFetch(request: WebFetchRequest): Promise<WebFetchResponse> {
     if (!request.url || request.url.length === 0) {
       throw new Error('URL is required')
     }
     const response = await utils.post(this.fetch, `https://ollama.com/api/web_fetch`, { ...request }, { headers: this.config.headers })
-    return (await response.json()) as FetchResponse
+    return (await response.json()) as WebFetchResponse
   }
 }
 

+ 7 - 8
src/interfaces.ts

@@ -270,26 +270,25 @@ export interface StatusResponse {
   status: string
 }
 
-// Web Search types
-export interface SearchRequest {
+export interface WebSearchRequest {
   query: string
-  max_results?: number
+  maxResults?: number
 }
 
-export interface SearchResult {
+export interface WebSearchResult {
   content: string
 }
 
-export interface SearchResponse {
-  results: SearchResult[]
+export interface WebSearchResponse {
+  results: WebSearchResult[]
 }
 
 // Fetch types
-export interface FetchRequest {
+export interface WebFetchRequest {
   url: string
 }
 
-export interface FetchResponse {
+export interface WebFetchResponse {
   title: string
   url: string
   content: string

+ 64 - 27
src/utils.ts

@@ -28,7 +28,11 @@ export class AbortableAsyncIterator<T extends object> {
   private readonly itr: AsyncGenerator<T | ErrorResponse>
   private readonly doneCallback: () => void
 
-  constructor(abortController: AbortController, itr: AsyncGenerator<T | ErrorResponse>, doneCallback: () => void) {
+  constructor(
+    abortController: AbortController,
+    itr: AsyncGenerator<T | ErrorResponse>,
+    doneCallback: () => void,
+  ) {
     this.abortController = abortController
     this.itr = itr
     this.doneCallback = doneCallback
@@ -119,23 +123,27 @@ function getPlatform(): string {
  *   - An array of key-value pairs representing headers.
  * @returns {Record<string,string>} - A plain object representing the normalized headers.
  */
-function normalizeHeaders(headers?: HeadersInit | undefined): Record<string,string> {
+function normalizeHeaders(headers?: HeadersInit | undefined): Record<string, string> {
   if (headers instanceof Headers) {
-      // If headers are an instance of Headers, convert it to an object
-      const obj: Record<string, string> = {};
-        headers.forEach((value, key) => {
-          obj[key] = value;
-        });
-        return obj;
+    // If headers are an instance of Headers, convert it to an object
+    const obj: Record<string, string> = {}
+    headers.forEach((value, key) => {
+      obj[key] = value
+    })
+    return obj
   } else if (Array.isArray(headers)) {
-      // If headers are in array format, convert them to an object
-      return Object.fromEntries(headers);
+    // If headers are in array format, convert them to an object
+    return Object.fromEntries(headers)
   } else {
-      // Otherwise assume it's already a plain object
-      return headers || {};
+    // Otherwise assume it's already a plain object
+    return headers || {}
   }
 }
 
+const readEnvVar = (obj: object, key: string): string | undefined => {
+  return obj[key]
+}
+
 /**
  * A wrapper around fetch that adds default headers.
  * @param fetch {Fetch} - The fetch function to use
@@ -155,16 +163,41 @@ const fetchWithHeaders = async (
   } as HeadersInit
 
   // Normalizes headers into a plain object format.
-  options.headers = normalizeHeaders(options.headers);
-  
-  // Filter out default headers from custom headers
+  options.headers = normalizeHeaders(options.headers)
+
+  // Automatically add the API key to the headers if the URL is https://ollama.com
+  try {
+    const parsed = new URL(url)
+    if (parsed.protocol === 'https:' && parsed.hostname === 'ollama.com') {
+      const apiKey =
+        typeof process === 'object' &&
+        process !== null &&
+        typeof process.env === 'object' &&
+        process.env !== null
+          ? readEnvVar(process.env, 'OLLAMA_API_KEY')
+          : undefined
+      const authorization =
+        options.headers['authorization'] || options.headers['Authorization']
+      if (!authorization && apiKey) {
+        options.headers['Authorization'] = `Bearer ${apiKey}`
+      }
+    }
+  } catch (error) {
+    console.error('error parsing url', error)
+  }
+
   const customHeaders = Object.fromEntries(
-    Object.entries(options.headers).filter(([key]) => !Object.keys(defaultHeaders).some(defaultKey => defaultKey.toLowerCase() === key.toLowerCase()))
+    Object.entries(options.headers).filter(
+      ([key]) =>
+        !Object.keys(defaultHeaders).some(
+          (defaultKey) => defaultKey.toLowerCase() === key.toLowerCase(),
+        ),
+    ),
   )
 
   options.headers = {
     ...defaultHeaders,
-    ...customHeaders
+    ...customHeaders,
   }
 
   return fetch(url, options)
@@ -176,9 +209,13 @@ const fetchWithHeaders = async (
  * @param host {string} - The host to fetch
  * @returns {Promise<Response>} - The fetch response
  */
-export const get = async (fetch: Fetch, host: string, options?: { headers?: HeadersInit }): Promise<Response> => {
+export const get = async (
+  fetch: Fetch,
+  host: string,
+  options?: { headers?: HeadersInit },
+): Promise<Response> => {
   const response = await fetchWithHeaders(fetch, host, {
-    headers: options?.headers
+    headers: options?.headers,
   })
 
   await checkOk(response)
@@ -212,7 +249,7 @@ export const post = async (
   fetch: Fetch,
   host: string,
   data?: Record<string, unknown> | BodyInit,
-  options?: { signal?: AbortSignal, headers?: HeadersInit },
+  options?: { signal?: AbortSignal; headers?: HeadersInit },
 ): Promise<Response> => {
   const isRecord = (input: any): input is Record<string, unknown> => {
     return input !== null && typeof input === 'object' && !Array.isArray(input)
@@ -224,7 +261,7 @@ export const post = async (
     method: 'POST',
     body: formattedData,
     signal: options?.signal,
-    headers: options?.headers
+    headers: options?.headers,
   })
 
   await checkOk(response)
@@ -247,7 +284,7 @@ export const del = async (
   const response = await fetchWithHeaders(fetch, host, {
     method: 'DELETE',
     body: JSON.stringify(data),
-    headers: options?.headers
+    headers: options?.headers,
   })
 
   await checkOk(response)
@@ -332,16 +369,16 @@ export const formatHost = (host: string): string => {
   }
 
   // Build basic auth part if present
-  let auth = '';
+  let auth = ''
   if (url.username) {
-    auth = url.username;
+    auth = url.username
     if (url.password) {
-      auth += `:${url.password}`;
+      auth += `:${url.password}`
     }
-    auth += '@';
+    auth += '@'
   }
 
-  let formattedHost = `${url.protocol}//${auth}${url.hostname}:${port}${url.pathname}`;
+  let formattedHost = `${url.protocol}//${auth}${url.hostname}:${port}${url.pathname}`
   // remove trailing slashes
   if (formattedHost.endsWith('/')) {
     formattedHost = formattedHost.slice(0, -1)