import serial import struct import time import sys import os from google.protobuf.message import DecodeError from serial.tools import list_ports from crccheck.crc import Crc from firmware_pb2 import Start, FileCheck , Package, PackageAck, Done from usb_pb2 import PackageType Crc32 = Crc(32,0x04C11DB7,0xffffffff) def load_firmware(filename): # Open the file in binary mode with open(filename, 'rb') as f: data = f.read() # Calculate the CRC32 checksum crc = Crc32.calc(data) # Get the file size size = len(data) # Split the data into 256-byte chunks packages = [data[i:i+256] for i in range(0, len(data), 256)] return crc, size, packages def make_header(typeid: PackageType, length: int) -> bytearray: struct_format = '> 8) & 0xFF) + (typeidint & 0xFF) + ((typeidint >> 8) & 0xFF) packed_data = struct.pack(struct_format, length, typeid, check) return packed_data def send_package(typeid : PackageType, data: bytearray, serial: serial.Serial): head = make_header(typeid, len(data)) package = head + data serial.write(package) def read_header(serial): # Read the header header_data = serial.read(5) # header size is 5 bytes length, typeid, check = struct.unpack('> 8) & 0xFF) + (typeid & 0xFF) + ((typeid >> 8) & 0xFF) if check != check_calculated: print('Header check byte mismatch') return None, None return length, typeid def receive_ack(serial): length, typeid = read_header(serial) if not length or typeid != PackageType.FIRMWAREPACKAGEACK: return None # Read the message message_data = serial.read(length) if len(message_data) != length: print('Incomplete message') return None # Parse the message ack = PackageAck() try: ack.ParseFromString(message_data) except DecodeError: print('Failed to parse PackageAck') return None return ack def receive_file_check(serial): length, typeid = read_header(serial) if not length or typeid != PackageType.FIRMWAREFILECHECK: return None # Read the message message_data = serial.read(length) if len(message_data) != length: print('Incomplete message') return None # Parse the message file_check = FileCheck() try: file_check.ParseFromString(message_data) except DecodeError: print('Failed to parse FileCheck') return None return file_check FILENAME = 'firmware.bin' ID = 0 if len(sys.argv) > 1: FILENAME = sys.argv[1] if len(sys.argv) > 2: ID = int(sys.argv[2]) if __name__ == "__main__": stm_port = None for port in list_ports.comports(): print(port) if "STM32 Virtual ComPort" in port.description: stm_port = port.device break if stm_port is None: print("STM32 Virtual ComPort not found") exit(-1) else: # Open the serial port ser = serial.Serial(stm_port,baudrate=5000000) crc, size, packages = load_firmware(FILENAME) # Create a Start message start = Start() start.name = os.path.basename(FILENAME) start.size = size start.packages = len(packages) start.device_id = ID start.crc_fw = crc # Send the Start message print(start) send_package(PackageType.FIRMWARESTART, start.SerializeToString(), ser) #time.sleep(1) # wait for the device to process the message # Receive the FileCheck message file_check = receive_file_check(ser) print(file_check) if not file_check: print('Failed to receive FileCheck') exit(-1) if file_check.crc_fw == start.crc_fw and file_check.size == start.size and not file_check.ready_for_data: # Skip to Done if the CRC and size match and ready_for_data is false print('No need for data transfer') elif file_check.ready_for_data: # Send the firmware packages for (i,pack_data) in enumerate(packages): package = Package() package.counter = i package.crc_pac = Crc32.calc(pack_data) package.device_id = start.device_id package.data = pack_data # Send the Package message print(package) print(hex(package.crc_pac)) send_package(PackageType.FIRMWAREPACKAGE, package.SerializeToString(), ser) # Wait for the PackageAck message ack = receive_ack(ser) print(ack) if not ack.ack: print(f'Package {i} not acknowledged') exit(-1) else: print('Error in FileCheck message') exit(-1) # Send the Done message done = Done() done.size = start.size done.crc_fw = start.crc_fw done.device_id = start.device_id print(done) send_package(PackageType.FIRMWAREDONE, done.SerializeToString(), ser)