diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -459,13 +459,18 @@ LogicalResult matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // There is no support of branch weights in LLVM dialect at the moment. - if (auto weights = op.branch_weights()) - return failure(); + // If branch weights exist, map them to 32-bit integer vector. + ElementsAttr branchWeights = nullptr; + if (auto weights = op.branch_weights()) { + VectorType weightType = VectorType::get(2, rewriter.getI32Type()); + branchWeights = + DenseElementsAttr::get(weightType, weights.getValue().getValue()); + } rewriter.replaceOpWithNewOp( - op, op.condition(), op.getTrueBlock(), op.getTrueBlockArguments(), - op.getFalseBlock(), op.getFalseBlockArguments()); + op, op.condition(), op.getTrueBlockArguments(), + op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), + op.getFalseBlock()); return success(); } }; diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -66,16 +66,14 @@ ^inner_false(%arg3: i32, %arg4: i32): spv.Return } -} - -// ----- -spv.module Logical GLSL450 { spv.func @cond_branch_with_weights(%cond: i1) -> () "None" { - // expected-error@+1 {{failed to legalize operation 'spv.BranchConditional' that was explicitly marked illegal}} + // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2 spv.BranchConditional %cond [1, 2], ^true, ^false + // CHECK: ^bb1: ^true: spv.Return + // CHECK: ^bb2: ^false: spv.Return }