diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -31,6 +31,31 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Pattern to convert a loop::IfOp within kernel functions into +/// spirv::SelectionOp. +class IfOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(loop::IfOp IfOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to erase a loop::TerminatorOp. +class TerminatorOpConversion final + : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(loop::TerminatorOp terminatorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(terminatorOp); + return matchSuccess(); + } +}; + /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation /// builin variables. template @@ -174,6 +199,58 @@ } //===----------------------------------------------------------------------===// +// loop::IfOp. +//===----------------------------------------------------------------------===// + +PatternMatchResult +IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // When lowering `loop::IfOp` we explicitly create a selection header block + // before the control flow diverges and a merge block where control flow + // subsequently converges. + loop::IfOpOperandAdaptor ifOperands(operands); + auto loc = ifOp.getLoc(); + + // Create `spv.selection` operation, selection header block and merge block. + auto selectionControl = rewriter.getI32IntegerAttr( + static_cast(spirv::SelectionControl::None)); + auto selectionOp = rewriter.create(loc, selectionControl); + selectionOp.addMergeBlock(); + auto *mergeBlock = selectionOp.getMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + auto *selectionHeaderBlock = new Block(); + selectionOp.body().getBlocks().push_front(selectionHeaderBlock); + + // Inline `then` region before the merge block and branch to it. + auto &thenRegion = ifOp.thenRegion(); + auto *thenBlock = &thenRegion.front(); + rewriter.setInsertionPointToEnd(&thenRegion.back()); + rewriter.create(loc, mergeBlock); + rewriter.inlineRegionBefore(thenRegion, mergeBlock); + + auto *elseBlock = mergeBlock; + // If `else` region is not empty, inline that region before the merge block + // and branch to it. + if (!ifOp.elseRegion().empty()) { + auto &elseRegion = ifOp.elseRegion(); + elseBlock = &elseRegion.front(); + rewriter.setInsertionPointToEnd(&elseRegion.back()); + rewriter.create(loc, mergeBlock); + rewriter.inlineRegionBefore(elseRegion, mergeBlock); + } + + // Create a `spv.BranchConditional` operation for selection header block. + rewriter.setInsertionPointToEnd(selectionHeaderBlock); + rewriter.create(loc, ifOperands.condition(), + thenBlock, ArrayRef(), + elseBlock, ArrayRef()); + + rewriter.eraseOp(ifOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// // Builtins. //===----------------------------------------------------------------------===// @@ -348,12 +425,12 @@ ArrayRef workGroupSize) { patterns.insert(context, typeConverter, workGroupSize); patterns.insert< - GPUReturnOpConversion, ForOpConversion, KernelModuleConversion, - KernelModuleTerminatorConversion, + ForOpConversion, GPUReturnOpConversion, IfOpConversion, + KernelModuleConversion, KernelModuleTerminatorConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion>(context, - typeConverter); + spirv::BuiltIn::LocalInvocationId>, + TerminatorOpConversion>(context, typeConverter); } diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s + +module attributes {gpu.container_module} { + func @main(%arg0 : memref<10xf32>, %arg1 : i1) { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> () + return + } + + module @kernels attributes {gpu.kernel_module} { + // CHECK-LABEL: @kernel_simple_selection + gpu.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) + attributes {gpu.kernel} { + %value = constant 0.0 : f32 + %i = constant 0 : index + + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: spv.Branch [[MERGE]] + // CHECK-NEXT: [[MERGE]]: + // CHECK-NEXT: spv._merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Return + + loop.if %arg3 { + store %value, %arg2[%i] : memref<10xf32> + } + gpu.return + } + + // CHECK-LABEL: @kernel_nested_selection + gpu.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) + attributes {gpu.kernel} { + %i = constant 0 : index + %j = constant 9 : index + + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]] + // CHECK-NEXT: [[TRUE_TOP]]: + // CHECK-NEXT: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]] + // CHECK-NEXT: [[TRUE_NESTED_TRUE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]] + // CHECK-NEXT: [[FALSE_NESTED_TRUE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_TRUE_PATH]] + // CHECK-NEXT: [[MERGE_NESTED_TRUE_PATH]]: + // CHECK-NEXT: spv._merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Branch [[MERGE_TOP:\^.*]] + // CHECK-NEXT: [[FALSE_TOP]]: + // CHECK-NEXT: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]] + // CHECK-NEXT: [[TRUE_NESTED_FALSE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]] + // CHECK-NEXT: [[FALSE_NESTED_FALSE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_FALSE_PATH]] + // CHECK: [[MERGE_NESTED_FALSE_PATH]]: + // CHECK-NEXT: spv._merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Branch [[MERGE_TOP]] + // CHECK-NEXT: [[MERGE_TOP]]: + // CHECK-NEXT: spv._merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Return + + loop.if %arg5 { + loop.if %arg6 { + %value = load %arg3[%i] : memref<10xf32> + store %value, %arg4[%i] : memref<10xf32> + } else { + %value = load %arg4[%i] : memref<10xf32> + store %value, %arg3[%i] : memref<10xf32> + } + } else { + loop.if %arg6 { + %value = load %arg3[%j] : memref<10xf32> + store %value, %arg4[%j] : memref<10xf32> + } else { + %value = load %arg4[%j] : memref<10xf32> + store %value, %arg3[%j] : memref<10xf32> + } + } + gpu.return + } + } +}