diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" +#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -416,6 +416,20 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// TosaToLinalg +//===----------------------------------------------------------------------===// + +def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> { + let summary = "Lower TOSA to LinAlg on tensors"; + let description = [{ + Pass that converts TOSA operations to the equivalent operations using the + tensor operations in LinAlg. + }]; + + let constructor = "tosa::createTosaToLinalgOnTensors()"; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -0,0 +1,36 @@ +//===-- TosaToLinalg.h - TOSA optimization pass declarations ----------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the passes for the TOSA Linalg Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H +#define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaToLinalgOnTensors(); + +/// Populates passes to convert from TOSA to Linalg on buffers. At the end of +/// the pass, the function will only contain linalg ops or standard ops if the +/// pipeline succeeds. +void addTosaToLinalgOnTensorsPasses(OpPassManager &pm); + +/// Populates conversion passes from TOSA dialect to Linalg dialect. +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(TosaToLinalg) add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRTosaToLinalg + TosaToLinalg.cpp + TosaToLinalgPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalg + MLIRLinalgUtils + MLIRPass + MLIRTosa + MLIRTosaTransforms + MLIRSupport + ) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -0,0 +1,218 @@ +//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These rewriters lower from the Tosa to the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +static SmallVector +getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, + getParallelIteratorTypeName()); +} + +static Value getInitTensor(OpBuilder &b, Location loc, ShapedType type) { + return b.create(loc, ArrayRef({}), + type.getShape(), type.getElementType()); +} + +static Value +createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, + ArrayRef resultTypes, + PatternRewriter &rewriter) { + Location loc = op->getLoc(); + auto elementTy = + op->getResult(0).getType().cast().getElementType(); + + // tosa::AbsOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::AddOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::BitwiseAndOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::BitwiseOrOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::BitwiseXOrOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::LogicalLeftShiftOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::LogicalrightShiftOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::PowOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::SubOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + // tosa::TanhOp + if (isa(op) && elementTy.isa()) { + return rewriter.create(loc, resultTypes, args); + } + + op->dump(); + llvm::report_fatal_error("unhandled op (see dump above): linalg body " + "calculation for elementwise op"); + return nullptr; +} + +namespace { + +template +class PointwiseConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewriteHelper(Operation *operation, + PatternRewriter &rewriter) const { + auto loc = operation->getLoc(); + auto results = operation->getResults(); + auto t0 = + operation->getOperand(0).getType().template dyn_cast(); + if (!t0) + return rewriter.notifyMatchFailure(operation, + "All results must be a shaped type"); + + // For now require no broadcasting. Consider making it support broadcasting + // operations. + Type uniqueTy = operation->getOperand(0).getType(); + bool allInputTypesEqual = + llvm::all_of(operation->getOperandTypes(), + [&](Type operandTy) { return operandTy == uniqueTy; }); + if (!allInputTypesEqual) + return rewriter.notifyMatchFailure( + operation, "All operands must have the same type"); + bool allResultTypesEqual = + llvm::all_of(operation->getResultTypes(), + [&](Type resultTy) { return resultTy == uniqueTy; }); + if (!allResultTypesEqual) + return rewriter.notifyMatchFailure( + operation, "All results must have the same type as the input"); + + unsigned nloops = t0.getRank(); + auto checkValidShapedType = [&](Type t) { + if (!t) + return false; + auto st = t.cast(); + return !st || !st.hasRank() || + !st.getElementType().isSignlessIntOrFloat(); + }; + + auto checkValidShapedValue = [&](Value v) { + return checkValidShapedType(v.getType()); + }; + + if (llvm::any_of(operation->getOperands(), checkValidShapedValue) || + llvm::any_of(operation->getResultTypes(), checkValidShapedType)) + return emitError(loc, "tosa to linalg conversion expects ranked args of " + "signless int, float or complex element type with ") + << nloops << " parallel iterators: " << *(operation); + + // Construct the indexing maps needed for linalg.generic ops. + SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; + + for (Value in : operation->getOperands()) + bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + + SmallVector outputBuffers; + for (auto result : results) { + auto resultType = result.getType().template cast(); + if (!resultType.hasStaticShape()) + return rewriter.notifyMatchFailure( + operation, + "tosa to linalg conversion expects statically shaped tensors"); + + outputBuffers.push_back(getInitTensor(rewriter, loc, resultType)); + opResultTypes.push_back(result.getType()); + } + + bodyResultTypes = llvm::to_vector<4>(llvm::map_range( + outputBuffers, [](Value v) { return getElementTypeOrSelf(v); })); + + // Supports only non-broadcasted operation. Shoudl consider update indexing + // map to be multidimensional. + AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops); + SmallVector indexingMaps(operation->getNumOperands() + + bodyResultTypes.size(), + commonIndexingMap); + + auto linalgOp = rewriter.create( + loc, opResultTypes, operation->getOperands(), outputBuffers, + indexingMaps, getNParallelLoopsAttrs(nloops), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value opResult = createLinalgBodyCalculationForElementwiseOp( + operation, blockArgs.take_front(operation->getNumOperands()), + bodyResultTypes, rewriter); + nestedBuilder.create(loc, opResult); + }); + + rewriter.replaceOp(operation, linalgOp->getResults()); + return success(); + } + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const final { + return matchAndRewriteHelper(op, rewriter); + } +}; + +} // namespace + +void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert< + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter>(context); +} diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -0,0 +1,59 @@ +//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes Tosa operations to the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TosaToLinalgOnTensors + : public TosaToLinalgOnTensorsBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + FuncOp func = getFunction(); + mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( + func.getContext(), &patterns); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::tosa::createTosaToLinalgOnTensors() { + return std::make_unique(); +} + +void mlir::tosa::addTosaToLinalgOnTensorsPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaMakeBroadcastablePass()); + pm.addNestedPass(createTosaToLinalgOnTensors()); +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -0,0 +1,124 @@ +// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s + +// CHECK: #map = affine_map<() -> ()> + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + + %0 = "tosa.abs"(%arg0) : (tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK: #map = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1x2xf32> + %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1x2xf32> +} + +// ----- + +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{failed to legalize operation 'tosa.add'}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @test_simple_f32 +func @test_simple_f32(%arg0: tensor<1xf32>) -> () { + // CHECK: linalg.generic + // CHECK: tanh + %0 = "tosa.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: absf + %1 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: addf + %2 = "tosa.add"(%0, %0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: subf + %3 = "tosa.sub"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: linalg.generic + // CHECK: pow + %4 = "tosa.pow"(%1, %2) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return +} + +// ----- + +// CHECK-LABEL: @test_simple_i32 +func @test_simple_i32(%arg0: tensor<1xi32>) -> () { + // CHECK: linalg.generic + // CHECK: addi + %0 = "tosa.add"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: subi + %1 = "tosa.sub"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: and + %2 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: or + %3 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: xor + %4 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: shift_left + %5 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: linalg.generic + // CHECK: shift_right_unsigned + %6 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + return +} +