diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -230,6 +230,23 @@ let dependentDialects = ["omp::OpenMPDialect"]; } +//===----------------------------------------------------------------------===// +// SCFToSPIRV +//===----------------------------------------------------------------------===// + +def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> { + let summary = "Convert SCF dialect to SPIR-V dialect."; + let description = [{ + This pass converts SCF ops into SPIR-V structured control flow ops. + SPIR-V structured control flow ops does not support yielding values. + So for SCF ops yielding values, SPIR-V variables are created for + holding the values and load/store operations are emitted for updating + them. + }]; + let constructor = "mlir::createConvertSCFToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + //===----------------------------------------------------------------------===// // SCFToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h @@ -0,0 +1,21 @@ +//===- SCFToSPIRVPass.h - SCF to SPIR-V Conversion Pass ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRVPASS_H +#define MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Creates a pass to convert SCF ops into SPIR-V ops. +std::unique_ptr> createConvertSCFToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRVPASS_H diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -12,12 +12,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" + #include "../PassDetail.h" #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" -#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -58,10 +57,8 @@ spirv::SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); - ScfToSPIRVContext scfContext; OwningRewritePatternList patterns; populateGPUToSPIRVPatterns(context, typeConverter, patterns); - populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); if (failed(applyFullConversion(kernelModules, *target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRSCFToSPIRV SCFToSPIRV.cpp + SCFToSPIRVPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToSPIRV diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -0,0 +1,51 @@ +//===- SCFToSPIRVPass.cpp - SCF to SPIR-V Dialect Conversion Pass ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert SCF dialect into SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" + +using namespace mlir; + +namespace { +struct SCFToSPIRVPass : public SCFToSPIRVBase { + void runOnOperation() override; +}; +} // namespace + +void SCFToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); + ScfToSPIRVContext scfContext; + OwningRewritePatternList patterns; + populateSCFToSPIRVPatterns(context, typeConverter, scfContext, patterns); + populateStandardToSPIRVPatterns(context, typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); + + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertSCFToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/GPUToSPIRV/test_spirv_entry_point.mlir b/mlir/test/Conversion/GPUToSPIRV/entry-point.mlir rename from mlir/test/Conversion/GPUToSPIRV/test_spirv_entry_point.mlir rename to mlir/test/Conversion/GPUToSPIRV/entry-point.mlir diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir deleted file mode 100644 --- a/mlir/test/Conversion/GPUToSPIRV/if.mlir +++ /dev/null @@ -1,167 +0,0 @@ -// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s - -module attributes { - gpu.container_module, - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - func @main(%arg0 : memref<10xf32>, %arg1 : i1) { - %c0 = constant 1 : index - gpu.launch_func @kernels::@kernel_simple_selection - blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0) - args(%arg0 : memref<10xf32>, %arg1 : i1) - return - } - - gpu.module @kernels { - // CHECK-LABEL: @kernel_simple_selection - gpu.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) kernel - attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { - %value = constant 0.0 : f32 - %i = constant 0 : index - - // CHECK: spv.selection { - // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]] - // CHECK-NEXT: [[TRUE]]: - // CHECK: spv.Branch [[MERGE]] - // CHECK-NEXT: [[MERGE]]: - // CHECK-NEXT: spv.mlir.merge - // CHECK-NEXT: } - // CHECK-NEXT: spv.Return - - scf.if %arg3 { - store %value, %arg2[%i] : memref<10xf32> - } - gpu.return - } - - // CHECK-LABEL: @kernel_nested_selection - gpu.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) kernel - attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { - %i = constant 0 : index - %j = constant 9 : index - - // CHECK: spv.selection { - // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]] - // CHECK-NEXT: [[TRUE_TOP]]: - // CHECK-NEXT: spv.selection { - // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]] - // CHECK-NEXT: [[TRUE_NESTED_TRUE_PATH]]: - // CHECK: spv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]] - // CHECK-NEXT: [[FALSE_NESTED_TRUE_PATH]]: - // CHECK: spv.Branch [[MERGE_NESTED_TRUE_PATH]] - // CHECK-NEXT: [[MERGE_NESTED_TRUE_PATH]]: - // CHECK-NEXT: spv.mlir.merge - // CHECK-NEXT: } - // CHECK-NEXT: spv.Branch [[MERGE_TOP:\^.*]] - // CHECK-NEXT: [[FALSE_TOP]]: - // CHECK-NEXT: spv.selection { - // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]] - // CHECK-NEXT: [[TRUE_NESTED_FALSE_PATH]]: - // CHECK: spv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]] - // CHECK-NEXT: [[FALSE_NESTED_FALSE_PATH]]: - // CHECK: spv.Branch [[MERGE_NESTED_FALSE_PATH]] - // CHECK: [[MERGE_NESTED_FALSE_PATH]]: - // CHECK-NEXT: spv.mlir.merge - // CHECK-NEXT: } - // CHECK-NEXT: spv.Branch [[MERGE_TOP]] - // CHECK-NEXT: [[MERGE_TOP]]: - // CHECK-NEXT: spv.mlir.merge - // CHECK-NEXT: } - // CHECK-NEXT: spv.Return - - scf.if %arg5 { - scf.if %arg6 { - %value = load %arg3[%i] : memref<10xf32> - store %value, %arg4[%i] : memref<10xf32> - } else { - %value = load %arg4[%i] : memref<10xf32> - store %value, %arg3[%i] : memref<10xf32> - } - } else { - scf.if %arg6 { - %value = load %arg3[%j] : memref<10xf32> - store %value, %arg4[%j] : memref<10xf32> - } else { - %value = load %arg4[%j] : memref<10xf32> - store %value, %arg3[%j] : memref<10xf32> - } - } - gpu.return - } - // CHECK-LABEL: @simple_if_yield - gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel - attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { - // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr - // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr - // CHECK: spv.selection { - // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] - // CHECK-NEXT: [[TRUE]]: - // CHECK: %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32 - // CHECK: %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32 - // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32 - // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32 - // CHECK: spv.Branch ^[[MERGE:.*]] - // CHECK-NEXT: [[FALSE]]: - // CHECK: %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32 - // CHECK: %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32 - // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32 - // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32 - // CHECK: spv.Branch ^[[MERGE]] - // CHECK-NEXT: ^[[MERGE]]: - // CHECK: spv.mlir.merge - // CHECK-NEXT: } - // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 - // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 - // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 - // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 - // CHECK: spv.Return - %0:2 = scf.if %arg3 -> (f32, f32) { - %c0 = constant 0.0 : f32 - %c1 = constant 1.0 : f32 - scf.yield %c0, %c1 : f32, f32 - } else { - %c0 = constant 2.0 : f32 - %c1 = constant 3.0 : f32 - scf.yield %c1, %c0 : f32, f32 - } - %i = constant 0 : index - %j = constant 1 : index - store %0#0, %arg2[%i] : memref<10xf32> - store %0#1, %arg2[%j] : memref<10xf32> - gpu.return - } - // TODO: The transformation should only be legal if - // VariablePointer capability is supported. This test is still useful to - // make sure we can handle scf op result with type change. - // CHECK-LABEL: @simple_if_yield_type_change - // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0])>, StorageBuffer>, Function> - // CHECK: spv.selection { - // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] - // CHECK-NEXT: [[TRUE]]: - // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> - // CHECK: spv.Branch ^[[MERGE:.*]] - // CHECK-NEXT: [[FALSE]]: - // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> - // CHECK: spv.Branch ^[[MERGE]] - // CHECK-NEXT: ^[[MERGE]]: - // CHECK: spv.mlir.merge - // CHECK-NEXT: } - // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0])>, StorageBuffer> - // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0])>, StorageBuffer> - // CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 - // CHECK: spv.Return - gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel - attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { - %i = constant 0 : index - %value = constant 0.0 : f32 - %0 = scf.if %arg4 -> (memref<10xf32>) { - scf.yield %arg2 : memref<10xf32> - } else { - scf.yield %arg3 : memref<10xf32> - } - store %value, %0[%i] : memref<10xf32> - gpu.return - } - } -} diff --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir deleted file mode 100644 --- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir +++ /dev/null @@ -1,98 +0,0 @@ -// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s - -module attributes { - gpu.container_module, - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) { - %c0 = constant 1 : index - gpu.launch_func @kernels::@loop_kernel - blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0) - args(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) - return - } - - gpu.module @kernels { - gpu.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel - attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { - // CHECK: %[[LB:.*]] = spv.constant 4 : i32 - %lb = constant 4 : index - // CHECK: %[[UB:.*]] = spv.constant 42 : i32 - %ub = constant 42 : index - // CHECK: %[[STEP:.*]] = spv.constant 2 : i32 - %step = constant 2 : index - // CHECK: spv.loop { - // CHECK-NEXT: spv.Branch ^[[HEADER:.*]](%[[LB]] : i32) - // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32): - // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 - // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] - // CHECK: ^[[BODY]]: - // CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32 - // CHECK: %[[OFFSET1:.*]] = spv.constant 0 : i32 - // CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32 - // CHECK: %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32 - // CHECK: %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32 - // CHECK: spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}} - // CHECK: %[[ZERO2:.*]] = spv.constant 0 : i32 - // CHECK: %[[OFFSET2:.*]] = spv.constant 0 : i32 - // CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32 - // CHECK: %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32 - // CHECK: %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32 - // CHECK: spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]] - // CHECK: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 - // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32) - // CHECK: ^[[MERGE]] - // CHECK: spv.mlir.merge - // CHECK: } - scf.for %arg4 = %lb to %ub step %step { - %1 = load %arg2[%arg4] : memref<10xf32> - store %1, %arg3[%arg4] : memref<10xf32> - } - gpu.return - } - - - // CHECK-LABEL: @loop_yield - gpu.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel - attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { - // CHECK: %[[LB:.*]] = spv.constant 4 : i32 - %lb = constant 4 : index - // CHECK: %[[UB:.*]] = spv.constant 42 : i32 - %ub = constant 42 : index - // CHECK: %[[STEP:.*]] = spv.constant 2 : i32 - %step = constant 2 : index - // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32 - %s0 = constant 0.0 : f32 - // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32 - %s1 = constant 1.0 : f32 - // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr - // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr - // CHECK: spv.loop { - // CHECK: spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32) - // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32): - // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 - // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] - // CHECK: ^[[BODY]]: - // CHECK: %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32 - // CHECK-DAG: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 - // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32 - // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32 - // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32) - // CHECK: ^[[MERGE]]: - // CHECK: spv.mlir.merge - // CHECK: } - %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) { - %sn = addf %si, %si : f32 - scf.yield %sn, %sn : f32, f32 - } - // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 - // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 - // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 - // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 - store %result#0, %arg3[%lb] : memref<10xf32> - store %result#1, %arg3[%ub] : memref<10xf32> - gpu.return - } - } -} diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) { + // CHECK: %[[LB:.*]] = spv.constant 4 : i32 + %lb = constant 4 : index + // CHECK: %[[UB:.*]] = spv.constant 42 : i32 + %ub = constant 42 : index + // CHECK: %[[STEP:.*]] = spv.constant 2 : i32 + %step = constant 2 : index + // CHECK: spv.loop { + // CHECK-NEXT: spv.Branch ^[[HEADER:.*]](%[[LB]] : i32) + // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32): + // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] + // CHECK: ^[[BODY]]: + // CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET1:.*]] = spv.constant 0 : i32 + // CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32 + // CHECK: %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32 + // CHECK: %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32 + // CHECK: spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}} + // CHECK: %[[ZERO2:.*]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET2:.*]] = spv.constant 0 : i32 + // CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32 + // CHECK: %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32 + // CHECK: %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32 + // CHECK: spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]] + // CHECK: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 + // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32) + // CHECK: ^[[MERGE]] + // CHECK: spv.mlir.merge + // CHECK: } + scf.for %arg4 = %lb to %ub step %step { + %1 = load %arg2[%arg4] : memref<10xf32> + store %1, %arg3[%arg4] : memref<10xf32> + } + return +} + + +// CHECK-LABEL: @loop_yield +func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) { + // CHECK: %[[LB:.*]] = spv.constant 4 : i32 + %lb = constant 4 : index + // CHECK: %[[UB:.*]] = spv.constant 42 : i32 + %ub = constant 42 : index + // CHECK: %[[STEP:.*]] = spv.constant 2 : i32 + %step = constant 2 : index + // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32 + %s0 = constant 0.0 : f32 + // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32 + %s1 = constant 1.0 : f32 + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.loop { + // CHECK: spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32) + // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32): + // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] + // CHECK: ^[[BODY]]: + // CHECK: %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32 + // CHECK-DAG: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32 + // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32) + // CHECK: ^[[MERGE]]: + // CHECK: spv.mlir.merge + // CHECK: } + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) { + %sn = addf %si, %si : f32 + scf.yield %sn, %sn : f32, f32 + } + // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 + // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 + store %result#0, %arg3[%lb] : memref<10xf32> + store %result#1, %arg3[%ub] : memref<10xf32> + return +} + +} // end module diff --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir @@ -0,0 +1,156 @@ +// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @kernel_simple_selection +func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) { + %value = constant 0.0 : f32 + %i = constant 0 : index + + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: spv.Branch [[MERGE]] + // CHECK-NEXT: [[MERGE]]: + // CHECK-NEXT: spv.mlir.merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Return + + scf.if %arg3 { + store %value, %arg2[%i] : memref<10xf32> + } + return +} + +// CHECK-LABEL: @kernel_nested_selection +func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) { + %i = constant 0 : index + %j = constant 9 : index + + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]] + // CHECK-NEXT: [[TRUE_TOP]]: + // CHECK-NEXT: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]] + // CHECK-NEXT: [[TRUE_NESTED_TRUE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]] + // CHECK-NEXT: [[FALSE_NESTED_TRUE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_TRUE_PATH]] + // CHECK-NEXT: [[MERGE_NESTED_TRUE_PATH]]: + // CHECK-NEXT: spv.mlir.merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Branch [[MERGE_TOP:\^.*]] + // CHECK-NEXT: [[FALSE_TOP]]: + // CHECK-NEXT: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]] + // CHECK-NEXT: [[TRUE_NESTED_FALSE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]] + // CHECK-NEXT: [[FALSE_NESTED_FALSE_PATH]]: + // CHECK: spv.Branch [[MERGE_NESTED_FALSE_PATH]] + // CHECK: [[MERGE_NESTED_FALSE_PATH]]: + // CHECK-NEXT: spv.mlir.merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Branch [[MERGE_TOP]] + // CHECK-NEXT: [[MERGE_TOP]]: + // CHECK-NEXT: spv.mlir.merge + // CHECK-NEXT: } + // CHECK-NEXT: spv.Return + + scf.if %arg5 { + scf.if %arg6 { + %value = load %arg3[%i] : memref<10xf32> + store %value, %arg4[%i] : memref<10xf32> + } else { + %value = load %arg4[%i] : memref<10xf32> + store %value, %arg3[%i] : memref<10xf32> + } + } else { + scf.if %arg6 { + %value = load %arg3[%j] : memref<10xf32> + store %value, %arg4[%j] : memref<10xf32> + } else { + %value = load %arg4[%j] : memref<10xf32> + store %value, %arg3[%j] : memref<10xf32> + } + } + return +} + +// CHECK-LABEL: @simple_if_yield +func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) { + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32 + // CHECK: %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32 + // CHECK: spv.Branch ^[[MERGE:.*]] + // CHECK-NEXT: [[FALSE]]: + // CHECK: %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32 + // CHECK: %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32 + // CHECK: spv.Branch ^[[MERGE]] + // CHECK-NEXT: ^[[MERGE]]: + // CHECK: spv.mlir.merge + // CHECK-NEXT: } + // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 + // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 + // CHECK: spv.Return + %0:2 = scf.if %arg3 -> (f32, f32) { + %c0 = constant 0.0 : f32 + %c1 = constant 1.0 : f32 + scf.yield %c0, %c1 : f32, f32 + } else { + %c0 = constant 2.0 : f32 + %c1 = constant 3.0 : f32 + scf.yield %c1, %c0 : f32, f32 + } + %i = constant 0 : index + %j = constant 1 : index + store %0#0, %arg2[%i] : memref<10xf32> + store %0#1, %arg2[%j] : memref<10xf32> + return +} + +// TODO: The transformation should only be legal if VariablePointer capability +// is supported. This test is still useful to make sure we can handle scf op +// result with type change. +func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) { + // CHECK-LABEL: @simple_if_yield_type_change + // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0])>, StorageBuffer>, Function> + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> + // CHECK: spv.Branch ^[[MERGE:.*]] + // CHECK-NEXT: [[FALSE]]: + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0])>, StorageBuffer> + // CHECK: spv.Branch ^[[MERGE]] + // CHECK-NEXT: ^[[MERGE]]: + // CHECK: spv.mlir.merge + // CHECK-NEXT: } + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0])>, StorageBuffer> + // CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 + // CHECK: spv.Return + %i = constant 0 : index + %value = constant 0.0 : f32 + %0 = scf.if %arg4 -> (memref<10xf32>) { + scf.yield %arg2 : memref<10xf32> + } else { + scf.yield %arg3 : memref<10xf32> + } + store %value, %0[%i] : memref<10xf32> + return +} + +} // end module