Files
cls_master/tools/send_fw.py

181 lines
5.1 KiB
Python

import serial
import struct
import time
import sys
import os
from google.protobuf.message import DecodeError
from serial.tools import list_ports
from crccheck.crc import Crc
from firmware_pb2 import Start, FileCheck , Package, PackageAck, Done
from usb_pb2 import PackageType
Crc32 = Crc(32,0x04C11DB7,0xffffffff)
def load_firmware(filename):
# Open the file in binary mode
with open(filename, 'rb') as f:
data = f.read()
# Calculate the CRC32 checksum
crc = Crc32.calc(data)
# Get the file size
size = len(data)
# Split the data into 256-byte chunks
packages = [data[i:i+256] for i in range(0, len(data), 256)]
return crc, size, packages
def make_header(typeid: PackageType, length: int) -> bytearray:
struct_format = '<HHB' # '<' for little-endian, 'H' for uint16_t, 'B' for uint8_t
# Calculate the check byte as the sum of length and type
typeidint = int(typeid)
check = (length & 0xFF) + ((length >> 8) & 0xFF) + (typeidint & 0xFF) + ((typeidint >> 8) & 0xFF)
packed_data = struct.pack(struct_format, length, typeid, check)
return packed_data
def send_package(typeid : PackageType, data: bytearray, serial: serial.Serial):
head = make_header(typeid, len(data))
package = head + data
serial.write(package)
def read_header(serial):
# Read the header
header_data = serial.read(5) # header size is 5 bytes
length, typeid, check = struct.unpack('<HHB', header_data)
# Verify the check byte
check_calculated = (length & 0xFF) + ((length >> 8) & 0xFF) + (typeid & 0xFF) + ((typeid >> 8) & 0xFF)
if check != check_calculated:
print('Header check byte mismatch')
return None, None
return length, typeid
def receive_ack(serial):
length, typeid = read_header(serial)
if not length or typeid != PackageType.FIRMWAREPACKAGEACK:
return None
# Read the message
message_data = serial.read(length)
if len(message_data) != length:
print('Incomplete message')
return None
# Parse the message
ack = PackageAck()
try:
ack.ParseFromString(message_data)
except DecodeError:
print('Failed to parse PackageAck')
return None
return ack
def receive_file_check(serial):
length, typeid = read_header(serial)
if not length or typeid != PackageType.FIRMWAREFILECHECK:
return None
# Read the message
message_data = serial.read(length)
if len(message_data) != length:
print('Incomplete message')
return None
# Parse the message
file_check = FileCheck()
try:
file_check.ParseFromString(message_data)
except DecodeError:
print('Failed to parse FileCheck')
return None
return file_check
FILENAME = 'firmware.bin'
ID = 0
if len(sys.argv) > 1:
FILENAME = sys.argv[1]
if len(sys.argv) > 2:
ID = int(sys.argv[2])
if __name__ == "__main__":
stm_port = None
for port in list_ports.comports():
print(port)
if "STM32 Virtual ComPort" in port.description:
stm_port = port.device
break
if stm_port is None:
print("STM32 Virtual ComPort not found")
exit(-1)
else:
# Open the serial port
ser = serial.Serial(stm_port,baudrate=5000000)
crc, size, packages = load_firmware(FILENAME)
# Create a Start message
start = Start()
start.name = os.path.basename(FILENAME)
start.size = size
start.packages = len(packages)
start.device_id = ID
start.crc_fw = crc
# Send the Start message
print(start)
send_package(PackageType.FIRMWARESTART, start.SerializeToString(), ser)
#time.sleep(1) # wait for the device to process the message
# Receive the FileCheck message
file_check = receive_file_check(ser)
print(file_check)
if not file_check:
print('Failed to receive FileCheck')
exit(-1)
if file_check.crc_fw == start.crc_fw and file_check.size == start.size and not file_check.ready_for_data:
# Skip to Done if the CRC and size match and ready_for_data is false
print('No need for data transfer')
elif file_check.ready_for_data:
# Send the firmware packages
for (i,pack_data) in enumerate(packages):
package = Package()
package.counter = i
package.crc_pac = Crc32.calc(pack_data)
package.device_id = start.device_id
package.data = pack_data
# Send the Package message
print(package)
print(hex(package.crc_pac))
send_package(PackageType.FIRMWAREPACKAGE, package.SerializeToString(), ser)
# Wait for the PackageAck message
ack = receive_ack(ser)
print(ack)
if not ack.ack:
print(f'Package {i} not acknowledged')
exit(-1)
else:
print('Error in FileCheck message')
exit(-1)
# Send the Done message
done = Done()
done.size = start.size
done.crc_fw = start.crc_fw
done.device_id = start.device_id
print(done)
send_package(PackageType.FIRMWAREDONE, done.SerializeToString(), ser)