Parcourir la source

match host formatting

Bruce MacDonald il y a 1 an
Parent
commit
56a4a8791c
4 fichiers modifiés avec 118 ajouts et 32 suppressions
  1. 16 20
      src/index.ts
  2. 2 2
      src/interfaces.ts
  3. 46 8
      src/utils.ts
  4. 54 2
      test/index.spec.ts

+ 16 - 20
src/index.ts

@@ -31,10 +31,10 @@ export class Ollama {
   private readonly config: Config
   private readonly fetch: Fetch
 
-	constructor (config?: Partial<Config>) {
-		this.config = {
-			host: config?.host ?? "http://127.0.0.1:11434"
-		};
+  constructor(config?: Partial<Config>) {
+    this.config = {
+      host: utils.formatHost(config?.host),
+    }
 
     this.fetch = fetch
     if (config?.fetch != null) {
@@ -47,11 +47,9 @@ export class Ollama {
     request: { stream?: boolean } & Record<string, any>,
   ): Promise<T | AsyncGenerator<T>> {
     request.stream = request.stream ?? false
-    const response = await utils.post(
-      this.fetch,
-      `${this.config.address}/api/${endpoint}`,
-      { ...request },
-    )
+    const response = await utils.post(this.fetch, `${this.config.host}/api/${endpoint}`, {
+      ...request,
+    })
 
     if (!response.body) {
       throw new Error('Missing body')
@@ -158,7 +156,7 @@ export class Ollama {
     const digest = `sha256:${sha256sum}`
 
     try {
-      await utils.head(this.fetch, `${this.config.address}/api/blobs/${digest}`)
+      await utils.head(this.fetch, `${this.config.host}/api/blobs/${digest}`)
     } catch (e) {
       if (e instanceof Error && e.message.includes('404')) {
         // Create a new readable stream for the fetch request
@@ -180,7 +178,7 @@ export class Ollama {
 
         await utils.post(
           this.fetch,
-          `${this.config.address}/api/blobs/${digest}`,
+          `${this.config.host}/api/blobs/${digest}`,
           readableStream,
         )
       } else {
@@ -280,25 +278,25 @@ export class Ollama {
   }
 
   async delete(request: DeleteRequest): Promise<StatusResponse> {
-    await utils.del(this.fetch, `${this.config.address}/api/delete`, {
+    await utils.del(this.fetch, `${this.config.host}/api/delete`, {
       name: request.model,
     })
     return { status: 'success' }
   }
 
   async copy(request: CopyRequest): Promise<StatusResponse> {
-    await utils.post(this.fetch, `${this.config.address}/api/copy`, { ...request })
+    await utils.post(this.fetch, `${this.config.host}/api/copy`, { ...request })
     return { status: 'success' }
   }
 
   async list(): Promise<ListResponse> {
-    const response = await utils.get(this.fetch, `${this.config.address}/api/tags`)
+    const response = await utils.get(this.fetch, `${this.config.host}/api/tags`)
     const listResponse = (await response.json()) as ListResponse
     return listResponse
   }
 
   async show(request: ShowRequest): Promise<ShowResponse> {
-    const response = await utils.post(this.fetch, `${this.config.address}/api/show`, {
+    const response = await utils.post(this.fetch, `${this.config.host}/api/show`, {
       ...request,
     })
     const showResponse = (await response.json()) as ShowResponse
@@ -306,11 +304,9 @@ export class Ollama {
   }
 
   async embeddings(request: EmbeddingsRequest): Promise<EmbeddingsResponse> {
-    const response = await utils.post(
-      this.fetch,
-      `${this.config.address}/api/embeddings`,
-      { request },
-    )
+    const response = await utils.post(this.fetch, `${this.config.host}/api/embeddings`, {
+      request,
+    })
     const embeddingsResponse = (await response.json()) as EmbeddingsResponse
     return embeddingsResponse
   }

+ 2 - 2
src/interfaces.ts

@@ -1,8 +1,8 @@
 export type Fetch = typeof fetch
 
 export interface Config {
-	host: string,
-	fetch?: Fetch
+  host: string
+  fetch?: Fetch
 }
 
 // request types

+ 46 - 8
src/utils.ts

@@ -25,16 +25,16 @@ const checkOk = async (response: Response): Promise<void> => {
   }
 }
 
-export const get = async (fetch: Fetch, address: string): Promise<Response> => {
-  const response = await fetch(formatAddress(address))
+export const get = async (fetch: Fetch, host: string): Promise<Response> => {
+  const response = await fetch(host)
 
   await checkOk(response)
 
   return response
 }
 
-export const head = async (fetch: Fetch, address: string): Promise<Response> => {
-  const response = await fetch(formatAddress(address), {
+export const head = async (fetch: Fetch, host: string): Promise<Response> => {
+  const response = await fetch(host, {
     method: 'HEAD',
   })
 
@@ -45,7 +45,7 @@ export const head = async (fetch: Fetch, address: string): Promise<Response> =>
 
 export const post = async (
   fetch: Fetch,
-  address: string,
+  host: string,
   data?: Record<string, unknown> | BodyInit,
 ): Promise<Response> => {
   const isRecord = (input: any): input is Record<string, unknown> => {
@@ -54,7 +54,7 @@ export const post = async (
 
   const formattedData = isRecord(data) ? JSON.stringify(data) : data
 
-  const response = await fetch(formatAddress(address), {
+  const response = await fetch(host, {
     method: 'POST',
     body: formattedData,
   })
@@ -66,10 +66,10 @@ export const post = async (
 
 export const del = async (
   fetch: Fetch,
-  address: string,
+  host: string,
   data?: Record<string, unknown>,
 ): Promise<Response> => {
-  const response = await fetch(formatAddress(address), {
+  const response = await fetch(host, {
     method: 'DELETE',
     body: JSON.stringify(data),
   })
@@ -110,3 +110,41 @@ export const parseJSON = async function* <T = unknown>(
     }
   }
 }
+
+export const formatHost = (host: string): string => {
+  if (!host) {
+    host = 'http://127.0.0.1:11434'
+  }
+
+  let isExplicitProtocol = host.includes('://')
+
+  if (host.startsWith(':')) {
+    // if host starts with ':', prepend the default hostname
+    host = `http://127.0.0.1${host}`
+    isExplicitProtocol = false
+  }
+
+  if (!isExplicitProtocol) {
+    host = `http://${host}`
+  }
+
+  const url = new URL(host)
+
+  let port = url.port
+  if (!port) {
+    if (!isExplicitProtocol) {
+      port = '11434'
+    } else {
+      // Assign default ports based on the protocol
+      port = url.protocol === 'https:' ? '443' : '80'
+    }
+  }
+
+  let formattedHost = `${url.protocol}//${url.hostname}:${port}${url.pathname}`
+  // remove trailing slashes
+  if (formattedHost.endsWith('/')) {
+    formattedHost = formattedHost.slice(0, -1)
+  }
+
+  return formattedHost
+}

+ 54 - 2
test/index.spec.ts

@@ -1,3 +1,55 @@
-describe('Empty test', () => {
-  it('runs', () => {})
+import { formatHost } from '../src/utils'
+
+describe('formatHost Function Tests', () => {
+  it('should return default URL for empty string', () => {
+    expect(formatHost('')).toBe('http://127.0.0.1:11434')
+  })
+
+  it('should parse plain IP address', () => {
+    expect(formatHost('1.2.3.4')).toBe('http://1.2.3.4:11434')
+  })
+
+  it('should parse IP address with port', () => {
+    expect(formatHost('1.2.3.4:56789')).toBe('http://1.2.3.4:56789')
+  })
+
+  it('should parse HTTP URL', () => {
+    expect(formatHost('http://1.2.3.4')).toBe('http://1.2.3.4:80')
+  })
+
+  it('should parse HTTPS URL', () => {
+    expect(formatHost('https://1.2.3.4')).toBe('https://1.2.3.4:443')
+  })
+
+  it('should parse HTTPS URL with port', () => {
+    expect(formatHost('https://1.2.3.4:56789')).toBe('https://1.2.3.4:56789')
+  })
+
+  it('should parse domain name', () => {
+    expect(formatHost('example.com')).toBe('http://example.com:11434')
+  })
+
+  it('should parse domain name with port', () => {
+    expect(formatHost('example.com:56789')).toBe('http://example.com:56789')
+  })
+
+  it('should parse HTTP domain', () => {
+    expect(formatHost('http://example.com')).toBe('http://example.com:80')
+  })
+
+  it('should parse HTTPS domain', () => {
+    expect(formatHost('https://example.com')).toBe('https://example.com:443')
+  })
+
+  it('should parse HTTPS domain with port', () => {
+    expect(formatHost('https://example.com:56789')).toBe('https://example.com:56789')
+  })
+
+  it('should handle trailing slash in domain', () => {
+    expect(formatHost('example.com/')).toBe('http://example.com:11434')
+  })
+
+  it('should handle trailing slash in domain with port', () => {
+    expect(formatHost('example.com:56789/')).toBe('http://example.com:56789')
+  })
 })