diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -31,6 +31,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -440,6 +440,20 @@ let constructor = "tosa::createTosaToLinalgOnTensors()"; } +//===----------------------------------------------------------------------===// +// TosaToStandard +//===----------------------------------------------------------------------===// + +def TosaToStandard : FunctionPass<"tosa-to-standard"> { + let summary = "Lower TOSA to the Standard dialect"; + let description = [{ + Pass that converts TOSA operations to the equivalent operations using the + operations in the Standard dialect. + }]; + + let constructor = "tosa::createTosaToStandard()"; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h @@ -0,0 +1,34 @@ +//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the passes for the TOSA to Standard Dialect conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H +#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaToStandard(); + +LogicalResult legalizeTosaControlFlow(Region ®ion); + +void populateTosaToStandardConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + +/// Populates passes to convert from TOSA to Standard. +void addTosaToStandardPasses(OpPassManager &pm); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -22,6 +22,7 @@ add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) add_subdirectory(TosaToLinalg) +add_subdirectory(TosaToStandard) add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRTosaToStandard + TosaToStandard.cpp + TosaToStandardPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRPass + MLIRTensor + MLIRTosa + MLIRTosaTransforms + MLIRSupport + ) diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -0,0 +1,189 @@ +//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect +//-----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These rewriters lower from the Tosa to the Standard dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tosa; + +namespace { + +class ConstOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::tosa::ConstOp op, + PatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, op.value()); + return success(); + } +}; + +// Replaces all tosa.yield ops with a branch to the correct target. This is +// used for processing each branch of tosa.if +LogicalResult replaceYieldOps(Region *region, Block *targetBlock, Location loc, + const BlockAndValueMapping &mapper, + OpBuilder *builder) { + for (auto &oldBlock : region->getBlocks()) { + Block *block = mapper.lookup(&oldBlock); + auto returnOp = dyn_cast(block->getTerminator()); + if (!returnOp) + continue; + builder->setInsertionPointToEnd(block); + builder->create(loc, targetBlock, returnOp.getOperands()); + returnOp.erase(); + } + + return success(); +} + +LogicalResult legalizeTosaIf(IfOp ifOp) { + Operation *operation = ifOp.getOperation(); + mlir::OpBuilder builder(ifOp); + auto *origBlock = operation->getBlock(); + auto *tailBlock = origBlock->splitBlock(operation); + auto loc = ifOp.getLoc(); + + // Initially we clone the then/else blocks at the current point in the + // region so we can branch to them in the conditional check. + BlockAndValueMapping mapper; + ifOp.then_branch().cloneInto(origBlock->getParent(), + Region::iterator(tailBlock), mapper); + ifOp.else_branch().cloneInto(origBlock->getParent(), + Region::iterator(tailBlock), mapper); + + Block *thenBranch = mapper.lookup(&ifOp.then_branch().front()); + Block *elseBranch = mapper.lookup(&ifOp.else_branch().front()); + + // Move to the conditional section, extract the value from the condition + // tensor, and execute the conditional branch to the then/else cases. + builder.setInsertionPointToEnd(origBlock); + auto condValue = builder.create(loc, ifOp.cond()); + builder.create(loc, condValue, thenBranch, ifOp.inputs(), + elseBranch, ifOp.inputs()); + + // Yields need to be replaced with branches to the tile block. + if (failed(replaceYieldOps(&ifOp.then_branch(), tailBlock, loc, mapper, + &builder))) + return failure(); + if (failed(replaceYieldOps(&ifOp.else_branch(), tailBlock, loc, mapper, + &builder))) + return failure(); + + tailBlock->addArguments(ifOp.getResultTypes()); + + for (auto it : llvm::enumerate(ifOp.getResults())) { + it.value().replaceAllUsesWith(tailBlock->getArgument(it.index())); + } + + operation->erase(); + return success(); +} + +LogicalResult legalizeTosaWhile(WhileOp whileOp) { + auto *operation = whileOp.getOperation(); + mlir::OpBuilder builder(whileOp); + auto loc = whileOp.getLoc(); + + auto *origBlock = operation->getBlock(); + auto *tailBlock = origBlock->splitBlock(operation); + + // Copy the condition and body checks into their separate sections, + // lookup the newly created copies. + BlockAndValueMapping mapper; + whileOp.cond().cloneInto(origBlock->getParent(), Region::iterator(tailBlock), + mapper); + whileOp.body().cloneInto(origBlock->getParent(), Region::iterator(tailBlock), + mapper); + + auto *condBlock = mapper.lookup(&whileOp.cond().front()); + auto *bodyBlock = mapper.lookup(&whileOp.body().front()); + + // Move back to before the cond/body blocks and branch to the conditiono op. + builder.setInsertionPointToEnd(origBlock); + builder.create(loc, condBlock, whileOp.getOperands()); + + // Update the inlined condition blocks by replacing any yield ops with an + // tensor.extract and conditional branch to the body block or tail. + builder.setInsertionPointToStart(condBlock); + for (auto &block : whileOp.cond()) { + auto *newBlock = mapper.lookup(&block); + + auto yieldOp = dyn_cast(newBlock->getTerminator()); + if (!yieldOp) + continue; + builder.setInsertionPointToEnd(newBlock); + + auto yieldValue = yieldOp.getOperand(0); + auto condValue = builder.create(loc, yieldValue); + + // Pass the args passed to the cond block as the arguments for each branch + // condition. + llvm::SmallVector successorArgs(condBlock->args_begin(), + condBlock->args_end()); + builder.create(loc, condValue, bodyBlock, successorArgs, + tailBlock, successorArgs); + yieldOp.erase(); + } + + // Replaces all yield s in the while op body with branches to the condition. + for (auto &block : whileOp.body()) { + auto *newBlock = mapper.lookup(&block); + auto yieldOp = dyn_cast(newBlock->getTerminator()); + if (!yieldOp) + continue; + builder.setInsertionPointToEnd(newBlock); + builder.create(loc, condBlock, yieldOp.getOperands()); + yieldOp.erase(); + } + + // Erase the original while loop. + tailBlock->addArguments(whileOp.getResultTypes()); + for (auto it : llvm::enumerate(whileOp.getResults())) { + it.value().replaceAllUsesWith(tailBlock->getArgument(it.index())); + } + operation->erase(); + + return success(); +} + +} // namespace + +void mlir::tosa::populateTosaToStandardConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +LogicalResult mlir::tosa::legalizeTosaControlFlow(Region ®ion) { + // Both While and If legalization involves inling referenced blocks. As a + // result, this may require multiple passes with inline blocks. + auto whileOps = llvm::to_vector<5>(region.getOps()); + for (auto whileOp : whileOps) { + if (failed(legalizeTosaWhile(whileOp))) + return failure(); + } + + auto ifOps = llvm::to_vector<5>(region.getOps()); + for (auto ifOp : ifOps) { + if (failed(legalizeTosaIf(ifOp))) + return failure(); + } + + return success(); +} diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -0,0 +1,60 @@ +//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes Tosa operations to the Standard dialect. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tosa; + +namespace { +struct TosaToStandard : public TosaToStandardBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + FuncOp func = getFunction(); + mlir::tosa::populateTosaToStandardConversionPatterns(func.getContext(), + &patterns); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + + if (failed(legalizeTosaControlFlow(func.body()))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::tosa::createTosaToStandard() { + return std::make_unique(); +} + +void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaToStandard()); +} diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s + + +// CHECK-LABEL: func @const_test +func @const_test() -> (tensor) { + // CHECK: [[C3:%.+]] = constant dense<3> : tensor + %0 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + + // CHECK: return [[C3]] + return %0 : tensor +} + +// ---- + +// CHECK-LABEL: func @while_test +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func @while_test(%arg0 : tensor) -> (tensor) { + // CHECK-NOT: "tosa.while_loop" + // CHECK: br [[B1:\^.+]]([[ARG0]] : tensor) + %1 = "tosa.while_loop"(%arg0) ( { + // CHECK: [[B1]]([[ARG1:%.+]]: tensor) + ^bb0(%arg2: tensor): // no predecessors + // CHECK: [[C3:%.+]] = constant dense<3> : tensor + // CHECK: [[GT:%.+]] = "tosa.greater_equal"([[C3]], [[ARG1]]) + // CHECK: [[EX:%.+]] = tensor.extract [[GT]][] + // CHECK: cond_br [[EX]], [[B2:\^.+]]([[ARG1]] : tensor), [[B3:\^.+]]([[ARG1]] : tensor) + %2 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + %3 = "tosa.greater_equal"(%2, %arg2) : (tensor, tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + // CHECK: [[B2]]([[ARG2:%.+]]: tensor) + ^bb0(%arg2: tensor): // no predecessors + // CHECK: [[C1:%.+]] = constant dense<1> : tensor + // CHECK: [[ADD:%.+]] = "tosa.add"([[ARG2]], [[C1:%.+]]) + // CHECK: br [[B1]]([[ADD]] : tensor) + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg2, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }) : (tensor) -> (tensor) + // CHECK: [[B3]]([[RESULT:%.+]]: tensor) + // CHECK: return [[RESULT]] + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: func @if_test +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor) +func @if_test(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> (tensor) { + // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]][] : tensor + // CHECK: cond_br [[EX]], [[B1:\^.+]]([[ARG0]], [[ARG1]] : tensor, tensor), [[B2:\^.+]]([[ARG0]], [[ARG1]] : tensor, tensor) + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + // CHECK: [[B1]]([[ARG3:%.+]]: tensor, [[ARG4:%.+]]: tensor) + ^bb1(%arg3 : tensor, %arg4 : tensor): + // CHECK: br [[B3:\^.+]]([[ARG3]] : tensor) + "tosa.yield"(%arg3) : (tensor) -> () + }, { + // CHECK: [[B2]]([[ARG5:%.+]]: tensor, [[ARG6:%.+]]: tensor): + ^bb1(%arg5 : tensor, %arg6 : tensor): + // CHECK: br [[B3]]([[ARG6]] : tensor) + "tosa.yield"(%arg6) : (tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor) + + // CHECK: [[B3]]([[ARG7:%.+]]: tensor): + // CHECK: return [[ARG7]] : tensor + return %0 : tensor +}