diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -1142,12 +1142,12 @@ * From GPU dialect: headers are at [include/mlir/Conversion/GPUTOSPIRV][MlirGpuToSpirvHeaders]; libraries are at [lib/Conversion/GPUToSPIRV][MlirGpuToSpirvLibs]. -* From standard dialect: headers are at - [include/mlir/Conversion/StandardTOSPIRV][MlirStdToSpirvHeaders]; libraries - are at [lib/Conversion/StandardToSPIRV][MlirStdToSpirvLibs]. +* From Func dialect: headers are at + [include/mlir/Conversion/FuncToSPIRV][MlirFuncToSpirvHeaders]; libraries + are at [lib/Conversion/FuncToSPIRV][MlirFuncToSpirvLibs]. These dialect to dialect conversions have their dedicated libraries, -`MLIRGPUToSPIRV` and `MLIRStandardToSPIRV`, respectively. +`MLIRGPUToSPIRV` and `MLIRFuncToSPIRV`, respectively. There are also common utilities when targeting SPIR-V from any dialect: @@ -1407,8 +1407,8 @@ [MlirSpirvUnittests]: https://github.com/llvm/llvm-project/tree/main/mlir/unittests/Dialect/SPIRV [MlirGpuToSpirvHeaders]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Conversion/GPUToSPIRV [MlirGpuToSpirvLibs]: https://github.com/llvm/llvm-project/tree/main/mlir/lib/Conversion/GPUToSPIRV -[MlirStdToSpirvHeaders]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Conversion/StandardToSPIRV -[MlirStdToSpirvLibs]: https://github.com/llvm/llvm-project/tree/main/mlir/lib/Conversion/StandardToSPIRV +[MlirFuncToSpirvHeaders]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Conversion/FuncToSPIRV +[MlirFuncToSpirvLibs]: https://github.com/llvm/llvm-project/tree/main/mlir/lib/Conversion/FuncToSPIRV [MlirSpirvDialect]: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVDialect.h [MlirSpirvTypes]: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h [MlirSpirvOpsH]: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h @@ -0,0 +1,25 @@ +//===- ControlFlowToSPIRVPass.h - ControlFLow to SPIR-V Passes --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides passes to convert ControlFlow dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRVPASS_H +#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Creates a pass to convert ControlFlow ops to SPIR-V ops. +std::unique_ptr> createConvertControlFlowToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h @@ -0,0 +1,29 @@ +//===- FuncToSPIRV.h - Func to SPIR-V Patterns ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert Func dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H +#define MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating Func ops +/// to SPIR-V ops. Also adds the patterns to legalize ops not directly +/// translated to SPIR-V dialect. +void populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h rename from mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h rename to mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h --- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h @@ -1,4 +1,4 @@ -//===- StandardToSPIRVPass.h - Standard to SPIR-V Passes --------*- C++ -*-===// +//===- FuncToSPIRVPass.h - Func to SPIR-V Passes ----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,20 +6,20 @@ // //===----------------------------------------------------------------------===// // -// Provides passes to convert Standard dialect to SPIR-V dialect. +// Provides passes to convert Func dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H -#define MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H +#ifndef MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRVPASS_H +#define MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRVPASS_H #include "mlir/Pass/Pass.h" namespace mlir { -/// Creates a pass to convert standard ops to SPIR-V ops. -std::unique_ptr> createConvertStandardToSPIRVPass(); +/// Creates a pass to convert Func ops to SPIR-V ops. +std::unique_ptr> createConvertFuncToSPIRVPass(); } // namespace mlir -#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H +#endif // MLIR_CONVERSION_FUNCTOSPIRV_FUNCTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -19,6 +19,8 @@ #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" +#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" @@ -44,7 +46,7 @@ #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" +#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/TosaToSCF/TosaToSCF.h" #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -95,7 +95,7 @@ // ArithmeticToSPIRV //===----------------------------------------------------------------------===// -def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv", "FuncOp"> { +def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv", "ModuleOp"> { let summary = "Convert Arithmetic dialect to SPIR-V dialect"; let constructor = "mlir::arith::createConvertArithmeticToSPIRVPass()"; let dependentDialects = ["spirv::SPIRVDialect"]; @@ -202,6 +202,38 @@ ]; } +//===----------------------------------------------------------------------===// +// ControlFlowToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv", "ModuleOp"> { + let summary = "Convert ControlFlow dialect to SPIR-V dialect"; + let constructor = "mlir::createConvertControlFlowToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; + let options = [ + Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types", + "bool", /*default=*/"true", + "Emulate non-32-bit scalar types with 32-bit ones if " + "missing native support"> + ]; +} + +//===----------------------------------------------------------------------===// +// FuncToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv", "ModuleOp"> { + let summary = "Convert Func dialect to SPIR-V dialect"; + let constructor = "mlir::createConvertFuncToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; + let options = [ + Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types", + "bool", /*default=*/"true", + "Emulate non-32-bit scalar types with 32-bit ones if " + "missing native support"> + ]; +} + //===----------------------------------------------------------------------===// // GPUCommon //===----------------------------------------------------------------------===// @@ -637,12 +669,12 @@ } //===----------------------------------------------------------------------===// -// StandardToSPIRV +// TensorToSPIRV //===----------------------------------------------------------------------===// -def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> { - let summary = "Convert Standard dialect to SPIR-V dialect"; - let constructor = "mlir::createConvertStandardToSPIRVPass()"; +def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv", "ModuleOp"> { + let summary = "Convert Tensor dialect to SPIR-V dialect"; + let constructor = "mlir::createConvertTensorToSPIRVPass()"; let dependentDialects = ["spirv::SPIRVDialect"]; let options = [ Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types", diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h rename from mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h rename to mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h --- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h @@ -1,4 +1,4 @@ -//===- StandardToSPIRV.h - Standard to SPIR-V Patterns --------*- C++ -*-===// +//===- TensorToSPIRV.h - Tensor to SPIR-V Patterns --------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,24 +6,18 @@ // //===----------------------------------------------------------------------===// // -// Provides patterns to convert Standard dialect to SPIR-V dialect. +// Provides patterns to convert Tensor dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H -#define MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H +#ifndef MLIR_CONVERSION_TENSORTOSPIRV_TENSORTOSPIRV_H +#define MLIR_CONVERSION_TENSORTOSPIRV_TENSORTOSPIRV_H #include "mlir/Transforms/DialectConversion.h" namespace mlir { class SPIRVTypeConverter; -/// Appends to a pattern list additional patterns for translating standard ops -/// to SPIR-V ops. Also adds the patterns to legalize ops not directly -/// translated to SPIR-V dialect. -void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns); - /// Appends to a pattern list additional patterns for translating tensor ops /// to SPIR-V ops. /// @@ -42,4 +36,4 @@ } // namespace mlir -#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H +#endif // MLIR_CONVERSION_TENSORTOSPIRV_TENSORTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h @@ -0,0 +1,25 @@ +//===- TensorToSPIRVPass.h - Tensor to SPIR-V Passes --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides passes to convert Tensor dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TENSORTOSPIRV_TENSORTOSPIRVPASS_H +#define MLIR_CONVERSION_TENSORTOSPIRV_TENSORTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Creates a pass to convert Tensor ops to SPIR-V ops. +std::unique_ptr> createConvertTensorToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_TENSORTOSPIRV_TENSORTOSPIRVPASS_H diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "../PassDetail.h" #include "../SPIRVCommon/Pattern.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -843,7 +844,14 @@ TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - SelectOpPattern + SelectOpPattern, + + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern >(typeConverter, patterns.getContext()); // clang-format on @@ -861,7 +869,7 @@ struct ConvertArithmeticToSPIRVPass : public ConvertArithmeticToSPIRVBase { void runOnOperation() override { - auto module = getOperation()->getParentOfType(); + auto module = getOperation(); auto targetAttr = spirv::lookupTargetEnvOrDefault(module); auto target = SPIRVConversionTarget::get(targetAttr); @@ -870,10 +878,11 @@ SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(&getContext()); - mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); - if (failed(applyPartialConversion(getOperation(), *target, - std::move(patterns)))) + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRFuncToSPIRV MLIRSPIRVConversion MLIRSPIRV ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(ComplexToStandard) add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSPIRV) +add_subdirectory(FuncToSPIRV) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) @@ -34,7 +35,7 @@ add_subdirectory(ShapeToStandard) add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) -add_subdirectory(StandardToSPIRV) +add_subdirectory(TensorToSPIRV) add_subdirectory(TosaToLinalg) add_subdirectory(TosaToSCF) add_subdirectory(TosaToStandard) diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRControlFlowToSPIRV ControlFlowToSPIRV.cpp + ControlFlowToSPIRVPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp rename from mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp rename to mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -1,4 +1,4 @@ -//===- StandardToSPIRVPass.cpp - Standard to SPIR-V Passes ----------------===// +//===- ControlFlowToSPIRVPass.cpp - ControlFlow to SPIR-V Pass ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,30 +6,27 @@ // //===----------------------------------------------------------------------===// // -// This file implements a pass to convert standard dialect to SPIR-V dialect. +// This file implements a pass to convert ControlFlow dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" +#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" #include "../PassDetail.h" -#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" -#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" using namespace mlir; namespace { -/// A pass converting MLIR Standard operations into the SPIR-V dialect. -class ConvertStandardToSPIRVPass - : public ConvertStandardToSPIRVBase { +/// A pass converting MLIR ControlFlow operations into the SPIR-V dialect. +class ConvertControlFlowToSPIRVPass + : public ConvertControlFlowToSPIRVBase { void runOnOperation() override; }; } // namespace -void ConvertStandardToSPIRVPass::runOnOperation() { +void ConvertControlFlowToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -41,22 +38,14 @@ options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); - // TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to - // StandardToSPIRV RewritePatternSet patterns(context); - arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns); - populateMathToSPIRVPatterns(typeConverter, patterns); - populateStandardToSPIRVPatterns(typeConverter, patterns); - populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, - patterns); - populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } std::unique_ptr> -mlir::createConvertStandardToSPIRVPass() { - return std::make_unique(); +mlir::createConvertControlFlowToSPIRVPass() { + return std::make_unique(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/FuncToSPIRV/CMakeLists.txt copy from mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt copy to mlir/lib/Conversion/FuncToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/FuncToSPIRV/CMakeLists.txt @@ -1,5 +1,6 @@ -add_mlir_conversion_library(MLIRControlFlowToSPIRV - ControlFlowToSPIRV.cpp +add_mlir_conversion_library(MLIRFuncToSPIRV + FuncToSPIRV.cpp + FuncToSPIRVPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV @@ -9,8 +10,8 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRFunc MLIRIR - MLIRControlFlow MLIRPass MLIRSPIRV MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp @@ -0,0 +1,71 @@ +//===- FuncToSPIRV.cpp - Func to SPIR-V Patterns ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert Func dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" +#include "../SPIRVCommon/Pattern.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "func-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +// Note that DRR cannot be used for the patterns in this file: we may need to +// convert type along the way, which requires ConversionPattern. DRR generates +// normal RewritePattern. + +namespace { + +/// Converts func.return to spv.Return. +class ReturnOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (returnOp.getNumOperands() > 1) + return failure(); + + if (returnOp.getNumOperands() == 1) { + rewriter.replaceOpWithNewOp( + returnOp, adaptor.getOperands()[0]); + } else { + rewriter.replaceOpWithNewOp(returnOp); + } + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + + patterns.add(typeConverter, context); +} diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -0,0 +1,51 @@ +//===- FuncToSPIRVPass.cpp - Func to SPIR-V Passes ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert Func dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR Func operations into the SPIR-V dialect. +class ConvertFuncToSPIRVPass + : public ConvertFuncToSPIRVBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertFuncToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter::Options options; + options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; + SPIRVTypeConverter typeConverter(targetAttr, options); + + RewritePatternSet patterns(context); + populateFuncToSPIRVPatterns(typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertFuncToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -8,12 +8,12 @@ LINK_LIBS PUBLIC MLIRArithmeticToSPIRV MLIRGPUOps + MLIRFuncToSPIRV MLIRIR MLIRPass MLIRSCFToSPIRV MLIRSPIRV MLIRSPIRVConversion - MLIRStandardToSPIRV MLIRSupport MLIRTransforms ) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -15,9 +15,9 @@ #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -66,7 +66,7 @@ // patterns. mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); - populateStandardToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); if (failed(applyFullConversion(kernelModules, *target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -10,10 +10,10 @@ LINK_LIBS PUBLIC MLIRArithmeticToSPIRV + MLIRFuncToSPIRV MLIRMemRefToSPIRV MLIRSPIRV MLIRSPIRVConversion - MLIRStandardToSPIRV MLIRIR MLIRPass MLIRSCF diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -14,9 +14,9 @@ #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -45,7 +45,7 @@ // TODO: Change SPIR-V conversion to be progressive and remove the following // patterns. mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); - populateStandardToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/TensorToSPIRV/CMakeLists.txt rename from mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt rename to mlir/lib/Conversion/TensorToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/TensorToSPIRV/CMakeLists.txt @@ -1,6 +1,6 @@ -add_mlir_conversion_library(MLIRStandardToSPIRV - StandardToSPIRV.cpp - StandardToSPIRVPass.cpp +add_mlir_conversion_library(MLIRTensorToSPIRV + TensorToSPIRV.cpp + TensorToSPIRVPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV @@ -11,11 +11,8 @@ LINK_LIBS PUBLIC MLIRArithmeticToSPIRV - MLIRControlFlowToSPIRV - MLIRFunc + MLIRFuncToSPIRV MLIRIR - MLIRMathToSPIRV - MLIRMemRef MLIRPass MLIRSPIRV MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp rename from mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp rename to mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -1,4 +1,4 @@ -//===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===// +//===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// // -// This file implements patterns to convert standard dialect to SPIR-V dialect. +// This file implements patterns to convert Tensor dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" #include "../SPIRVCommon/Pattern.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -22,7 +22,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "std-to-spirv-pattern" +#define DEBUG_TYPE "tensor-to-spirv-pattern" using namespace mlir; @@ -30,22 +30,8 @@ // Operation conversion //===----------------------------------------------------------------------===// -// Note that DRR cannot be used for the patterns in this file: we may need to -// convert type along the way, which requires ConversionPattern. DRR generates -// normal RewritePattern. - namespace { -/// Converts func.return to spv.Return. -class ReturnOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final @@ -109,51 +95,13 @@ } // namespace -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -LogicalResult -ReturnOpPattern::matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - if (returnOp.getNumOperands() > 1) - return failure(); - - if (returnOp.getNumOperands() == 1) { - rewriter.replaceOpWithNewOp(returnOp, - adaptor.getOperands()[0]); - } else { - rewriter.replaceOpWithNewOp(returnOp); - } - return success(); -} - //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// -namespace mlir { -void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns) { - MLIRContext *context = patterns.getContext(); - - patterns.add< - // Unary and binary patterns - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - - ReturnOpPattern>(typeConverter, context); -} - -void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - int64_t byteCountThreshold, - RewritePatternSet &patterns) { +void mlir::populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + int64_t byteCountThreshold, + RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext(), byteCountThreshold); } - -} // namespace mlir diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -0,0 +1,55 @@ +//===- TensorToSPIRVPass.cpp - Tensor to SPIR-V Passes ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert Tensor dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" +#include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" +#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR Tensor operations into the SPIR-V dialect. +class ConvertTensorToSPIRVPass + : public ConvertTensorToSPIRVBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter::Options options; + options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; + SPIRVTypeConverter typeConverter(targetAttr, options); + + RewritePatternSet patterns(context); + arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); + populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, + patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::createConvertTensorToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// // arithmetic ops diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir --- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir +++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-cf-to-spirv -verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// // cf.br, cf.cond_br @@ -9,29 +9,26 @@ } { // CHECK-LABEL: func @simple_loop -func @simple_loop(index, index, index) { -^bb0(%begin : index, %end : index, %step : index): +func @simple_loop(%begin: i32, %end: i32, %step: i32) { // CHECK-NEXT: spv.Branch ^bb1 cf.br ^bb1 // CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32) ^bb1: // pred: ^bb0 - cf.br ^bb2(%begin : index) + cf.br ^bb2(%begin : i32) // CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3 -// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32 -// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4 -^bb2(%0: index): // 2 preds: ^bb1, ^bb3 - %1 = arith.cmpi slt, %0, %end : index +// CHECK: spv.BranchConditional {{.*}}, ^bb3, ^bb4 +^bb2(%0: i32): // 2 preds: ^bb1, ^bb3 + %1 = arith.cmpi slt, %0, %end : i32 cf.cond_br %1, ^bb3, ^bb4 // CHECK: ^bb3: // pred: ^bb2 -// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32 -// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32) +// CHECK: spv.Branch ^bb2({{.*}} : i32) ^bb3: // pred: ^bb2 - %2 = arith.addi %0, %step : index - cf.br ^bb2(%2 : index) + %2 = arith.addi %0, %step : i32 + cf.br ^bb2(%2 : i32) // CHECK: ^bb4: // pred: ^bb2 ^bb4: // pred: ^bb2 diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt -split-input-file -convert-func-to-spirv -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// func.return +//===----------------------------------------------------------------------===// + +module attributes { + spv.target_env = #spv.target_env<#spv.vce, {}> +} { + +// CHECK-LABEL: spv.func @return_one_val +// CHECK-SAME: (%[[ARG:.+]]: f32) +func @return_one_val(%arg0: f32) -> f32 { + // CHECK: spv.ReturnValue %[[ARG]] : f32 + return %arg0: f32 +} + +// Check that multiple-return functions are not converted. +// CHECK-LABEL: func @return_multi_val +func @return_multi_val(%arg0: f32) -> (f32, f32) { + // CHECK: return + return %arg0, %arg0: f32, f32 +} + +} + +// ----- diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir rename from mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir rename to mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s -// RUN: mlir-opt -split-input-file -convert-std-to-spirv="emulate-non-32-bit-scalar-types=false" %s -o - | FileCheck %s --check-prefix=NOEMU +// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-non-32-bit-scalar-types=false" %s -o - | FileCheck %s --check-prefix=NOEMU //===----------------------------------------------------------------------===// // Integer types diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir deleted file mode 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ /dev/null @@ -1,902 +0,0 @@ -// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s | FileCheck %s - -//===----------------------------------------------------------------------===// -// arith.select -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - -// Check integer operation conversions. -// CHECK-LABEL: @int32_scalar -func @int32_scalar(%lhs: i32, %rhs: i32) { - // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32 - %0 = arith.addi %lhs, %rhs: i32 - // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32 - %1 = arith.subi %lhs, %rhs: i32 - // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32 - %2 = arith.muli %lhs, %rhs: i32 - // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32 - %3 = arith.divsi %lhs, %rhs: i32 - // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32 - %4 = arith.divui %lhs, %rhs: i32 - // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 - %5 = arith.remui %lhs, %rhs: i32 - // CHECK: spv.GLSL.SMax %{{.*}}, %{{.*}}: i32 - %6 = arith.maxsi %lhs, %rhs : i32 - // CHECK: spv.GLSL.UMax %{{.*}}, %{{.*}}: i32 - %7 = arith.maxui %lhs, %rhs : i32 - // CHECK: spv.GLSL.SMin %{{.*}}, %{{.*}}: i32 - %8 = arith.minsi %lhs, %rhs : i32 - // CHECK: spv.GLSL.UMin %{{.*}}, %{{.*}}: i32 - %9 = arith.minui %lhs, %rhs : i32 - return -} - -// CHECK-LABEL: @scalar_srem -// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) -func @scalar_srem(%lhs: i32, %rhs: i32) { - // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32 - // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32 - // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32 - // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32 - // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32 - // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32 - %0 = arith.remsi %lhs, %rhs: i32 - return -} - -// Check float unary operation conversions. -// CHECK-LABEL: @float32_unary_scalar -func @float32_unary_scalar(%arg0: f32) { - // CHECK: spv.GLSL.FAbs %{{.*}}: f32 - %0 = math.abs %arg0 : f32 - // CHECK: spv.GLSL.Ceil %{{.*}}: f32 - %1 = math.ceil %arg0 : f32 - // CHECK: spv.FNegate %{{.*}}: f32 - %5 = arith.negf %arg0 : f32 - // CHECK: spv.GLSL.Floor %{{.*}}: f32 - %10 = math.floor %arg0 : f32 - return -} - -// Check float binary operation conversions. -// CHECK-LABEL: @float32_binary_scalar -func @float32_binary_scalar(%lhs: f32, %rhs: f32) { - // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 - %0 = arith.addf %lhs, %rhs: f32 - // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32 - %1 = arith.subf %lhs, %rhs: f32 - // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32 - %2 = arith.mulf %lhs, %rhs: f32 - // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32 - %3 = arith.divf %lhs, %rhs: f32 - // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 - %4 = arith.remf %lhs, %rhs: f32 - // CHECK: spv.GLSL.FMax %{{.*}}, %{{.*}}: f32 - %5 = arith.maxf %lhs, %rhs: f32 - // CHECK: spv.GLSL.FMin %{{.*}}, %{{.*}}: f32 - %6 = arith.minf %lhs, %rhs: f32 - return -} - -// Check int vector types. -// CHECK-LABEL: @int_vector234 -func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) { - // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi8> - %0 = arith.divsi %arg0, %arg0: vector<2xi8> - // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi64> - %1 = arith.divui %arg1, %arg1: vector<4xi64> - return -} - -// CHECK-LABEL: @vector_srem -// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>) -func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) { - // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : vector<3xi16> - // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : vector<3xi16> - // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16> - // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16> - // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16> - // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16> - %0 = arith.remsi %arg0, %arg1: vector<3xi16> - return -} - -// Check float vector types. -// CHECK-LABEL: @float_vector234 -func @float_vector234(%arg0: vector<2xf16>, %arg1: vector<3xf64>) { - // CHECK: spv.FAdd %{{.*}}, %{{.*}}: vector<2xf16> - %0 = arith.addf %arg0, %arg0: vector<2xf16> - // CHECK: spv.FMul %{{.*}}, %{{.*}}: vector<3xf64> - %1 = arith.mulf %arg1, %arg1: vector<3xf64> - return -} - -// CHECK-LABEL: @one_elem_vector -func @one_elem_vector(%arg0: vector<1xi32>) { - // CHECK: spv.IAdd %{{.+}}, %{{.+}}: i32 - %0 = arith.addi %arg0, %arg0: vector<1xi32> - return -} - -// CHECK-LABEL: @unsupported_5elem_vector -func @unsupported_5elem_vector(%arg0: vector<5xi32>) { - // CHECK: subi - %1 = arith.subi %arg0, %arg0: vector<5xi32> - return -} - -// CHECK-LABEL: @unsupported_2x2elem_vector -func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) { - // CHECK: muli - %2 = arith.muli %arg0, %arg0: vector<2x2xi32> - return -} - -} // end module - -// ----- - -// Check that types are converted to 32-bit when no special capabilities. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @int_vector23 -func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) { - // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> - %0 = arith.divsi %arg0, %arg0: vector<2xi8> - // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<3xi32> - %1 = arith.divsi %arg1, %arg1: vector<3xi16> - return -} - -// CHECK-LABEL: @float_scalar -func @float_scalar(%arg0: f16, %arg1: f64) { - // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 - %0 = arith.addf %arg0, %arg0: f16 - // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32 - %1 = arith.mulf %arg1, %arg1: f64 - return -} - -} // end module - -// ----- - -// Check that types are converted to 32-bit when no special capabilities that -// are not supported. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}} -func @int_vector4_invalid(%arg0: vector<4xi64>) { - // expected-error@below {{bitwidth emulation is not implemented yet on unsigned op}} - // expected-note@below {{see existing live user here}} - %0 = arith.divui %arg0, %arg0: vector<4xi64> - return -} - -} // end module - -// ----- - -//===----------------------------------------------------------------------===// -// std bit ops -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @bitwise_scalar -func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { - // CHECK: spv.BitwiseAnd - %0 = arith.andi %arg0, %arg1 : i32 - // CHECK: spv.BitwiseOr - %1 = arith.ori %arg0, %arg1 : i32 - // CHECK: spv.BitwiseXor - %2 = arith.xori %arg0, %arg1 : i32 - return -} - -// CHECK-LABEL: @bitwise_vector -func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { - // CHECK: spv.BitwiseAnd - %0 = arith.andi %arg0, %arg1 : vector<4xi32> - // CHECK: spv.BitwiseOr - %1 = arith.ori %arg0, %arg1 : vector<4xi32> - // CHECK: spv.BitwiseXor - %2 = arith.xori %arg0, %arg1 : vector<4xi32> - return -} - -// CHECK-LABEL: @logical_scalar -func @logical_scalar(%arg0 : i1, %arg1 : i1) { - // CHECK: spv.LogicalAnd - %0 = arith.andi %arg0, %arg1 : i1 - // CHECK: spv.LogicalOr - %1 = arith.ori %arg0, %arg1 : i1 - // CHECK: spv.LogicalNotEqual - %2 = arith.xori %arg0, %arg1 : i1 - return -} - -// CHECK-LABEL: @logical_vector -func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { - // CHECK: spv.LogicalAnd - %0 = arith.andi %arg0, %arg1 : vector<4xi1> - // CHECK: spv.LogicalOr - %1 = arith.ori %arg0, %arg1 : vector<4xi1> - // CHECK: spv.LogicalNotEqual - %2 = arith.xori %arg0, %arg1 : vector<4xi1> - return -} - -// CHECK-LABEL: @shift_scalar -func @shift_scalar(%arg0 : i32, %arg1 : i32) { - // CHECK: spv.ShiftLeftLogical - %0 = arith.shli %arg0, %arg1 : i32 - // CHECK: spv.ShiftRightArithmetic - %1 = arith.shrsi %arg0, %arg1 : i32 - // CHECK: spv.ShiftRightLogical - %2 = arith.shrui %arg0, %arg1 : i32 - return -} - -// CHECK-LABEL: @shift_vector -func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { - // CHECK: spv.ShiftLeftLogical - %0 = arith.shli %arg0, %arg1 : vector<4xi32> - // CHECK: spv.ShiftRightArithmetic - %1 = arith.shrsi %arg0, %arg1 : vector<4xi32> - // CHECK: spv.ShiftRightLogical - %2 = arith.shrui %arg0, %arg1 : vector<4xi32> - return -} - -} // end module - -// ----- - -//===----------------------------------------------------------------------===// -// arith.cmpf -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @cmpf -func @cmpf(%arg0 : f32, %arg1 : f32) { - // CHECK: spv.FOrdEqual - %1 = arith.cmpf oeq, %arg0, %arg1 : f32 - // CHECK: spv.FOrdGreaterThan - %2 = arith.cmpf ogt, %arg0, %arg1 : f32 - // CHECK: spv.FOrdGreaterThanEqual - %3 = arith.cmpf oge, %arg0, %arg1 : f32 - // CHECK: spv.FOrdLessThan - %4 = arith.cmpf olt, %arg0, %arg1 : f32 - // CHECK: spv.FOrdLessThanEqual - %5 = arith.cmpf ole, %arg0, %arg1 : f32 - // CHECK: spv.FOrdNotEqual - %6 = arith.cmpf one, %arg0, %arg1 : f32 - // CHECK: spv.FUnordEqual - %7 = arith.cmpf ueq, %arg0, %arg1 : f32 - // CHECK: spv.FUnordGreaterThan - %8 = arith.cmpf ugt, %arg0, %arg1 : f32 - // CHECK: spv.FUnordGreaterThanEqual - %9 = arith.cmpf uge, %arg0, %arg1 : f32 - // CHECK: spv.FUnordLessThan - %10 = arith.cmpf ult, %arg0, %arg1 : f32 - // CHECK: FUnordLessThanEqual - %11 = arith.cmpf ule, %arg0, %arg1 : f32 - // CHECK: spv.FUnordNotEqual - %12 = arith.cmpf une, %arg0, %arg1 : f32 - return -} - -} // end module - -// ----- - -// With Kernel capability, we can convert NaN check to spv.Ordered/spv.Unordered. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @cmpf -func @cmpf(%arg0 : f32, %arg1 : f32) { - // CHECK: spv.Ordered - %0 = arith.cmpf ord, %arg0, %arg1 : f32 - // CHECK: spv.Unordered - %1 = arith.cmpf uno, %arg0, %arg1 : f32 - return -} - -} // end module - -// ----- - -// Without Kernel capability, we need to convert NaN check to spv.IsNan. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @cmpf -// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 -func @cmpf(%arg0 : f32, %arg1 : f32) { - // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 - // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 - // CHECK-NEXT: %[[OR:.+]] = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1 - // CHECK-NEXT: %{{.+}} = spv.LogicalNot %[[OR]] : i1 - %0 = arith.cmpf ord, %arg0, %arg1 : f32 - - // CHECK-NEXT: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 - // CHECK-NEXT: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 - // CHECK-NEXT: %{{.+}} = spv.LogicalOr %[[LHS_NAN]], %[[RHS_NAN]] : i1 - %1 = arith.cmpf uno, %arg0, %arg1 : f32 - return -} - -} // end module - -// ----- - -//===----------------------------------------------------------------------===// -// arith.cmpi -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @cmpi -func @cmpi(%arg0 : i32, %arg1 : i32) { - // CHECK: spv.IEqual - %0 = arith.cmpi eq, %arg0, %arg1 : i32 - // CHECK: spv.INotEqual - %1 = arith.cmpi ne, %arg0, %arg1 : i32 - // CHECK: spv.SLessThan - %2 = arith.cmpi slt, %arg0, %arg1 : i32 - // CHECK: spv.SLessThanEqual - %3 = arith.cmpi sle, %arg0, %arg1 : i32 - // CHECK: spv.SGreaterThan - %4 = arith.cmpi sgt, %arg0, %arg1 : i32 - // CHECK: spv.SGreaterThanEqual - %5 = arith.cmpi sge, %arg0, %arg1 : i32 - // CHECK: spv.ULessThan - %6 = arith.cmpi ult, %arg0, %arg1 : i32 - // CHECK: spv.ULessThanEqual - %7 = arith.cmpi ule, %arg0, %arg1 : i32 - // CHECK: spv.UGreaterThan - %8 = arith.cmpi ugt, %arg0, %arg1 : i32 - // CHECK: spv.UGreaterThanEqual - %9 = arith.cmpi uge, %arg0, %arg1 : i32 - return -} - -// CHECK-LABEL: @boolcmpi -func @boolcmpi(%arg0 : i1, %arg1 : i1) { - // CHECK: spv.LogicalEqual - %0 = arith.cmpi eq, %arg0, %arg1 : i1 - // CHECK: spv.LogicalNotEqual - %1 = arith.cmpi ne, %arg0, %arg1 : i1 - return -} - -// CHECK-LABEL: @vecboolcmpi -func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { - // CHECK: spv.LogicalEqual - %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1> - // CHECK: spv.LogicalNotEqual - %1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1> - return -} - -} // end module - -// ----- - -//===----------------------------------------------------------------------===// -// arith.constant -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - -// CHECK-LABEL: @constant -func @constant() { - // CHECK: spv.Constant true - %0 = arith.constant true - // CHECK: spv.Constant 42 : i32 - %1 = arith.constant 42 : i32 - // CHECK: spv.Constant 5.000000e-01 : f32 - %2 = arith.constant 0.5 : f32 - // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32> - %3 = arith.constant dense<[2, 3]> : vector<2xi32> - // CHECK: spv.Constant 1 : i32 - %4 = arith.constant 1 : index - // CHECK: spv.Constant dense<1> : tensor<6xi32> : !spv.array<6 x i32, stride=4> - %5 = arith.constant dense<1> : tensor<2x3xi32> - // CHECK: spv.Constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32, stride=4> - %6 = arith.constant dense<1.0> : tensor<2x3xf32> - // CHECK: spv.Constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32, stride=4> - %7 = arith.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4> - %8 = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4> - %9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> - // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4> - %10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - return -} - -// CHECK-LABEL: @constant_16bit -func @constant_16bit() { - // CHECK: spv.Constant 4 : i16 - %0 = arith.constant 4 : i16 - // CHECK: spv.Constant 5.000000e+00 : f16 - %1 = arith.constant 5.0 : f16 - // CHECK: spv.Constant dense<[2, 3]> : vector<2xi16> - %2 = arith.constant dense<[2, 3]> : vector<2xi16> - // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16, stride=2> - %3 = arith.constant dense<4.0> : tensor<5xf16> - return -} - -// CHECK-LABEL: @constant_64bit -func @constant_64bit() { - // CHECK: spv.Constant 4 : i64 - %0 = arith.constant 4 : i64 - // CHECK: spv.Constant 5.000000e+00 : f64 - %1 = arith.constant 5.0 : f64 - // CHECK: spv.Constant dense<[2, 3]> : vector<2xi64> - %2 = arith.constant dense<[2, 3]> : vector<2xi64> - // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64, stride=8> - %3 = arith.constant dense<4.0> : tensor<5xf64> - return -} - -} // end module - -// ----- - -// Check that constants are converted to 32-bit when no special capability. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @constant_16bit -func @constant_16bit() { - // CHECK: spv.Constant 4 : i32 - %0 = arith.constant 4 : i16 - // CHECK: spv.Constant 5.000000e+00 : f32 - %1 = arith.constant 5.0 : f16 - // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32> - %2 = arith.constant dense<[2, 3]> : vector<2xi16> - // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4> - %3 = arith.constant dense<4.0> : tensor<5xf16> - // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4> - %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16> - return -} - -// CHECK-LABEL: @constant_64bit -func @constant_64bit() { - // CHECK: spv.Constant 4 : i32 - %0 = arith.constant 4 : i64 - // CHECK: spv.Constant 5.000000e+00 : f32 - %1 = arith.constant 5.0 : f64 - // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32> - %2 = arith.constant dense<[2, 3]> : vector<2xi64> - // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4> - %3 = arith.constant dense<4.0> : tensor<5xf64> - // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4> - %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16> - return -} - -// CHECK-LABEL: @corner_cases -func @corner_cases() { - // CHECK: %{{.*}} = spv.Constant -1 : i32 - %0 = arith.constant 4294967295 : i64 // 2^32 - 1 - // CHECK: %{{.*}} = spv.Constant 2147483647 : i32 - %1 = arith.constant 2147483647 : i64 // 2^31 - 1 - // CHECK: %{{.*}} = spv.Constant -2147483648 : i32 - %2 = arith.constant 2147483648 : i64 // 2^31 - // CHECK: %{{.*}} = spv.Constant -2147483648 : i32 - %3 = arith.constant -2147483648 : i64 // -2^31 - - // CHECK: %{{.*}} = spv.Constant -1 : i32 - %5 = arith.constant -1 : i64 - // CHECK: %{{.*}} = spv.Constant -2 : i32 - %6 = arith.constant -2 : i64 - // CHECK: %{{.*}} = spv.Constant -1 : i32 - %7 = arith.constant -1 : index - // CHECK: %{{.*}} = spv.Constant -2 : i32 - %8 = arith.constant -2 : index - - - // CHECK: spv.Constant false - %9 = arith.constant false - // CHECK: spv.Constant true - %10 = arith.constant true - - return -} - -// CHECK-LABEL: @unsupported_cases -func @unsupported_cases() { - // CHECK: %{{.*}} = arith.constant 4294967296 : i64 - %0 = arith.constant 4294967296 : i64 // 2^32 - // CHECK: %{{.*}} = arith.constant -2147483649 : i64 - %1 = arith.constant -2147483649 : i64 // -2^31 - 1 - // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64 - %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1 - return -} - -} // end module - -// ----- - -//===----------------------------------------------------------------------===// -// std cast ops -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - -// CHECK-LABEL: index_cast1 -func @index_cast1(%arg0: i16) { - // CHECK: spv.SConvert %{{.+}} : i16 to i32 - %0 = arith.index_cast %arg0 : i16 to index - return -} - -// CHECK-LABEL: index_cast2 -func @index_cast2(%arg0: index) { - // CHECK: spv.SConvert %{{.+}} : i32 to i16 - %0 = arith.index_cast %arg0 : index to i16 - return -} - -// CHECK-LABEL: index_cast3 -func @index_cast3(%arg0: i32) { - // CHECK-NOT: spv.SConvert - %0 = arith.index_cast %arg0 : i32 to index - return -} - -// CHECK-LABEL: index_cast4 -func @index_cast4(%arg0: index) { - // CHECK-NOT: spv.SConvert - %0 = arith.index_cast %arg0 : index to i32 - return -} - -// CHECK-LABEL: @fpext1 -func @fpext1(%arg0: f16) -> f64 { - // CHECK: spv.FConvert %{{.*}} : f16 to f64 - %0 = arith.extf %arg0 : f16 to f64 - return %0 : f64 -} - -// CHECK-LABEL: @fpext2 -func @fpext2(%arg0 : f32) -> f64 { - // CHECK: spv.FConvert %{{.*}} : f32 to f64 - %0 = arith.extf %arg0 : f32 to f64 - return %0 : f64 -} - -// CHECK-LABEL: @fptrunc1 -func @fptrunc1(%arg0 : f64) -> f16 { - // CHECK: spv.FConvert %{{.*}} : f64 to f16 - %0 = arith.truncf %arg0 : f64 to f16 - return %0 : f16 -} - -// CHECK-LABEL: @fptrunc2 -func @fptrunc2(%arg0: f32) -> f16 { - // CHECK: spv.FConvert %{{.*}} : f32 to f16 - %0 = arith.truncf %arg0 : f32 to f16 - return %0 : f16 -} - -// CHECK-LABEL: @sitofp1 -func @sitofp1(%arg0 : i32) -> f32 { - // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32 - %0 = arith.sitofp %arg0 : i32 to f32 - return %0 : f32 -} - -// CHECK-LABEL: @sitofp2 -func @sitofp2(%arg0 : i64) -> f64 { - // CHECK: spv.ConvertSToF %{{.*}} : i64 to f64 - %0 = arith.sitofp %arg0 : i64 to f64 - return %0 : f64 -} - -// CHECK-LABEL: @uitofp_i16_f32 -func @uitofp_i16_f32(%arg0: i16) -> f32 { - // CHECK: spv.ConvertUToF %{{.*}} : i16 to f32 - %0 = arith.uitofp %arg0 : i16 to f32 - return %0 : f32 -} - -// CHECK-LABEL: @uitofp_i32_f32 -func @uitofp_i32_f32(%arg0 : i32) -> f32 { - // CHECK: spv.ConvertUToF %{{.*}} : i32 to f32 - %0 = arith.uitofp %arg0 : i32 to f32 - return %0 : f32 -} - -// CHECK-LABEL: @uitofp_i1_f32 -func @uitofp_i1_f32(%arg0 : i1) -> f32 { - // CHECK: %[[ZERO:.+]] = spv.Constant 0.000000e+00 : f32 - // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32 - %0 = arith.uitofp %arg0 : i1 to f32 - return %0 : f32 -} - -// CHECK-LABEL: @uitofp_i1_f64 -func @uitofp_i1_f64(%arg0 : i1) -> f64 { - // CHECK: %[[ZERO:.+]] = spv.Constant 0.000000e+00 : f64 - // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f64 - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f64 - %0 = arith.uitofp %arg0 : i1 to f64 - return %0 : f64 -} - -// CHECK-LABEL: @uitofp_vec_i1_f32 -func @uitofp_vec_i1_f32(%arg0 : vector<4xi1>) -> vector<4xf32> { - // CHECK: %[[ZERO:.+]] = spv.Constant dense<0.000000e+00> : vector<4xf32> - // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<4xf32> - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32> - %0 = arith.uitofp %arg0 : vector<4xi1> to vector<4xf32> - return %0 : vector<4xf32> -} - -// CHECK-LABEL: @uitofp_vec_i1_f64 -spv.func @uitofp_vec_i1_f64(%arg0: vector<4xi1>) -> vector<4xf64> "None" { - // CHECK: %[[ZERO:.+]] = spv.Constant dense<0.000000e+00> : vector<4xf64> - // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<4xf64> - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf64> - %0 = spv.Constant dense<0.000000e+00> : vector<4xf64> - %1 = spv.Constant dense<1.000000e+00> : vector<4xf64> - %2 = spv.Select %arg0, %1, %0 : vector<4xi1>, vector<4xf64> - spv.ReturnValue %2 : vector<4xf64> -} - -// CHECK-LABEL: @sexti1 -func @sexti1(%arg0: i16) -> i64 { - // CHECK: spv.SConvert %{{.*}} : i16 to i64 - %0 = arith.extsi %arg0 : i16 to i64 - return %0 : i64 -} - -// CHECK-LABEL: @sexti2 -func @sexti2(%arg0 : i32) -> i64 { - // CHECK: spv.SConvert %{{.*}} : i32 to i64 - %0 = arith.extsi %arg0 : i32 to i64 - return %0 : i64 -} - -// CHECK-LABEL: @zexti1 -func @zexti1(%arg0: i16) -> i64 { - // CHECK: spv.UConvert %{{.*}} : i16 to i64 - %0 = arith.extui %arg0 : i16 to i64 - return %0 : i64 -} - -// CHECK-LABEL: @zexti2 -func @zexti2(%arg0 : i32) -> i64 { - // CHECK: spv.UConvert %{{.*}} : i32 to i64 - %0 = arith.extui %arg0 : i32 to i64 - return %0 : i64 -} - -// CHECK-LABEL: @zexti3 -func @zexti3(%arg0 : i1) -> i32 { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, i32 - %0 = arith.extui %arg0 : i1 to i32 - return %0 : i32 -} - -// CHECK-LABEL: @zexti4 -func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> { - // CHECK: %[[ZERO:.+]] = spv.Constant dense<0> : vector<4xi32> - // CHECK: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32> - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi32> - %0 = arith.extui %arg0 : vector<4xi1> to vector<4xi32> - return %0 : vector<4xi32> -} - -// CHECK-LABEL: @zexti5 -func @zexti5(%arg0 : vector<4xi1>) -> vector<4xi64> { - // CHECK: %[[ZERO:.+]] = spv.Constant dense<0> : vector<4xi64> - // CHECK: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi64> - // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi64> - %0 = arith.extui %arg0 : vector<4xi1> to vector<4xi64> - return %0 : vector<4xi64> -} - -// CHECK-LABEL: @trunci1 -func @trunci1(%arg0 : i64) -> i16 { - // CHECK: spv.SConvert %{{.*}} : i64 to i16 - %0 = arith.trunci %arg0 : i64 to i16 - return %0 : i16 -} - -// CHECK-LABEL: @trunci2 -func @trunci2(%arg0: i32) -> i16 { - // CHECK: spv.SConvert %{{.*}} : i32 to i16 - %0 = arith.trunci %arg0 : i32 to i16 - return %0 : i16 -} - -// CHECK-LABEL: @trunc_to_i1 -func @trunc_to_i1(%arg0: i32) -> i1 { - // CHECK: %[[MASK:.*]] = spv.Constant 1 : i32 - // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : i32 - // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : i32 - // CHECK-DAG: %[[TRUE:.*]] = spv.Constant true - // CHECK-DAG: %[[FALSE:.*]] = spv.Constant false - // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : i1, i1 - %0 = arith.trunci %arg0 : i32 to i1 - return %0 : i1 -} - -// CHECK-LABEL: @trunc_to_veci1 -func @trunc_to_veci1(%arg0: vector<4xi32>) -> vector<4xi1> { - // CHECK: %[[MASK:.*]] = spv.Constant dense<1> : vector<4xi32> - // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32> - // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : vector<4xi32> - // CHECK-DAG: %[[TRUE:.*]] = spv.Constant dense : vector<4xi1> - // CHECK-DAG: %[[FALSE:.*]] = spv.Constant dense : vector<4xi1> - // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : vector<4xi1>, vector<4xi1> - %0 = arith.trunci %arg0 : vector<4xi32> to vector<4xi1> - return %0 : vector<4xi1> -} - -// CHECK-LABEL: @fptosi1 -func @fptosi1(%arg0 : f32) -> i32 { - // CHECK: spv.ConvertFToS %{{.*}} : f32 to i32 - %0 = arith.fptosi %arg0 : f32 to i32 - return %0 : i32 -} - -// CHECK-LABEL: @fptosi2 -func @fptosi2(%arg0 : f16) -> i16 { - // CHECK: spv.ConvertFToS %{{.*}} : f16 to i16 - %0 = arith.fptosi %arg0 : f16 to i16 - return %0 : i16 -} - -} // end module - -// ----- - -// Checks that cast types will be adjusted when missing special capabilities for -// certain non-32-bit scalar types. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @fpext1 -// CHECK-SAME: %[[ARG:.*]]: f32 -func @fpext1(%arg0: f16) -> f64 { - // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64 - %0 = arith.extf %arg0 : f16 to f64 - return %0: f64 -} - -// CHECK-LABEL: @fpext2 -// CHECK-SAME: %[[ARG:.*]]: f32 -func @fpext2(%arg0 : f32) -> f64 { - // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64 - %0 = arith.extf %arg0 : f32 to f64 - return %0: f64 -} - -} // end module - -// ----- - -// Checks that cast types will be adjusted when missing special capabilities for -// certain non-32-bit scalar types. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: @fptrunc1 -// CHECK-SAME: %[[ARG:.*]]: f32 -func @fptrunc1(%arg0 : f64) -> f16 { - // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16 - %0 = arith.truncf %arg0 : f64 to f16 - return %0: f16 -} - -// CHECK-LABEL: @fptrunc2 -// CHECK-SAME: %[[ARG:.*]]: f32 -func @fptrunc2(%arg0: f32) -> f16 { - // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16 - %0 = arith.truncf %arg0 : f32 to f16 - return %0: f16 -} - -// CHECK-LABEL: @sitofp -func @sitofp(%arg0 : i64) -> f64 { - // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32 - %0 = arith.sitofp %arg0 : i64 to f64 - return %0: f64 -} - -} // end module - -// ----- - -//===----------------------------------------------------------------------===// -// func.return -//===----------------------------------------------------------------------===// - -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: spv.func @return_one_val -// CHECK-SAME: (%[[ARG:.+]]: f32) -func @return_one_val(%arg0: f32) -> f32 { - // CHECK: spv.ReturnValue %[[ARG]] : f32 - return %arg0: f32 -} - -// Check that multiple-return functions are not converted. -// CHECK-LABEL: func @return_multi_val -func @return_multi_val(%arg0: f32) -> (f32, f32) { - // CHECK: return - return %arg0, %arg0: f32, f32 -} - -} - -// ----- - -//===----------------------------------------------------------------------===// -// tensor.extract -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @tensor_extract_constant -// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32, %[[C:.+]]: i32) -func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 { - // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> - %cst = arith.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> - // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr, Function> - // CHECK: %[[C0:.+]] = spv.Constant 0 : i32 - // CHECK: %[[C6:.+]] = spv.Constant 6 : i32 - // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32 - // CHECK: %[[ADD0:.+]] = spv.IAdd %[[C0]], %[[MUL0]] : i32 - // CHECK: %[[C3:.+]] = spv.Constant 3 : i32 - // CHECK: %[[MUL1:.+]] = spv.IMul %[[C3]], %[[B]] : i32 - // CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[MUL1]] : i32 - // CHECK: %[[C1:.+]] = spv.Constant 1 : i32 - // CHECK: %[[MUL2:.+]] = spv.IMul %[[C1]], %[[C]] : i32 - // CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[MUL2]] : i32 - // CHECK: %[[AC:.+]] = spv.AccessChain %[[VAR]][%[[ADD2]]] - // CHECK: %[[VAL:.+]] = spv.Load "Function" %[[AC]] : i32 - %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32> - // CHECK: spv.ReturnValue %[[VAL]] - return %extract : i32 -} diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -split-input-file -convert-tensor-to-spirv -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// tensor.extract +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @tensor_extract_constant +// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32, %[[C:.+]]: i32) +func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 { + // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> + %cst = arith.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> + // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr, Function> + // CHECK: %[[C0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[C6:.+]] = spv.Constant 6 : i32 + // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32 + // CHECK: %[[ADD0:.+]] = spv.IAdd %[[C0]], %[[MUL0]] : i32 + // CHECK: %[[C3:.+]] = spv.Constant 3 : i32 + // CHECK: %[[MUL1:.+]] = spv.IMul %[[C3]], %[[B]] : i32 + // CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[MUL1]] : i32 + // CHECK: %[[C1:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL2:.+]] = spv.IMul %[[C1]], %[[C]] : i32 + // CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[MUL2]] : i32 + // CHECK: %[[AC:.+]] = spv.AccessChain %[[VAR]][%[[ADD2]]] + // CHECK: %[[VAL:.+]] = spv.Load "Function" %[[AC]] : i32 + %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32> + // CHECK: spv.ReturnValue %[[VAL]] + return %extract : i32 +} diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -12,13 +12,13 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h"