diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -31,6 +31,8 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Conversion/TosaToSCF/TosaToSCF.h" +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -440,6 +440,34 @@ let constructor = "tosa::createTosaToLinalgOnTensors()"; } +//===----------------------------------------------------------------------===// +// TosaToSCF +//===----------------------------------------------------------------------===// + +def TosaToSCF : FunctionPass<"tosa-to-scf"> { + let summary = "Lower TOSA to the SCF dialect"; + let description = [{ + Pass that converts TOSA's control flow operations to the equivalent SCF + operations. + }]; + + let constructor = "tosa::createTosaToSCF()"; +} + +//===----------------------------------------------------------------------===// +// TosaToStandard +//===----------------------------------------------------------------------===// + +def TosaToStandard : FunctionPass<"tosa-to-standard"> { + let summary = "Lower TOSA to the Standard dialect"; + let description = [{ + Pass that converts TOSA operations to the equivalent operations using the + operations in the Standard dialect. + }]; + + let constructor = "tosa::createTosaToStandard()"; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h @@ -0,0 +1,32 @@ +//===-- TosaToSCF.h - TOSA to SCF dialect lowerings -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the passes for the TOSA to SCF Dialect conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H +#define MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaToSCF(); + +void populateTosaToSCFConversionPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +/// Populates passes to convert from TOSA to SCF. +void addTosaToSCFPasses(OpPassManager &pm); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h @@ -0,0 +1,34 @@ +//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the passes for the TOSA to Standard Dialect conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H +#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaToStandard(); + +LogicalResult legalizeTosaControlFlow(Region ®ion); + +void populateTosaToStandardConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + +/// Populates passes to convert from TOSA to Standard. +void addTosaToStandardPasses(OpPassManager &pm); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -22,6 +22,8 @@ add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) add_subdirectory(TosaToLinalg) +add_subdirectory(TosaToSCF) +add_subdirectory(TosaToStandard) add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRTosaToSCF + TosaToSCF.cpp + TosaToSCFPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSCF + MLIRStandard + MLIRPass + MLIRTensor + MLIRTosa + MLIRTosaTransforms + MLIRSupport + ) diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -0,0 +1,114 @@ +//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These rewriters lower from the Tosa to the SCF dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TosaToSCF/TosaToSCF.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tosa; + +namespace { + +void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, + PatternRewriter &rewriter) { + BlockAndValueMapping mapper; + dstRegion.takeBody(srcRegion); + auto headBlock = &dstRegion.front(); + for (auto it : llvm::zip(headBlock->getArguments(), operands)) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + + for (auto &block : dstRegion) { + llvm::SmallVector toDelete; + block.walk([&](tosa::YieldOp yield) { + rewriter.setInsertionPoint(yield); + rewriter.create(yield.getLoc(), yield.inputs()); + toDelete.push_back(yield); + }); + for (Operation *val : toDelete) + val->erase(); + } + + headBlock->eraseArguments( + llvm::to_vector<4>(llvm::seq(0, headBlock->getNumArguments()))); +} + +class IfOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::IfOp op, + PatternRewriter &rewriter) const final { + auto condition = rewriter.create(op.getLoc(), op.cond()); + auto newIf = rewriter.replaceOpWithNewOp(op, op.getResultTypes(), + condition, true); + + inlineIfCase(op.then_branch(), newIf.thenRegion(), op.inputs(), rewriter); + inlineIfCase(op.else_branch(), newIf.elseRegion(), op.inputs(), rewriter); + return success(); + } +}; + +void inlineWhileCase(Region &srcRegion, Region &dstRegion, + OperandRange operands, PatternRewriter &rewriter, + bool isCond) { + BlockAndValueMapping mapper; + dstRegion.takeBody(srcRegion); + + for (auto &block : dstRegion) { + llvm::SmallVector toDelete; + block.walk([&](tosa::YieldOp yield) { + rewriter.setInsertionPoint(yield); + if (isCond) { + auto condition = rewriter.create( + yield.getLoc(), yield.getOperand(0)); + rewriter.create(yield.getLoc(), condition, + block.getArguments()); + } else { + rewriter.setInsertionPoint(yield); + rewriter.create(yield.getLoc(), yield.inputs()); + } + toDelete.push_back(yield); + }); + for (Operation *val : toDelete) + val->erase(); + } +} + +class WhileOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::WhileOp op, + PatternRewriter &rewriter) const final { + auto newWhile = rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.inputs()); + + inlineWhileCase(op.cond(), newWhile.before(), op.inputs(), rewriter, true); + inlineWhileCase(op.body(), newWhile.after(), op.inputs(), rewriter, false); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaToSCFConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); + patterns->insert(context); +} diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp @@ -0,0 +1,58 @@ +//===- TosaToSCFPass.cpp - Lowering Tosa to SCF Dialect -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes Tosa operations to the SCF dialect. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/TosaToSCF/TosaToSCF.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tosa; + +namespace { +struct TosaToSCF : public TosaToSCFBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + TypeConverter typeConverter; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + FuncOp func = getFunction(); + mlir::tosa::populateTosaToSCFConversionPatterns(func.getContext(), + &patterns); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::tosa::createTosaToSCF() { + return std::make_unique(); +} + +void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaToSCF()); +} diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRTosaToStandard + TosaToStandard.cpp + TosaToStandardPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRPass + MLIRTosa + MLIRTosaTransforms + MLIRSupport + ) diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -0,0 +1,41 @@ +//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect +//-----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These rewriters lower from the Tosa to the Standard dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tosa; + +namespace { + +class ConstOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConstOp op, + PatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value()); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaToStandardConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); +} diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -0,0 +1,57 @@ +//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes Tosa operations to the Standard dialect. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tosa; + +namespace { +struct TosaToStandard : public TosaToStandardBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + FuncOp func = getFunction(); + mlir::tosa::populateTosaToStandardConversionPatterns(func.getContext(), + &patterns); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::tosa::createTosaToStandard() { + return std::make_unique(); +} + +void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaToStandard()); +} diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: func @while_test +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func @while_test(%arg0 : tensor) -> (tensor) { + // CHECK: [[WHILE:%.+]] = scf.while ([[ARG1:%.+]] = [[ARG0]]) + %1 = "tosa.while_loop"(%arg0) ( { + ^bb0(%arg2: tensor): + // CHECK: "tosa.const" + %2 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: [[COMPARE:%.+]] = "tosa.greater_equal" + %3 = "tosa.greater_equal"(%2, %arg2) : (tensor, tensor) -> tensor + + // CHECK: [[EX:%.+]] = tensor.extract [[COMPARE]] + // CHECK: scf.condition([[EX]]) [[ARG1]] + "tosa.yield"(%3) : (tensor) -> () + }, { + // CHECK: ^bb0([[ARG1:%.+]]: tensor) + ^bb0(%arg2: tensor): + // CHECK: tosa.const + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + + // CHECK: [[ADD:%.+]] = "tosa.add" + %3 = "tosa.add"(%arg2, %2) : (tensor, tensor) -> tensor + + // CHECK: scf.yield [[ADD]] + "tosa.yield"(%3) : (tensor) -> () + }) : (tensor) -> (tensor) + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: func @if_test +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor) +func @if_test(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> (tensor) { +// CHECK: [[EX:%.+]] = tensor.extract [[ARG2]] +// CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor) { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + + // CHECK: scf.yield [[ARG0]] + ^bb1(%arg3 : tensor, %arg4 : tensor): + "tosa.yield"(%arg3) : (tensor) -> () + + // CHECK: } else { + }, { + + // CHECK: scf.yield [[ARG1]] + ^bb1(%arg5 : tensor, %arg6 : tensor): + "tosa.yield"(%arg6) : (tensor) -> () + + // CHECK: } + // CHECK: return [[IF]] + }) : (tensor, tensor, tensor) -> (tensor) + + return %0 : tensor +} diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: func @const_test +func @const_test() -> (tensor) { + // CHECK: [[C3:%.+]] = constant dense<3> : tensor + %0 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: return [[C3]] + return %0 : tensor +}