Diff

Created Diff never expires
22 removals
96 lines
21 additions
96 lines
import { writeFileSync } from "fs";
import { writeFileSync } from "fs";
import { codePartToCompactString } from "./utils/codePartToCompactString";
import { codePartToCompactString } from "./utils/codePartToCompactString";
import { getFlydeFiles } from "./utils/fs-helpers";
import { getFlydeFiles } from "./utils/fs-helpers";
import { preprocessStdLibParts } from "./utils/preprocessStdLibParts";
import { preprocessStdLibParts } from "./utils/preprocessStdLibParts";
import { join } from "path";
import { join } from "path";
import { chunkArray } from "./utils";
import { chunkArray } from "./utils";
import { readVersionResult } from "./utils/generatePartVersions";
import { readVersionResult } from "./utils/generatePartVersions";
import { fullChatInstructions } from "./benchmark/chat-completion-instructions";


(async function () {
(async function () {
const files = getFlydeFiles();
const files = getFlydeFiles();
const parts = preprocessStdLibParts(files);
const parts = preprocessStdLibParts(files);


const partsWithVersions = parts.map((part) => {
const partsWithVersions = parts.map((part) => {
const versionData = readVersionResult(part);
const versionData = readVersionResult(part);
return { ...part, ...versionData };
return { ...part, ...versionData };
});
});


const partsPerNamespace = partsWithVersions.reduce<
const partsPerNamespace = partsWithVersions.reduce<
Record<string, typeof parts>
Record<string, typeof parts>
>((acc, part) => {
>((acc, part) => {
const namespace = part.original.namespace ?? "n/a";
const namespace = part.original.namespace ?? "n/a";
if (!acc[namespace]) {
if (!acc[namespace]) {
acc[namespace] = [];
acc[namespace] = [];
}
}
acc[namespace].push(part);
acc[namespace].push(part);
return acc;
return acc;
}, {});
}, {});


const trainingIds = new Set<string>();
const trainingIds = new Set<string>();
const validationIds = new Set<string>();
const validationIds = new Set<string>();


Object.entries(partsPerNamespace).forEach(([k, parts]) => {
Object.entries(partsPerNamespace).forEach(([k, parts]) => {
const chunks = chunkArray(parts, 10);
const chunks = chunkArray(parts, 10);
chunks.forEach((chunk) => {
chunks.forEach((chunk) => {
chunk.forEach((part, idx) => {
chunk.forEach((part, idx) => {
if (idx === 0 && chunk.length > 4) {
if (idx === 0 && chunk.length > 4) {
validationIds.add(part.original.id);
validationIds.add(part.original.id);
} else {
} else {
trainingIds.add(part.original.id);
trainingIds.add(part.original.id);
}
}
});
});
});
});
});
});


const validationDataset = partsWithVersions.flatMap((part) => {
const validationDataset = partsWithVersions.flatMap((part) => {
if (!validationIds.has(part.original.id)) {
if (!validationIds.has(part.original.id)) {
return [];
return [];
}
}


const compactParts = [part.original.runFnString, part.alternativeFunction]
const compactParts = codePartToCompactString({
.map((code) => ({ ...part.original, runFnString: code }))
...part.original,
.map(codePartToCompactString);
runFnString: part.original.runFnString,
});


return part.prompts.map((desc, idx) => {
return part.prompts.map((desc, idx) => {
return {
return {
prompt: desc + "\n\n###\n\n",
prompt: desc + "\n\n###\n\n",
completion: " " + compactParts[idx % 2] + "###",
completion: " " + compactParts[idx % 2] + "###",
};
};
});
});
});
});


const trainingDataset = partsWithVersions.flatMap((part) => {
const trainingDataset = partsWithVersions.flatMap((part) => {
if (!trainingIds.has(part.original.id)) {
if (!trainingIds.has(part.original.id)) {
return [];
return [];
}
}


const compactParts = [part.original.runFnString, part.alternativeFunction]
const compactPart = codePartToCompactString({
.map((code) => ({ ...part.original, runFnString: code }))
...part.original,
.map(codePartToCompactString);
runFnString: part.original.runFnString,

return part.prompts.map((desc, idx) => {
return {
prompt: desc + "\n\n###\n\n",
completion: " " + compactParts[idx % 2] + "###",
};
});
});

const prompt = part.prompts[0];

return {
messages: [
{ role: "system", content: fullChatInstructions },
{ role: "user", content: prompt },
{ role: "assistant", content: compactPart },
],
};
});
});


console.log(
const datasetFileLocation = join(__dirname, `../dataset-cc.json`);
partsWithVersions.length,
trainingDataset.length,
validationDataset.length
);

const datasetFileLocation = join(__dirname, `../dataset.json`);


writeFileSync(
writeFileSync(
datasetFileLocation,
datasetFileLocation,
JSON.stringify([...trainingDataset, ...validationDataset], null, 2)
JSON.stringify([...trainingDataset, ...validationDataset], null, 2)
);
);


console.log(
console.log(
`Dataset written to: ${datasetFileLocation}. Remember, the last ${validationDataset.length} entries are validation data.`
`${trainingDataset.length} examples written to: ${datasetFileLocation}. Remember, the last ${validationDataset.length} entries are validation data.`
);
);
})();
})();