diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -58,6 +58,7 @@ #include "mlir/Conversion/TosaToSCF/TosaToSCF.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1054,6 +1054,18 @@ ]; } +//===----------------------------------------------------------------------===// +// UBToSPIRV +//===----------------------------------------------------------------------===// + +def UBToSPIRVConversionPass : Pass<"convert-ub-to-spirv"> { + let summary = "Convert UB dialect to SPIR-V dialect"; + let description = [{ + This pass converts supported UB ops to SPIR-V dialect instructions. + }]; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + //===----------------------------------------------------------------------===// // VectorToGPU //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/UBToSPIRV/UBToSPIRV.h @@ -0,0 +1,29 @@ +//===- UBToSPIRV.h - UB to SPIR-V dialect conversion ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_UBTOSPIRV_UBSPIRV_H +#define MLIR_CONVERSION_UBTOSPIRV_UBSPIRV_H + +#include + +namespace mlir { + +class SPIRVTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_UBTOSPIRVCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace ub { +void populateUBToSPIRVConversionPatterns(SPIRVTypeConverter &converter, + RewritePatternSet &patterns); +} // namespace ub +} // namespace mlir + +#endif // MLIR_CONVERSION_UBTOSPIRV_UBSPIRV_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -48,6 +48,7 @@ add_subdirectory(TosaToSCF) add_subdirectory(TosaToTensor) add_subdirectory(UBToLLVM) +add_subdirectory(UBToSPIRV) add_subdirectory(VectorToArmSME) add_subdirectory(VectorToGPU) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/UBToSPIRV/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRUBToSPIRV + UBToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/UBToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRSPIRVConversion + MLIRSPIRVDialect + MLIRUBDialect + ) diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp @@ -0,0 +1,91 @@ +//===- UBToSPIRV.cpp - UB to SPIRV-V dialect conversion -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +#define GEN_PASS_DEF_UBTOSPIRVCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +struct PoisonOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// PoisonOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type origType = op.getType(); + if (!origType.isIntOrIndexOrFloat()) + return rewriter.notifyMatchFailure( + op, [&](Diagnostic &diag) { diag << "unsupported type " << origType; }); + + Type resType = getTypeConverter()->convertType(origType); + if (!resType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "failed to convert result type " << origType; + }); + + rewriter.replaceOpWithNewOp(op, resType); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct UBToSPIRVConversionPass + : public impl::UBToSPIRVConversionPassBase { + using Base::Base; + + void runOnOperation() override { + Operation *op = getOperation(); + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVConversionOptions options; + SPIRVTypeConverter typeConverter(targetAttr, options); + + RewritePatternSet patterns(&getContext()); + ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, *target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::ub::populateUBToSPIRVConversionPatterns( + SPIRVTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter, patterns.getContext()); +} diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt -split-input-file -convert-ub-to-spirv -verify-diagnostics %s | FileCheck %s + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @check_poison +func.func @check_poison(%lhs: i32, %rhs: i32) { +// CHECK: {{.*}} = spirv.Undef : i32 + %0 = ub.poison : index +// CHECK: {{.*}} = spirv.Undef : i16 + %1 = ub.poison : i16 +// CHECK: {{.*}} = spirv.Undef : f64 + %2 = ub.poison : f64 + return +} + +}