Untitled diff
13 removals
Words removed | 20 |
Total words | 289 |
Words removed (%) | 6.92 |
97 lines
13 additions
Words added | 44 |
Total words | 313 |
Words added (%) | 14.06 |
97 lines
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
# Licensed under the MIT license.
import argparse
import argparse
import datetime
import datetime
from distutils.dir_util import copy_tree
import os
import os
import shutil
import operator
import tempfile
import tempfile
import traceback
import seedot.common as Common
import seedot.common as Common
from seedot.main import Main
from seedot.mainX86 import MainX86
import seedot.util as Util
import seedot.util as Util
class MainDriver:
class MainDriverX86:
def parseArgs(self):
def parseArgs(self):
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--algo", choices=Common.Algo.All,
parser.add_argument("-a", "--algo", choices=Common.Algo.All,
metavar='', help="Algorithm to run ('bonsai' or 'protonn')")
metavar='', help="Algorithm to run ('bonsai' or 'protonn')")
parser.add_argument("--train", required=True,
parser.add_argument("--train", required=True,
metavar='', help="Training set file")
metavar='', help="Training set file")
parser.add_argument("--test", required=True,
parser.add_argument("--test", required=True,
metavar='', help="Testing set file")
metavar='', help="Testing set file")
parser.add_argument("--model", required=True, metavar='',
parser.add_argument("--model", required=True, metavar='',
help="Directory containing trained model (output from Bonsai/ProtoNN trainer)")
help="Directory containing trained model (output from Bonsai/ProtoNN trainer)")
# parser.add_argument("-v", "--version", default=Common.Version.Fixed, choices=Common.Version.All, metavar='',
# parser.add_argument("-v", "--version", default=Common.Version.Fixed, choices=Common.Version.All, metavar='',
# help="Datatype of the generated code (fixed-point or floating-point)")
# help="Datatype of the generated code (fixed-point or floating-point)")
parser.add_argument("--tempdir", metavar='',
parser.add_argument("--tempdir", metavar='',
help="Scratch directory for intermediate files")
help="Scratch directory for intermediate files")
parser.add_argument("-o", "--outdir", metavar='',
parser.add_argument("-o", "--outdir", metavar='',
help="Directory to output the generated Arduino sketch")
help="Directory to output the generated X86 files")
self.args = parser.parse_args()
self.args = parser.parse_args()
# Verify the input files and directory exists
# Verify the input files and directory exists
assert os.path.isfile(self.args.train), "Training set doesn't exist"
assert os.path.isfile(self.args.train), "Training set doesn't exist"
assert os.path.isfile(self.args.test), "Testing set doesn't exist"
assert os.path.isfile(self.args.test), "Testing set doesn't exist"
assert os.path.isdir(self.args.model), "Model directory doesn't exist"
assert os.path.isdir(self.args.model), "Model directory doesn't exist"
# Assign or create temporary directory
if self.args.tempdir is not None:
if self.args.tempdir is not None:
assert os.path.isdir(
assert os.path.isdir(
self.args.tempdir), "Scratch directory doesn't exist"
self.args.tempdir), "Scratch directory doesn't exist"
Common.tempdir = self.args.tempdir
Common.tempdir = self.args.tempdir
else:
else:
Common.tempdir = os.path.join(tempfile.gettempdir(
Common.tempdir = os.path.join(tempfile.gettempdir(
), "SeeDot", datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
), "SeeDot", datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
os.makedirs(Common.tempdir, exist_ok=True)
os.makedirs(Common.tempdir, exist_ok=True)
# Assign or create output directory
if self.args.outdir is not None:
if self.args.outdir is not None:
assert os.path.isdir(
assert os.path.isdir(
self.args.outdir), "Output directory doesn't exist"
self.args.outdir), "Output directory doesn't exist"
Common.outdir = self.args.outdir
Common.outdir = self.args.outdir
else:
else:
Common.outdir = os.path.join(Common.tempdir, "arduino")
Common.outdir = os.path.join(Common.tempdir, "<redacted>")
os.makedirs(Common.outdir, exist_ok=True)
os.makedirs(Common.outdir, exist_ok=True)
# Not relevant for me since this handles Windows specific things.
def checkMSBuildPath(self):
def checkMSBuildPath(self):
found = False
found = False
for path in Common.msbuildPathOptions:
for path in Common.msbuildPathOptions:
if os.path.isfile(path):
if os.path.isfile(path):
found = True
found = True
Common.msbuildPath = path
Common.msbuildPath = path
if not found:
if not found:
raise Exception(
raise Exception(
"Msbuild.exe not found at the following locations:\n%s\nPlease change the path and run again" % (
"Msbuild.exe not found at the following locations:\n%s\nPlease change the path and run again" % (
Common.msbuildPathOptions))
Common.msbuildPathOptions))
def run(self):
def run(self):
# Not relevant for me since this handles Windows specific things.
if Util.windows():
if Util.windows():
self.checkMSBuildPath()
self.checkMSBuildPath()
algo, version, trainingInput, testingInput, modelDir = self.args.algo, Common.Version.Fixed, self.args.train, self.args.test, self.args.model
algo, version, trainingInput, testingInput, modelDir = self.args.algo, Common.Version.Fixed, self.args.train, self.args.test, self.args.model
print("\n================================")
print("\n================================")
print("Executing on %s for Arduino" % (algo))
print("Executing on %s for X86" % (algo))
print("--------------------------------")
print("--------------------------------")
print("Train file: %s" % (trainingInput))
print("Train file: %s" % (trainingInput))
print("Test file: %s" % (testingInput))
print("Test file: %s" % (testingInput))
print("Model directory: %s" % (modelDir))
print("Model directory: %s" % (modelDir))
print("================================\n")
print("================================\n")
obj = Main(algo, version, Common.Target.Arduino,
obj = MainX86(algo, version, Common.Target.X86,
trainingInput, testingInput, modelDir, None)
trainingInput, testingInput, modelDir, None)
obj.run()
obj.run()
if __name__ == "__main__":
if __name__ == "__main__":
obj = MainDriver()
obj = MainDriverX86()
obj.parseArgs()
obj.parseArgs()
obj.run()
obj.run()