Browse Source

Split Ollama into a smaller base version for the browser, and a subclass that supports Node (#47)

Sam Willis 1 year ago
parent
commit
2aa838c1cf
3 changed files with 206 additions and 183 deletions
  1. 2 1
      package.json
  2. 200 0
      src/browser.ts
  3. 4 182
      src/index.ts

+ 2 - 1
package.json

@@ -6,7 +6,8 @@
   "main": "dist/index.js",
   "main": "dist/index.js",
   "types": "dist/index.d.ts",
   "types": "dist/index.d.ts",
   "exports": {
   "exports": {
-    ".": "./dist/index.js"
+    ".": "./dist/index.js",
+    "./browser": "./dist/browser.js"
   },
   },
   "scripts": {
   "scripts": {
     "format": "prettier --write .",
     "format": "prettier --write .",

+ 200 - 0
src/browser.ts

@@ -0,0 +1,200 @@
+import * as utils from './utils.js'
+import 'whatwg-fetch'
+
+import type {
+  Fetch,
+  Config,
+  GenerateRequest,
+  PullRequest,
+  PushRequest,
+  EmbeddingsRequest,
+  GenerateResponse,
+  EmbeddingsResponse,
+  ListResponse,
+  ProgressResponse,
+  ErrorResponse,
+  StatusResponse,
+  DeleteRequest,
+  CopyRequest,
+  ShowResponse,
+  ShowRequest,
+  ChatRequest,
+  ChatResponse,
+} from './interfaces.js'
+
+export class Ollama {
+  protected readonly config: Config
+  protected readonly fetch: Fetch
+  private abortController: AbortController
+
+  constructor(config?: Partial<Config>) {
+    this.config = {
+      host: '',
+    }
+    if (!config?.proxy) {
+      this.config.host = utils.formatHost(config?.host ?? 'http://127.0.0.1:11434')
+    }
+
+    this.fetch = fetch
+    if (config?.fetch != null) {
+      this.fetch = config.fetch
+    }
+
+    this.abortController = new AbortController()
+  }
+
+  // Abort any ongoing requests to Ollama
+  public abort() {
+    this.abortController.abort()
+    this.abortController = new AbortController()
+  }
+
+  protected async processStreamableRequest<T extends object>(
+    endpoint: string,
+    request: { stream?: boolean } & Record<string, any>,
+  ): Promise<T | AsyncGenerator<T>> {
+    request.stream = request.stream ?? false
+    const response = await utils.post(
+      this.fetch,
+      `${this.config.host}/api/${endpoint}`,
+      {
+        ...request,
+      },
+      { signal: this.abortController.signal },
+    )
+
+    if (!response.body) {
+      throw new Error('Missing body')
+    }
+
+    const itr = utils.parseJSON<T | ErrorResponse>(response.body)
+
+    if (request.stream) {
+      return (async function* () {
+        for await (const message of itr) {
+          if ('error' in message) {
+            throw new Error(message.error)
+          }
+          yield message
+          // message will be done in the case of chat and generate
+          // message will be success in the case of a progress response (pull, push, create)
+          if ((message as any).done || (message as any).status === 'success') {
+            return
+          }
+        }
+        throw new Error('Did not receive done or success response in stream.')
+      })()
+    } else {
+      const message = await itr.next()
+      if (!message.value.done && (message.value as any).status !== 'success') {
+        throw new Error('Expected a completed response.')
+      }
+      return message.value
+    }
+  }
+
+  async encodeImage(image: Uint8Array | Buffer | string): Promise<string> {
+    if (typeof image !== 'string') {
+      // image is Uint8Array or Buffer, convert it to base64
+      const result = Buffer.from(image).toString('base64')
+      return result
+    }
+    // the string may be base64 encoded
+    return image
+  }
+
+  generate(
+    request: GenerateRequest & { stream: true },
+  ): Promise<AsyncGenerator<GenerateResponse>>
+  generate(request: GenerateRequest & { stream?: false }): Promise<GenerateResponse>
+
+  async generate(
+    request: GenerateRequest,
+  ): Promise<GenerateResponse | AsyncGenerator<GenerateResponse>> {
+    if (request.images) {
+      request.images = await Promise.all(request.images.map(this.encodeImage.bind(this)))
+    }
+    return this.processStreamableRequest<GenerateResponse>('generate', request)
+  }
+
+  chat(request: ChatRequest & { stream: true }): Promise<AsyncGenerator<ChatResponse>>
+  chat(request: ChatRequest & { stream?: false }): Promise<ChatResponse>
+
+  async chat(request: ChatRequest): Promise<ChatResponse | AsyncGenerator<ChatResponse>> {
+    if (request.messages) {
+      for (const message of request.messages) {
+        if (message.images) {
+          message.images = await Promise.all(
+            message.images.map(this.encodeImage.bind(this)),
+          )
+        }
+      }
+    }
+    return this.processStreamableRequest<ChatResponse>('chat', request)
+  }
+
+  pull(request: PullRequest & { stream: true }): Promise<AsyncGenerator<ProgressResponse>>
+  pull(request: PullRequest & { stream?: false }): Promise<ProgressResponse>
+
+  async pull(
+    request: PullRequest,
+  ): Promise<ProgressResponse | AsyncGenerator<ProgressResponse>> {
+    return this.processStreamableRequest<ProgressResponse>('pull', {
+      name: request.model,
+      stream: request.stream,
+      insecure: request.insecure,
+    })
+  }
+
+  push(request: PushRequest & { stream: true }): Promise<AsyncGenerator<ProgressResponse>>
+  push(request: PushRequest & { stream?: false }): Promise<ProgressResponse>
+
+  async push(
+    request: PushRequest,
+  ): Promise<ProgressResponse | AsyncGenerator<ProgressResponse>> {
+    return this.processStreamableRequest<ProgressResponse>('push', {
+      name: request.model,
+      stream: request.stream,
+      insecure: request.insecure,
+    })
+  }
+
+  async delete(request: DeleteRequest): Promise<StatusResponse> {
+    await utils.del(this.fetch, `${this.config.host}/api/delete`, {
+      name: request.model,
+    })
+    return { status: 'success' }
+  }
+
+  async copy(request: CopyRequest): Promise<StatusResponse> {
+    await utils.post(this.fetch, `${this.config.host}/api/copy`, { ...request })
+    return { status: 'success' }
+  }
+
+  async list(): Promise<ListResponse> {
+    const response = await utils.get(this.fetch, `${this.config.host}/api/tags`)
+    const listResponse = (await response.json()) as ListResponse
+    return listResponse
+  }
+
+  async show(request: ShowRequest): Promise<ShowResponse> {
+    const response = await utils.post(this.fetch, `${this.config.host}/api/show`, {
+      ...request,
+    })
+    const showResponse = (await response.json()) as ShowResponse
+    return showResponse
+  }
+
+  async embeddings(request: EmbeddingsRequest): Promise<EmbeddingsResponse> {
+    const response = await utils.post(this.fetch, `${this.config.host}/api/embeddings`, {
+      ...request,
+    })
+    const embeddingsResponse = (await response.json()) as EmbeddingsResponse
+    return embeddingsResponse
+  }
+}
+
+export default new Ollama()
+
+// export all types from the main entry point so that packages importing types dont need to specify paths
+export * from './interfaces.js'

+ 4 - 182
src/index.ts

@@ -1,108 +1,20 @@
 import * as utils from './utils.js'
 import * as utils from './utils.js'
-import 'whatwg-fetch'
 import fs, { promises, createReadStream } from 'fs'
 import fs, { promises, createReadStream } from 'fs'
 import { join, resolve, dirname } from 'path'
 import { join, resolve, dirname } from 'path'
 import { createHash } from 'crypto'
 import { createHash } from 'crypto'
 import { homedir } from 'os'
 import { homedir } from 'os'
+import { Ollama as OllamaBrowser } from './browser.js'
 
 
 import type {
 import type {
-  Fetch,
-  Config,
-  GenerateRequest,
-  PullRequest,
-  PushRequest,
   CreateRequest,
   CreateRequest,
-  EmbeddingsRequest,
-  GenerateResponse,
-  EmbeddingsResponse,
-  ListResponse,
   ProgressResponse,
   ProgressResponse,
-  ErrorResponse,
-  StatusResponse,
-  DeleteRequest,
-  CopyRequest,
-  ShowResponse,
-  ShowRequest,
-  ChatRequest,
-  ChatResponse,
 } from './interfaces.js'
 } from './interfaces.js'
 
 
-export class Ollama {
-  private readonly config: Config
-  private readonly fetch: Fetch
-  private abortController: AbortController
+export class Ollama extends OllamaBrowser {
 
 
-  constructor(config?: Partial<Config>) {
-    this.config = {
-      host: '',
-    }
-    if (!config?.proxy) {
-      this.config.host = utils.formatHost(config?.host ?? 'http://127.0.0.1:11434')
-    }
-
-    this.fetch = fetch
-    if (config?.fetch != null) {
-      this.fetch = config.fetch
-    }
-
-    this.abortController = new AbortController()
-  }
-
-  // Abort any ongoing requests to Ollama
-  public abort() {
-    this.abortController.abort()
-    this.abortController = new AbortController()
-  }
-
-  private async processStreamableRequest<T extends object>(
-    endpoint: string,
-    request: { stream?: boolean } & Record<string, any>,
-  ): Promise<T | AsyncGenerator<T>> {
-    request.stream = request.stream ?? false
-    const response = await utils.post(
-      this.fetch,
-      `${this.config.host}/api/${endpoint}`,
-      {
-        ...request,
-      },
-      { signal: this.abortController.signal },
-    )
-
-    if (!response.body) {
-      throw new Error('Missing body')
-    }
-
-    const itr = utils.parseJSON<T | ErrorResponse>(response.body)
-
-    if (request.stream) {
-      return (async function* () {
-        for await (const message of itr) {
-          if ('error' in message) {
-            throw new Error(message.error)
-          }
-          yield message
-          // message will be done in the case of chat and generate
-          // message will be success in the case of a progress response (pull, push, create)
-          if ((message as any).done || (message as any).status === 'success') {
-            return
-          }
-        }
-        throw new Error('Did not receive done or success response in stream.')
-      })()
-    } else {
-      const message = await itr.next()
-      if (!message.value.done && (message.value as any).status !== 'success') {
-        throw new Error('Expected a completed response.')
-      }
-      return message.value
-    }
-  }
-
-  private async encodeImage(image: Uint8Array | Buffer | string): Promise<string> {
+  async encodeImage(image: Uint8Array | Buffer | string): Promise<string> {
     if (typeof image !== 'string') {
     if (typeof image !== 'string') {
-      // image is Uint8Array or Buffer, convert it to base64
-      const result = Buffer.from(image).toString('base64')
-      return result
+      return super.encodeImage(image)
     }
     }
     try {
     try {
       if (fs.existsSync(image)) {
       if (fs.existsSync(image)) {
@@ -209,62 +121,6 @@ export class Ollama {
     return digest
     return digest
   }
   }
 
 
-  generate(
-    request: GenerateRequest & { stream: true },
-  ): Promise<AsyncGenerator<GenerateResponse>>
-  generate(request: GenerateRequest & { stream?: false }): Promise<GenerateResponse>
-
-  async generate(
-    request: GenerateRequest,
-  ): Promise<GenerateResponse | AsyncGenerator<GenerateResponse>> {
-    if (request.images) {
-      request.images = await Promise.all(request.images.map(this.encodeImage.bind(this)))
-    }
-    return this.processStreamableRequest<GenerateResponse>('generate', request)
-  }
-
-  chat(request: ChatRequest & { stream: true }): Promise<AsyncGenerator<ChatResponse>>
-  chat(request: ChatRequest & { stream?: false }): Promise<ChatResponse>
-
-  async chat(request: ChatRequest): Promise<ChatResponse | AsyncGenerator<ChatResponse>> {
-    if (request.messages) {
-      for (const message of request.messages) {
-        if (message.images) {
-          message.images = await Promise.all(
-            message.images.map(this.encodeImage.bind(this)),
-          )
-        }
-      }
-    }
-    return this.processStreamableRequest<ChatResponse>('chat', request)
-  }
-
-  pull(request: PullRequest & { stream: true }): Promise<AsyncGenerator<ProgressResponse>>
-  pull(request: PullRequest & { stream?: false }): Promise<ProgressResponse>
-
-  async pull(
-    request: PullRequest,
-  ): Promise<ProgressResponse | AsyncGenerator<ProgressResponse>> {
-    return this.processStreamableRequest<ProgressResponse>('pull', {
-      name: request.model,
-      stream: request.stream,
-      insecure: request.insecure,
-    })
-  }
-
-  push(request: PushRequest & { stream: true }): Promise<AsyncGenerator<ProgressResponse>>
-  push(request: PushRequest & { stream?: false }): Promise<ProgressResponse>
-
-  async push(
-    request: PushRequest,
-  ): Promise<ProgressResponse | AsyncGenerator<ProgressResponse>> {
-    return this.processStreamableRequest<ProgressResponse>('push', {
-      name: request.model,
-      stream: request.stream,
-      insecure: request.insecure,
-    })
-  }
-
   create(
   create(
     request: CreateRequest & { stream: true },
     request: CreateRequest & { stream: true },
   ): Promise<AsyncGenerator<ProgressResponse>>
   ): Promise<AsyncGenerator<ProgressResponse>>
@@ -292,40 +148,6 @@ export class Ollama {
       modelfile: modelfileContent,
       modelfile: modelfileContent,
     })
     })
   }
   }
-
-  async delete(request: DeleteRequest): Promise<StatusResponse> {
-    await utils.del(this.fetch, `${this.config.host}/api/delete`, {
-      name: request.model,
-    })
-    return { status: 'success' }
-  }
-
-  async copy(request: CopyRequest): Promise<StatusResponse> {
-    await utils.post(this.fetch, `${this.config.host}/api/copy`, { ...request })
-    return { status: 'success' }
-  }
-
-  async list(): Promise<ListResponse> {
-    const response = await utils.get(this.fetch, `${this.config.host}/api/tags`)
-    const listResponse = (await response.json()) as ListResponse
-    return listResponse
-  }
-
-  async show(request: ShowRequest): Promise<ShowResponse> {
-    const response = await utils.post(this.fetch, `${this.config.host}/api/show`, {
-      ...request,
-    })
-    const showResponse = (await response.json()) as ShowResponse
-    return showResponse
-  }
-
-  async embeddings(request: EmbeddingsRequest): Promise<EmbeddingsResponse> {
-    const response = await utils.post(this.fetch, `${this.config.host}/api/embeddings`, {
-      ...request,
-    })
-    const embeddingsResponse = (await response.json()) as EmbeddingsResponse
-    return embeddingsResponse
-  }
 }
 }
 
 
 export default new Ollama()
 export default new Ollama()