diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -0,0 +1,36 @@ +//===-- TosaToLinalg.h - TOSA optimization pass declarations ----------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the passes for the TOSA Linalg Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H +#define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +std::unique_ptr createTosaToLinalgOnTensors(); + +/// Populates passes to convert from TOSA to Linalg on buffers. At the end of +/// the pass, the function will only contain linalg ops or standard ops if the +/// pipeline succeeds. +void addTosaToLinalgOnTensorsPasses(OpPassManager &pm); + +/// Populates conversion passes from TOSA dialect to Linalg dialect. +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -20,6 +20,7 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); +std::unique_ptr createTosaToLinalgOnTensors(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file declares the optimization passes for the TOSA Dialect in MLIR. +// This file declares the optimization and lowerings passes for the TOSA Dialect +// in MLIR. // //===----------------------------------------------------------------------===// @@ -24,3 +25,14 @@ let constructor = "createTosaMakeBroadcastablePass()"; } + +def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> { + let summary = "Lower TOSA to LinAlg on tensors"; + let description = [{ + Pass that converts TOSA operations to the equivalent operations using the + tensor operations in LinAlg. + }]; + + let constructor = "createTosaToLinalgOnTensors()"; +} + diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(TosaToLinalg) add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRTosaToLinalg + TosaToLinalg.cpp + TosaToLinalgPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalg + MLIRLinalgUtils + MLIRPass + MLIRTosa + MLIRTosaTransforms + MLIRSupport + ) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -0,0 +1,190 @@ +//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// These rewriters lower from the Tosa to the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, + getParallelIteratorTypeName()); +} + +Value getInitTensor(OpBuilder &b, Location loc, ShapedType type) { + return b.create(loc, ArrayRef({}), + type.getShape(), type.getElementType()); +} + +} // namespace + +namespace mlir { +namespace tosa { +namespace { + +template +class PointwiseConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + virtual Operation * + createGenericOp(Location loc, int nloops, ArrayRef indexingMaps, + ArrayRef args, ArrayRef outputBuffers, + ArrayRef opResultTypes, ArrayRef bodyArgTypes, + ArrayRef bodyResultTypes, + ConversionPatternRewriter &rewriter) const = 0; + + LogicalResult + matchAndRewriteHelper(Operation *operation, ArrayRef args, + ConversionPatternRewriter &rewriter) const { + auto loc = operation->getLoc(); + auto results = operation->getResults(); + auto t0 = args[0].getType().template dyn_cast(); + if (!t0) + return rewriter.notifyMatchFailure(operation, + "All results must be a shaped type"); + + // For now require no broadcasting. Consider making it support broadcasting + // operations. + auto uniqueTy = operation->getOperand(0).getType(); + bool allInputTypesEqual = + llvm::all_of(operation->getOperandTypes(), + [&](Type operandTy) { return operandTy == uniqueTy; }); + if (!allInputTypesEqual) + return rewriter.notifyMatchFailure( + operation, "All operands must have the same type"); + bool allResultTypesEqual = + llvm::all_of(operation->getResultTypes(), + [&](Type resultTy) { return resultTy == uniqueTy; }); + if (!allResultTypesEqual) + return rewriter.notifyMatchFailure( + operation, "All results must have the same type as the input"); + + unsigned nloops = t0.getRank(); + auto checkValidShapedType = [&](Type t) { + auto st = t.dyn_cast(); + return !t || !st || !st.hasRank() || + !st.getElementType().isSignlessIntOrFloat(); + }; + + auto checkValidShapedValue = [&](Value v) { + return checkValidShapedType(v.getType()); + }; + + if (llvm::any_of(args, checkValidShapedValue) || + llvm::any_of(operation->getResultTypes(), checkValidShapedType)) + return emitError(loc, "tosa to linalg conversion expects ranked args of " + "signless int, float or complex element type with ") + << nloops << " parallel iterators: " << *(operation); + + // Construct the indexing maps needed for linalg.generic ops. + SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; + + for (Value in : args) + bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + + SmallVector outputBuffers; + for (auto result : results) { + auto resultType = result.getType().template cast(); + if (!resultType.hasStaticShape()) + return rewriter.notifyMatchFailure( + operation, + "tosa to linalg conversion expects statically shaped tensors"); + + outputBuffers.push_back(getInitTensor(rewriter, loc, resultType)); + opResultTypes.push_back(result.getType()); + } + + bodyResultTypes = llvm::to_vector<4>(llvm::map_range( + outputBuffers, [](Value v) { return getElementTypeOrSelf(v); })); + + // Supports only non-broadcasted operation. Shoudl consider update indexing + // map to be multidimensional. + AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops); + SmallVector indexingMaps(args.size() + bodyResultTypes.size(), + commonIndexingMap); + + auto linalgOp = + createGenericOp(loc, nloops, indexingMaps, args, outputBuffers, + opResultTypes, bodyArgTypes, bodyResultTypes, rewriter); + if (!linalgOp) + return failure(); + + rewriter.replaceOp(operation, linalgOp->getResults()); + return success(); + } + + LogicalResult + matchAndRewrite(SrcOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewriteHelper(op, args, rewriter); + } +}; + +template +class SimplePointwiseConverter : public PointwiseConverter { +public: + using PointwiseConverter::PointwiseConverter; + + Operation * + createGenericOp(Location loc, int nloops, ArrayRef indexingMaps, + ArrayRef args, ArrayRef outputBuffers, + ArrayRef opResultTypes, ArrayRef bodyArgTypes, + ArrayRef bodyResultTypes, + ConversionPatternRewriter &rewriter) const override { + if (!bodyResultTypes.front().isa()) + return nullptr; + + return rewriter.create( + loc, opResultTypes, ValueRange(args), outputBuffers, indexingMaps, + getNParallelLoopsAttrs(nloops), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value opResult = rewriter.create( + loc, bodyResultTypes, blockArgs.take_front(args.size())); + nestedBuilder.create(loc, opResult); + }); + } +}; + +} // namespace + +void populateTosaToLinalgOnTensorsConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert< + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter, + SimplePointwiseConverter>(context); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -0,0 +1,58 @@ +//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass legalizes Tosa operations to the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TosaToLinalgOnTensors + : public TosaToLinalgOnTensorsBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + auto func = getFunction(); + mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( + func.getContext(), &patterns); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::tosa::createTosaToLinalgOnTensors() { + return std::make_unique(); +} + +void mlir::tosa::addTosaToLinalgOnTensorsPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaMakeBroadcastablePass()); + pm.addNestedPass(createTosaToLinalgOnTensors()); +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -0,0 +1,257 @@ +// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor) -> tensor { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0(%arg1: f32, %arg2: f32): + // CHECK: [[ELEMENT:%.+]] = absf %arg1 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + + %0 = "tosa.abs"(%arg0) : (tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_abs +func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: absf + %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + + +// ----- + +// CHECK-LABEL: @test_tanh +func @test_tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: tanh + %0 = "tosa.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{failed to legalize operation 'tosa.add'}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @test_add +func @test_add(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: sub + %0 = "tosa.sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: sub + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_sub +func @test_sub(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: sub + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_pow +func @test_pow(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: pow + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: @test_and +func @test_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: and + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_and +func @test_and(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: and + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + + +// ----- + +// CHECK-LABEL: @test_or +func @test_or(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: or + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_or +func @test_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: or + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @test_xor +func @test_xor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: xor + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_xor +func @test_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: xor + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: @test_logical_left_shift +func @test_logical_left_shift(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: shift_left + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_left_shift +func @test_logical_left_shift(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: shift_left + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} +// ----- + +// CHECK-LABEL: @test_logical_right_shift +func @test_logical_right_shift(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: shift_right_unsigned + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @test_logical_right_shift +func @test_logical_right_shift(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[GENERIC:%.+]] = linalg.generic + // CHECK: shift_right_unsigned + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] + return %0 : tensor +} +