diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp --- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp +++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp @@ -56,6 +56,12 @@ static cl::opt VerifyConstantTime("riscv-verify-constant-time", cl::init(false)); +static cl::opt SecretRegisters( + "riscv-secret-registers", + cl::desc( + "A comma-separated register list that will be considered as secrets"), + cl::init("")); + namespace llvm { extern const SubtargetFeatureKV RISCVFeatureKV[RISCV::NumSubtargetFeatures]; } // namespace llvm @@ -85,6 +91,9 @@ SmallVector ParserOptionsStack; ParserOptionsSet ParserOptions; + // A bit vector that each bit represents whether a register is secret. + BitVector Secrets; + SMLoc getLoc() const { return getParser().getTok().getLoc(); } bool isRV64() const { return getSTI().hasFeature(RISCV::Feature64Bit); } bool isRVE() const { return getSTI().hasFeature(RISCV::FeatureRVE); } @@ -177,6 +186,9 @@ // Check instruction constraints. bool validateInstruction(MCInst &Inst, OperandVector &Operands); + void initSecrets(); + bool verify(MCInst &Inst); + /// Helper for processing MC instructions that have been successfully matched /// by MatchAndEmitInstruction. Modifications to the emitted instructions, /// like the expansion of pseudo instructions (e.g., "li"), can be performed @@ -279,7 +291,7 @@ RISCVAsmParser(const MCSubtargetInfo &STI, MCAsmParser &Parser, const MCInstrInfo &MII, const MCTargetOptions &Options) - : MCTargetAsmParser(Options, STI, MII) { + : MCTargetAsmParser(Options, STI, MII), Secrets(32) { MCAsmParserExtension::Initialize(Parser); Parser.addAliasForDirective(".half", ".2byte"); @@ -310,6 +322,8 @@ if (AddBuildAttributes) getTargetStreamer().emitTargetAttributes(STI, /*EmitStackAlign*/ false); + if (VerifyConstantTime) + initSecrets(); } }; @@ -1312,12 +1326,14 @@ if (validateInstruction(Inst, Operands)) return true; - // TODO: Verify that only constant-time instructions will operate on - // secret operands. - if (VerifyConstantTime && + if (VerifyConstantTime && !SecretRegisters.getNumOccurrences() && !RISCVII::isConstantTime(MII.get(Inst.getOpcode()).TSFlags)) return Warning(IDLoc, "This instruction is not constant-time."); + if (VerifyConstantTime && SecretRegisters.getNumOccurrences() && + !verify(Inst)) + return Warning(IDLoc, "This instruction may leak secret."); + return processInstruction(Inst, IDLoc, Operands, Out); case Match_MissingFeature: { assert(MissingFeatures.any() && "Unknown missing features!"); @@ -3583,6 +3599,70 @@ return false; } +void RISCVAsmParser::initSecrets() { + if (SecretRegisters.empty()) { + Secrets.set(0, Secrets.size()); + return; + } + for (auto Name : llvm::split(SecretRegisters, ",")) { + MCRegister Reg = MatchRegisterName(Name); + if (!Reg) + Reg = MatchRegisterAltName(Name); + if (Reg < RISCV::X0 || Reg > RISCV::X31) + report_fatal_error("Secrets should be scalar integer registers!"); + Secrets.set(Reg - RISCV::X0); + } + // X0 is always non-secret. + Secrets.reset(0); +} + +bool RISCVAsmParser::verify(MCInst &RawInst) { + MCInst Res; + // We verify the uncompressed form. + bool Success = RISCVRVC::uncompress(Res, RawInst, *STI); + MCInst Inst = Success ? Res : RawInst; + unsigned NumOperands = Inst.getNumOperands(); + unsigned Opcode = Inst.getOpcode(); + const MCInstrDesc &MCID = MII.get(Opcode); + bool IsConstantTimeInst = RISCVII::isConstantTime(MCID.TSFlags); + + // FIXME: We suppose that there is only one output register and the second + // operand is the address to store. This may be not exact for some + // instructions. + bool HasSecret = false; + for (unsigned I = 0; I < NumOperands; I++) { + MCOperand Op = Inst.getOperand(I); + // We only care about register operands. + if (!Op.isReg()) + continue; + MCRegister Reg = Op.getReg(); + // We only care about integer register. + if (Reg < RISCV::X0 || Reg > RISCV::X31) + continue; + HasSecret |= Secrets.test(Reg - RISCV::X0); + } + + unsigned DstOpIdx = MCID.mayStore() ? 1 : 0; + MCOperand DstOp = Inst.getOperand(DstOpIdx); + assert(DstOp.isReg() && "Should be register!"); + MCRegister DstReg = DstOp.getReg(); + unsigned Index = DstReg - RISCV::X0; + if (!HasSecret) + // The base register will always be integer register. + return !(MCID.mayStore() && Secrets.test(Index)); + + // Should be integer register. + if (DstReg < RISCV::X0 || DstReg > RISCV::X31) + return true; + if (DstReg != RISCV::X0 && !MCID.isBranch()) { + if (IsConstantTimeInst && HasSecret) + Secrets.set(Index); + if (!IsConstantTimeInst && !HasSecret && Secrets.test(Index)) + Secrets.reset(Index); + } + return IsConstantTimeInst; +} + extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVAsmParser() { RegisterMCAsmParser X(getTheRISCV32Target()); RegisterMCAsmParser Y(getTheRISCV64Target()); diff --git a/llvm/test/MC/RISCV/secret-verify.s b/llvm/test/MC/RISCV/secret-verify.s new file mode 100644 --- /dev/null +++ b/llvm/test/MC/RISCV/secret-verify.s @@ -0,0 +1,36 @@ +# RUN: llvm-mc %s -triple=riscv64 -mattr=+m -riscv-verify-constant-time -riscv-secret-registers=a0 2>&1 \ +# RUN: | FileCheck %s + +# These assemblies is from C code, key(a0) is secret: +# int test(int key, int n) { +# int res = 0; +# for (int i = 0; i < n; i++) { +# if (key + 1 < i) +# res *= key; +# else +# res += key; +# } +# return res; +# } + +test: + blez a1, .LBB0_6 + li a3, 0 + li a2, 0 + addiw a4, a0, 1 + j .LBB0_4 +.LBB0_2: + addw a2, a2, a0 +.LBB0_3: + addiw a3, a3, 1 + beq a1, a3, .LBB0_7 +.LBB0_4: + # CHECK: warning: This instruction may leak secret + bge a4, a3, .LBB0_2 + mulw a2, a2, a0 + j .LBB0_3 +.LBB0_6: + li a2, 0 +.LBB0_7: + mv a0, a2 + ret