diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -0,0 +1,37 @@ +//===-- TosaToLinalg.h - TOSA optimization pass declarations ----------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the passes for the TOSA Linalg Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H +#define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaToLinalgOnTensors(); + +/// Populates passes to convert from TOSA to Linalg on buffers. At the end of +/// the pass, the function will only contain linalg ops or standard ops if the +/// pipeline succeeds. +void addTosaToLinalgOnTensorsPasses(OpPassManager &pm); + +/// Populates conversion passes from TOSA dialect to Linalg dialect. +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -20,6 +20,7 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); +std::unique_ptr createTosaToLinalgOnTensors(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file declares the optimization passes for the TOSA Dialect in MLIR. +// This file declares the optimization and lowerings passes for the TOSA Dialect +// in MLIR. // //===----------------------------------------------------------------------===// @@ -24,3 +25,14 @@ let constructor = "createTosaMakeBroadcastablePass()"; } + +def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> { + let summary = "Lower TOSA to LinAlg on tensors"; + let description = [{ + Pass that converts TOSA operations to the equivalent operations using the + tensor operations in LinAlg. + }]; + + let constructor = "createTosaToLinalgOnTensors()"; +} + diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(TosaToLinalg) add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRTosaToLinalg + TosaToLinalg.cpp + TosaToLinalgPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalg + MLIRLinalgUtils + MLIRPass + MLIRTosa + MLIRTosaTransforms + MLIRSupport + ) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -0,0 +1,203 @@ +//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These rewriters lower from the Tosa to the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { + +namespace { + +SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { + static constexpr StringRef kParallelIterType = "parallel"; + return SmallVector(nParallelLoops, kParallelIterType); +} + +Value getInitTensor(OpBuilder &b, Location loc, ShapedType type, + SmallVectorImpl &dynSizes) { + return b.create(loc, dynSizes, type.getShape(), + type.getElementType()); +} + +// TODO(pifon): This logic is used everywhere, the code should be shared. +SmallVector extractDynamicSizes(OpBuilder &b, Location loc, + Value tensor) { + auto tensorType = tensor.getType().dyn_cast(); + if (!tensorType) + return {}; + SmallVector dynSizes; + for (auto &en : llvm::enumerate(tensorType.getShape())) { + if (en.value() != ShapedType::kDynamicSize) + continue; + dynSizes.push_back(b.create(loc, tensor, en.index())); + } + return dynSizes; +} + +template +class PointwiseConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + virtual Operation * + createGenericOp(Location loc, int nloops, ArrayRef indexingMaps, + ArrayRef args, ArrayRef outputBuffers, + ArrayRef opResultTypes, ArrayRef bodyArgTypes, + ArrayRef bodyResultTypes, + ConversionPatternRewriter &rewriter) const = 0; + + LogicalResult + matchAndRewriteHelper(Operation *operation, ArrayRef args, + ConversionPatternRewriter &rewriter) const { + auto loc = operation->getLoc(); + auto results = operation->getResults(); + auto t0 = args[0].getType().template dyn_cast(); + if (!t0) + return rewriter.notifyMatchFailure(operation, + "All results must be a shaped type"); + + // For now require no broadcasting. Consider making it support broadcasting + // operations. + auto uniqueTy = operation->getOperand(0).getType(); + bool allInputTypesEqual = + llvm::all_of(operation->getOperandTypes(), + [&](Type operandTy) { return operandTy == uniqueTy; }); + if (!allInputTypesEqual) + return rewriter.notifyMatchFailure( + operation, "All operands must have the same type"); + bool allResultTypesEqual = + llvm::all_of(operation->getResultTypes(), + [&](Type resultTy) { return resultTy == uniqueTy; }); + if (!allResultTypesEqual) + return rewriter.notifyMatchFailure( + operation, "All results must have the same type as the input"); + + unsigned nloops = t0.getRank(); + auto checkValidShapedType = [&](ShapedType t) { + return !t || !t.hasRank() || t.getRank() != nloops || + !(t.getElementType().isSignlessIntOrFloat() || + t.getElementType().isa()); + }; + + if (llvm::any_of(args, + [&](Value v) { + return checkValidShapedType( + v.getType().dyn_cast()); + }) || + llvm::any_of(operation->getResultTypes(), [&](Type t) { + return checkValidShapedType(t.dyn_cast()); + })) + return emitError(loc, "tosa to linalg conversion expects ranked args of " + "signless int, float or complex element type with ") + << nloops << " parallel iterators: " << *(operation); + + // Construct the indexing maps needed for linalg.generic ops. + SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; + + for (Value in : args) + bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + + SmallVector outputBuffers; + for (auto result : results) { + auto resultType = result.getType().template cast(); + + auto dynSizes = extractDynamicSizes(rewriter, loc, args[0]); + outputBuffers.push_back( + getInitTensor(rewriter, loc, resultType, dynSizes)); + opResultTypes.push_back(result.getType()); + } + + bodyResultTypes = llvm::to_vector<4>(llvm::map_range( + outputBuffers, [](Value v) { return getElementTypeOrSelf(v); })); + + // Supports only non-broadcasted operation. Shoudl consider update indexing + // map to be multidimensional. + AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops); + SmallVector indexingMaps(args.size() + bodyResultTypes.size(), + commonIndexingMap); + + auto linalgOp = + createGenericOp(loc, nloops, indexingMaps, args, outputBuffers, + opResultTypes, bodyArgTypes, bodyResultTypes, rewriter); + if (!linalgOp) + return failure(); + + rewriter.replaceOp(operation, linalgOp->getResults()); + return success(); + } + + LogicalResult + matchAndRewrite(SrcOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + auto operation = op.getOperation(); + return matchAndRewriteHelper(operation, args, rewriter); + } +}; + +template +class SimplePointwiseConverter : public PointwiseConverter { +public: + using PointwiseConverter::PointwiseConverter; + + Operation * + createGenericOp(Location loc, int nloops, ArrayRef indexingMaps, + ArrayRef args, ArrayRef outputBuffers, + ArrayRef opResultTypes, ArrayRef bodyArgTypes, + ArrayRef bodyResultTypes, + ConversionPatternRewriter &rewriter) const override { + if (!bodyResultTypes.front().isa()) + return nullptr; + + return rewriter.create( + loc, opResultTypes, ValueRange(args), outputBuffers, indexingMaps, + getNParallelLoopsAttrs(nloops), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value opResult = rewriter.create( + loc, bodyResultTypes, + llvm::to_vector<2>(blockArgs.take_front(args.size()))); + nestedBuilder.create(loc, opResult); + }); + } +}; + +} // namespace + +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert< + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter>(context); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -0,0 +1,62 @@ +//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes Tosa operations to the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { + +namespace { +struct TosaToLinalgOnTensors + : public TosaToLinalgOnTensorsBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + auto func = getFunction(); + populateTosaToLinalgOnTensorsConversionPatterns(func.getContext(), + &patterns); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr createTosaToLinalgOnTensors() { + return std::make_unique(); +} + +void addTosaToLinalgOnTensorsPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaMakeBroadcastablePass()); + pm.addNestedPass(createTosaToLinalgOnTensors()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -0,0 +1,406 @@ +// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + + %0 = "tosa.abs"(%arg0) : (tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + + +// ----- + +// CHECK-LABEL: @test_tanh +func @test_tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = tanh %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{failed to legalize operation 'tosa.add'}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = addi %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = subf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + + %0 = "tosa.sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = subf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = subi %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +func @test_sub(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{failed to legalize operation 'tosa.sub'}} + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @test_pow +func @test_pow(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = powf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +func @test_pow(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{failed to legalize operation 'tosa.pow'}} + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @test_and +func @test_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = and %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor + + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_and +func @test_and(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = and %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +func @test_and(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + // expected-error @+1 {{failed to legalize operation 'tosa.bitwise_and'}} + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @test_or +func @test_or(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = or %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_or +func @test_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = or %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor + + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +func @test_or(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + // expected-error @+1 {{failed to legalize operation 'tosa.bitwise_or'}} + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @test_xor +func @test_xor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = xor %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_xor +func @test_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = xor %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor + + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +func @test_xor(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + // expected-error @+1 {{failed to legalize operation 'tosa.bitwise_xor'}} + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_left_shift +func @test_logical_left_shift(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = shift_left %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_left_shift +func @test_logical_left_shift(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = shift_left %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor + + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +func @test_logical_left_shift(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + // expected-error @+1 {{failed to legalize operation 'tosa.logical_left_shift'}} + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_right_shift +func @test_logical_right_shift(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = shift_right_unsigned %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_right_shift +func @test_logical_right_shift(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): + // CHECK: [[ELEMENT:%.+]] = shift_right_unsigned %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor + + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +func @test_logical_right_shift(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + // expected-error @+1 {{failed to legalize operation 'tosa.logical_right_shift'}} + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} +