diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -20,10 +21,21 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); +std::unique_ptr createTosaToLinalgOnTensors(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +/// Populates passes to convert from TOSA to Linalg on buffers. At the end of +/// the pass, the dispatch function will only contain linalg ops or standard ops +/// if the pipeline succeeds. The pass manager `pm` passed in here is expected +/// to operate on the module within the IREE::HAL::ExecutableTargetOp. +void addTosaToLinalgOnTensorsPasses(OpPassManager &pm); + +/// Populates conversion passes from TOSA dialect to Linalg dialect. +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + } // namespace tosa } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file declares the optimization passes for the TOSA Dialect in MLIR. +// This file declares the optimization and lowerings passes for the TOSA Dialect +// in MLIR. // //===----------------------------------------------------------------------===// @@ -24,3 +25,14 @@ let constructor = "createTosaMakeBroadcastablePass()"; } + +def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> { + let summary = "Lower TOSA to LinAlg on tensors"; + let description = [{ + Pass that converts TOSA operations to the equivalent operations using the + tensor operations in LinAlg. + }]; + + let constructor = "createTosaToLinalgOnTensors()"; +} + diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRTosaTransforms + Passes.cpp TosaMakeBroadcastable.cpp + TosaToLinalgOnTensors.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/Passes.cpp b/mlir/lib/Dialect/Tosa/Transforms/Passes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/Passes.cpp @@ -0,0 +1,28 @@ +//===-- Passes.cpp - TOSA optimization pass declarations ----------*- C++ +//-*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the optimization nd lowering passes for the TOSA Dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace tosa { + +void addTosaToLinalgOnTensorsPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaMakeBroadcastablePass()); + pm.addNestedPass(createTosaToLinalgOnTensors()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaToLinalgOnTensors.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaToLinalgOnTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaToLinalgOnTensors.cpp @@ -0,0 +1,202 @@ +//===- TosaToLinalgOnTensors.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides lowerings from TOSA dialect to Linalg dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { + +namespace { + +SmallVector GetNParallelLoopsAttrs(unsigned nParallelLoops) { + static constexpr StringRef kParallelIterType = "parallel"; + return SmallVector(nParallelLoops, kParallelIterType); +} + +Value GetInitTensor(OpBuilder &b, Location loc, ShapedType type, + SmallVectorImpl &dynSizes) { + return b.create(loc, dynSizes, type.getShape(), + type.getElementType()); +} + +// TODO(pifon): This logic is used everywhere, the code should be shared. +SmallVector ExtractDynamicSizes(OpBuilder &b, Location loc, + Value tensor) { + auto tensorType = tensor.getType().dyn_cast(); + if (!tensorType) + return {}; + SmallVector dynSizes; + for (auto &en : llvm::enumerate(tensorType.getShape())) { + if (en.value() != ShapedType::kDynamicSize) + continue; + dynSizes.push_back(b.create(loc, tensor, en.index())); + } + return dynSizes; +} + +template +class PointwiseConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SrcOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + auto operation = op.getOperation(); + auto results = operation->getResults(); + ShapedType t0 = args[0].getType().template dyn_cast(); + if (!t0) + return failure(); + + // For now require no broadcasting. Consider making it support broadcasting + // operations. + for (auto inputTy : operation->getOperandTypes()) { + for (auto resultTy : operation->getResultTypes()) { + if (inputTy != resultTy) + return failure(); + } + } + + unsigned nloops = t0.getRank(); + auto fail = [&](ShapedType t) { + return !t || !t.hasRank() || t.getRank() != nloops || + !(t.getElementType().isSignlessIntOrFloat() || + t.getElementType().isa()); + }; + + if (llvm::any_of(args, + [&](Value v) { + return fail(v.getType().dyn_cast()); + }) || + llvm::any_of(op.getOperation()->getResultTypes(), + [&](Type t) { return fail(t.dyn_cast()); })) + return emitError(loc, "tosa to linalg conversion expects ranked args of " + "signless int, float or complex element type with ") + << nloops << " parallel iterators: " << *(op.getOperation()); + + // Construct the indexing maps needed for linalg.generic ops. + SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; + + for (Value in : args) { + bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + } + + SmallVector outputBuffers; + for (auto result : results) { + ShapedType resultType = result.getType().template cast(); + Type elementType = resultType.getElementType(); + if (!elementType.isa()) + return failure(); + + auto dynSizes = ExtractDynamicSizes(rewriter, loc, args[0]); + outputBuffers.push_back( + GetInitTensor(rewriter, loc, resultType, dynSizes)); + opResultTypes.push_back(result.getType()); + } + + bodyResultTypes = llvm::to_vector<4>(llvm::map_range( + outputBuffers, [](Value v) { return getElementTypeOrSelf(v); })); + + // Supports only non-broadcasted operation. Shoudl consider update indexing + // map to be multidimensional. + AffineMap commonIndexingMap = + nloops ? rewriter.getMultiDimIdentityMap(nloops) + : AffineMap::get(nloops, 0, rewriter.getContext()); + SmallVector indexingMaps(args.size() + bodyResultTypes.size(), + commonIndexingMap); + + bool failed = false; + auto linalgOp = rewriter.create( + loc, opResultTypes, ValueRange(args), outputBuffers, indexingMaps, + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value opResult = rewriter.create( + loc, bodyResultTypes, + llvm::to_vector<2>(blockArgs.take_front(args.size()))); + nestedBuilder.create(loc, opResult); + }); + if (failed) + return failure(); + rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + return success(); + } +}; + +struct TosaToLinalgOnTensors + : public TosaToLinalgOnTensorsBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + + auto func = getFunction(); + populateTosaToLinalgOnTensorsConversionPatterns(func.getContext(), + &patterns); + if (failed(applyPartialConversion(func, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert>( + context); + patterns->insert>( + context); + patterns->insert>( + context); + patterns->insert>( + context); + patterns->insert>( + context); + + patterns->insert>( + context); + patterns->insert< + PointwiseConverter>( + context); + patterns + ->insert>( + context); + patterns->insert< + PointwiseConverter>( + context); + patterns->insert>(context); + patterns->insert>(context); + patterns->insert>( + context); +} + +std::unique_ptr createTosaToLinalgOnTensors() { + return std::make_unique(); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/test/Dialect/Tosa/to_linalg_on_tensors.mlir b/mlir/test/Dialect/Tosa/to_linalg_on_tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/to_linalg_on_tensors.mlir @@ -0,0 +1,223 @@ +// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_tanh +func @test_tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = tanh %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = addi %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = subf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = subi %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_pow +func @test_pow(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = powf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<1xf32> + + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_and +func @test_and(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = and %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_or +func @test_or(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = or %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_xor +func @test_xor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = xor %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_left_shift +func @test_logical_left_shift(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = shift_left %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_right_shift +func @test_logical_right_shift(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xi32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xi32>, tensor<1xi32>) outs([[INIT]] : tensor<1xi32>) { + // CHECK: ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + // CHECK: [[ELEMENT:%.+]] = shift_right_unsigned %arg2, %arg3 : i32 + // CHECK: linalg.yield [[ELEMENT]] : i32 + // CHECK: } -> tensor<1xi32> + + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +}