import { z } from 'zod';
import { chunkTypeSchema } from './chunk';
import { pieceOfInformationTagNameSchema } from '../tag';

// Define the table-specific schema separately
export const tableFilterSchema = z.object({
  type: z.literal('TABLE'),
  fieldReferenceIds: z.array(z.string()).optional(),
  tagName: pieceOfInformationTagNameSchema.optional()
});

// Define the base filter schema without the discriminated union part
export const baseFindSimilarChunksFilterSchema = z.object({
  type: chunkTypeSchema,
  tagName: pieceOfInformationTagNameSchema.optional(),
  similarityThreshold: z.number().min(0).max(1).optional()
});

// Combine the schemas using the discriminated union
export const findSimilarChunksFilterSchema = baseFindSimilarChunksFilterSchema.and(
  z.discriminatedUnion('type', [
    tableFilterSchema,
    z.object({
      type: z.literal('FILE')
    }),
    z.object({
      type: z.literal('PROJECT')
    }),
    z.object({
      type: z.literal('CONVERSATION')
    }),
    z.object({
      type: z.literal('PROJECT_NAME')
    }),
    z.object({
      type: z.literal('PROJECT_DESCRIPTION')
    })
  ])
);

const findSimilarChunksBaseSchema = z.object({
  knowledgeLibraryIds: z.array(z.string()).optional(),
  limit: z.number(),
  filter: findSimilarChunksFilterSchema
});
export const findSimilarChunksEmbeddingSchema = findSimilarChunksBaseSchema.extend({
  queryEmbedding: z.array(z.number())
});
export const findSimilarChunksQuerySchema = findSimilarChunksBaseSchema.extend({
  query: z.string()
});

export const findSimilarChunksSchema = findSimilarChunksEmbeddingSchema.extend({
  filter: findSimilarChunksFilterSchema
});

export const findSimilarTableRowsSchema = findSimilarChunksQuerySchema.extend({
  filter: tableFilterSchema
});

export type FindSimilarChunksFilter = z.infer<typeof findSimilarChunksFilterSchema>;
export type FindSimilarChunksEmbedding = z.infer<typeof findSimilarChunksEmbeddingSchema>;
export type FindSimilarChunksQuery = z.infer<typeof findSimilarChunksQuerySchema>;
export type FindSimilarTableRows = z.infer<typeof findSimilarTableRowsSchema>;
