diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -408,6 +408,39 @@ } }; +class BranchConversionPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::BranchOp branchOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(branchOp, operands, + branchOp.getTarget()); + return success(); + } +}; + +class BranchConditionalConversionPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion< + spirv::BranchConditionalOp>::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // There is no support of branch weights in LLVM dialect at the moment. + if (auto weights = op.branch_weights()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.condition(), op.getTrueBlock(), op.getTrueBlockArguments(), + op.getFalseBlock(), op.getFalseBlockArguments()); + return success(); + } +}; + /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template @@ -808,6 +841,9 @@ // Constant op ConstantScalarAndVectorPattern, + // Control Flow ops + BranchConversionPattern, BranchConditionalConversionPattern, + // Function Call op FunctionCallPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-opt -convert-spirv-to-llvm -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.Branch +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + spv.func @branch_without_arguments() -> () "None" { + // CHECK: llvm.br ^bb1 + spv.Branch ^label + // CHECK: ^bb1 + ^label: + spv.Return + } + + spv.func @branch_with_arguments() -> () "None" { + %0 = spv.constant 0 : i32 + %1 = spv.constant true + // CHECK: llvm.br ^bb1(%{{.*}}, %{{.*}} : !llvm.i32, !llvm.i1) + spv.Branch ^label(%0, %1: i32, i1) + // CHECK: ^bb1(%{{.*}}: !llvm.i32, %{{.*}}: !llvm.i1) + ^label(%arg0: i32, %arg1: i1): + spv.Return + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.BranchConditional +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + spv.func @cond_branch_without_arguments() -> () "None" { + // CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1 + %cond = spv.constant true + // CHECK: lvm.cond_br %[[COND]], ^bb1, ^bb2 + spv.BranchConditional %cond, ^true, ^false + // CHECK: ^bb1: + ^true: + spv.Return + // CHECK: ^bb2: + ^false: + spv.Return + } + + spv.func @cond_branch_with_arguments_nested() -> () "None" { + // CHECK: %[[COND1:.*]] = llvm.mlir.constant(true) : !llvm.i1 + %cond = spv.constant true + %0 = spv.constant 0 : i32 + // CHECK: %[[COND2:.*]] = llvm.mlir.constant(false) : !llvm.i1 + %false = spv.constant false + // CHECK: llvm.cond_br %[[COND1]], ^bb1(%{{.*}}, %[[COND2]] : !llvm.i32, !llvm.i1), ^bb2 + spv.BranchConditional %cond, ^outer_true(%0, %false: i32, i1), ^outer_false + // CHECK: ^bb1(%{{.*}}: !llvm.i32, %[[COND:.*]]: !llvm.i1): + ^outer_true(%arg0: i32, %arg1: i1): + // CHECK: llvm.cond_br %[[COND]], ^bb3, ^bb4(%{{.*}}, %{{.*}} : !llvm.i32, !llvm.i32) + spv.BranchConditional %arg1, ^inner_true, ^inner_false(%arg0, %arg0: i32, i32) + // CHECK: ^bb2: + ^outer_false: + spv.Return + // CHECK: ^bb3: + ^inner_true: + spv.Return + // CHECK: ^bb4(%{{.*}}: !llvm.i32, %{{.*}}: !llvm.i32): + ^inner_false(%arg3: i32, %arg4: i32): + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 { + spv.func @cond_branch_with_weights(%cond: i1) -> () "None" { + // expected-error@+1 {{failed to legalize operation 'spv.BranchConditional' that was explicitly marked illegal}} + spv.BranchConditional %cond [1, 2], ^true, ^false + ^true: + spv.Return + ^false: + spv.Return + } +}