diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/Sequence.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -25,6 +26,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -518,112 +520,147 @@ return nullptr; } -static LogicalResult -elementwiseMatchAndRewriteHelper(Operation *operation, - PatternRewriter &rewriter) { - auto loc = operation->getLoc(); - - assert(operation->getNumResults() == 1 && - "All TOSA elementwise ops should only return a single result."); - - auto result = operation->getResult(0); - auto resultTy = dyn_cast(result.getType()); - - if (!resultTy) - return rewriter.notifyMatchFailure( - operation, "All results must be a ranked tensor type"); - - unsigned rank = resultTy.getRank(); - - // Construct the indexing maps needed for linalg.generic ops. - SmallVector bodyArgTypes; +static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, + int64_t rank) { + // No need to expand if we are already at the desired rank + auto shapedType = dyn_cast(tensor.getType()); + assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); + int64_t numExtraDims = rank - shapedType.getRank(); + assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank"); + if (!numExtraDims) + return tensor; + + // Compute reassociation indices + SmallVector> reassociationIndices(shapedType.getRank()); + int64_t index = 0; + for (index = 0; index <= numExtraDims; index++) + reassociationIndices[0].push_back(index); + for (size_t position = 1; position < reassociationIndices.size(); position++) + reassociationIndices[position].push_back(index++); + + // Compute result type + SmallVector resultShape; + for (index = 0; index < numExtraDims; index++) + resultShape.push_back(1); + for (auto size : shapedType.getShape()) + resultShape.push_back(size); + auto resultType = RankedTensorType::get(resultShape, shapedType.getElementType()); + + // Emit 'tensor.expand_shape' op + return rewriter.create( + loc, resultType, tensor, reassociationIndices); +} - for (Value in : operation->getOperands()) - bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); +static Value +getDominantValueForDim(Value lhs, Value rhs, unsigned index) { + auto lhsDimSize = lhs.getType().cast().getDimSize(index); + auto rhsDimSize = rhs.getType().cast().getDimSize(index); + if ((ShapedType::isDynamic(lhsDimSize) && rhsDimSize > 1) || + (lhsDimSize == 1 && ShapedType::isDynamic(rhsDimSize)) || + (lhsDimSize == 1 && rhsDimSize > 1)) + return rhs; + return lhs; +} - SmallVector opResultTypes; - SmallVector emptyTensors; +static Value +getDominantValueForDim(ValueRange values, unsigned index) { + auto dominantValue = values.front(); + for (auto value : values.drop_front()) + dominantValue = getDominantValueForDim(dominantValue, value, index); + return dominantValue; +} - SmallVector dynDims; - dynDims.resize(rank); +static OpFoldResult +getTensorDim(PatternRewriter &rewriter, Location loc, Value tensor, int64_t index) { + auto shapedType = dyn_cast(tensor.getType()); + assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); + assert(index >= 0 && index < shapedType.getRank() && "index out of bounds"); + if (shapedType.isDynamicDim(index)) + return rewriter.create(loc, tensor, index).getResult(); + return rewriter.getIndexAttr(shapedType.getDimSize(index)); +} - for (auto arg : operation->getOperands()) { - auto operandTy = cast(arg.getType()); - for (int i = 0; i < operandTy.getRank(); i++) { - if (operandTy.isDynamicDim(i) && !dynDims[i]) - dynDims[i] = rewriter.create(loc, arg, i); - } +static Value createOutputTensor(PatternRewriter &rewriter, Location loc, + ValueRange values, Type elementType) { + SmallVector shape; + auto rank = values.front().getType().cast().getRank(); + for (auto index : llvm::seq(0, rank)) { + auto dominantValue = getDominantValueForDim(values, index); + auto dim = getTensorDim(rewriter, loc, dominantValue, index); + shape.push_back(dim); } + return rewriter.create(loc, shape, elementType); +} - SmallVector filteredDims = condenseValues(dynDims); - - emptyTensors.push_back( - rewriter.create(loc, resultTy, filteredDims)); - opResultTypes.push_back(result.getType()); - - auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( - emptyTensors, [](Value v) { return getElementTypeOrSelf(v); })); - - SmallVector operands; - SmallVector indexingMaps; - indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); - - // Input indexing maps may be broadcasted. - for (Value operand : operation->getOperands()) { - ShapedType type = cast(operand.getType()); - - if (type.getShape() == resultTy.getShape()) { - operands.push_back(operand); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); - continue; - } +static bool +operandsAndResultsRanked(Operation* operation) { + auto isRanked = [](Value value) { return isa(value.getType()); }; + return llvm::all_of(operation->getOperands(), isRanked) && + llvm::all_of(operation->getResults(), isRanked); +} - SmallVector newShape; - SmallVector affineExprs; - newShape.reserve(type.getRank()); - for (const auto &it : llvm::enumerate(type.getShape())) { - if (it.value() == resultTy.getDimSize(it.index())) { - newShape.push_back(it.value()); - affineExprs.push_back( - mlir::getAffineDimExpr(it.index(), rewriter.getContext())); - } - } +static LogicalResult +elementwiseMatchAndRewriteHelper(Operation *operation, + PatternRewriter &rewriter) { - if (newShape.size() != rank) { - operand = rewriter.create( - loc, RankedTensorType::get(newShape, type.getElementType()), operand, - rewriter.getDenseI64ArrayAttr(newShape)); + // Collect op properties + assert(operation->getNumResults() == 1 && "elementwise op expects 1 result"); + assert(operation->getNumOperands() >= 1 && "elementwise op expects at least 1 operand"); + auto loc = operation->getLoc(); + auto result = operation->getResult(0); + auto resultType = result.getType().cast(); + + // Check supported features for this pass + if (!operandsAndResultsRanked(operation)) + return rewriter.notifyMatchFailure(operation, "Unranked tensors not supported"); + + // Equalize input ranks + auto rank = resultType.getRank(); + auto expandedOperands = llvm::map_to_vector(operation->getOperands(), [&](Value operand) { + return expandRank(rewriter, loc, operand, rank); + }); + + // Create output tensor + auto outputTensor = createOutputTensor( + rewriter, loc, expandedOperands, resultType.getElementType()); + + // Build affine maps + auto affineMaps = llvm::map_to_vector(expandedOperands, [&](auto operand) { + auto shape = cast(operand.getType()).getShape(); + SmallVector affineExprs; + for (auto it : llvm::enumerate(shape)) { + auto affineExpr = it.value() == 1 ? + rewriter.getAffineConstantExpr(0) : + rewriter.getAffineDimExpr(it.index()); + affineExprs.push_back(affineExpr); } + return AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); + }); + affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); - operands.push_back(operand); - indexingMaps.push_back(AffineMap::get( - /*dimCount=*/rank, /*symbolCount=*/0, affineExprs, - rewriter.getContext())); - } - - indexingMaps.append(operation->getNumResults(), - rewriter.getMultiDimIdentityMap(rank)); - - bool didEncounterError = false; + // Emit 'linalg.generic' op + bool encounteredError = false; auto linalgOp = rewriter.create( - loc, opResultTypes, operands, emptyTensors, indexingMaps, + loc, outputTensor.getType(), expandedOperands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), - bodyResultTypes, rewriter); + {resultType.getElementType()}, rewriter); if (!opResult) { - didEncounterError = true; + encounteredError = true; return; } nestedBuilder.create(loc, opResult); }); - - if (didEncounterError) + if (encounteredError) return rewriter.notifyMatchFailure( operation, "unable to create linalg.generic body for elementwise op"); - - rewriter.replaceOp(operation, linalgOp->getResults()); + + // Cast 'linalg.generic' result into original result type if needed + auto castResult = rewriter.createOrFold( + loc, resultType, linalgOp->getResult(0)); + rewriter.replaceOp(operation, castResult); return success(); } diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -195,18 +195,22 @@ static bool isCompatibleInferredReturnShape(ArrayRef inferred, ArrayRef existing) { - auto isCompatible = [](int64_t dim1, int64_t dim2) { - // If the inferred and existing dim is the same, or one of them is unknown - // then it is compatible, else if the inferred dim is 1 then it is also - // compatible. But if the existing dim is 1 and the inferred is greater than - // 1 then flag. - return dim1 == dim2 || ShapedType::isDynamic(dim1) || - ShapedType::isDynamic(dim2) || dim1 == 1; + auto isCompatible = [](int64_t inferredDim, int64_t existingDim) { + // The following criterion is used to determine the validity of an existing + // dimension: + // + // inferredDim existingDim Behavior + // ----------- ----------- -------- + // dynamic dynamic OK + // dynamic static Error + // static dynamic OK + // static static OK if equal + return ShapedType::isDynamic(existingDim) || inferredDim == existingDim; }; if (inferred.size() != existing.size()) return false; - for (auto p : llvm::zip(inferred, existing)) - if (!isCompatible(std::get<0>(p), std::get<1>(p))) + for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing)) + if (!isCompatible(inferredDim, existingDim)) return false; return true; } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1,89 +1,148 @@ // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics -o -| FileCheck %s + // CHECK: #[[$MAP0:.*]] = affine_map<() -> ()> -// CHECK-LABEL: @test_abs -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_abs(%arg0: tensor) -> tensor { +// CHECK-LABEL: @test_abs_scalar +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +func.func @test_abs_scalar(%arg0: tensor) -> tensor { // CHECK: [[INIT:%.+]] = tensor.empty() : tensor - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%[[ARG0]] : tensor) outs([[INIT]] : tensor) { - // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32): - // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins([[ARG0]] : tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 // CHECK: } -> tensor + %0 = "tosa.abs"(%arg0) : (tensor) -> tensor - %0 = "tosa.abs"(%arg0) : (tensor) -> tensor + // CHECK: return [[GENERIC]] : tensor + return %0 : tensor +} - // CHECK: return [[GENERIC]] - return %0 : tensor +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_abs_static_dynamic +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +func.func @test_abs_static_dynamic(%arg0: tensor<3x?xf32>) -> tensor<3x?xf32> { + // CHECK: [[ONE:%.+]] = arith.constant 1 : index + // CHECK: [[DIM:%.+]] = tensor.dim [[ARG0]], [[ONE]] : tensor<3x?xf32> + // CHECK: [[INIT:%.+]] = tensor.empty([[DIM]]) : tensor<3x?xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]] : tensor<3x?xf32>) outs([[INIT]] : tensor<3x?xf32>) { + // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<3x?xf32> + %0 = "tosa.abs"(%arg0) : (tensor<3x?xf32>) -> tensor<3x?xf32> + + // CHECK: return [[GENERIC]] : tensor<3x?xf32> + return %0 : tensor<3x?xf32> } // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: @test_abs -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-LABEL: @test_abs_cast_result +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +func.func @test_abs_cast_result(%arg0: tensor<2xf32>) -> tensor { // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { - // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32): - // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]] : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { + // CHECK: ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 // CHECK: } -> tensor<2xf32> - %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor - // CHECK: return [[GENERIC]] - return %0 : tensor<2xf32> + // CHECK: [[CAST:%.+]] = tensor.cast [[GENERIC]] : tensor<2xf32> to tensor + // CHECK: return [[CAST]] : tensor + return %0 : tensor } // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @test_abs -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) { - // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32): - // CHECK: [[ELEMENT:%.+]] = math.absf %[[ARG1]] +// CHECK-LABEL: @test_add_all_static +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_all_static(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> tensor<3x5xf32> { + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<3x5xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor<3x5xf32>, tensor<3x5xf32>) outs([[INIT]] : tensor<3x5xf32>) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 - // CHECK: } -> tensor<2x3xf32> - %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> - - // CHECK: return [[GENERIC]] - return %0 : tensor<2x3xf32> + // CHECK: } -> tensor<3x5xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> + + // CHECK: return [[GENERIC]] : tensor<3x5xf32> + return %0 : tensor<3x5xf32> } // ----- -// CHECK-LABEL: @test_abs -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_abs(%arg0: tensor) -> tensor { - // CHECK: %[[C0:.+]] = arith.constant 0 - // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] - // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) - // CHECK: linalg.generic - // CHECK: math.absf - %0 = "tosa.abs"(%arg0) : (tensor) -> tensor - return %0 : tensor +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_add_all_dynamic +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_all_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: [[ZERO:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[ONE:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[DIM0:%.+]] = tensor.dim [[ARG0]], [[ZERO]] : tensor + // CHECK-DAG: [[DIM1:%.+]] = tensor.dim [[ARG0]], [[ONE]] : tensor + // CHECK: [[INIT:%.+]] = tensor.empty([[DIM0]], [[DIM1]]) : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] : tensor + return %0 : tensor } // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @test_abs_dyn -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { - // CHECK: %[[C1:.+]] = arith.constant 1 - // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] - // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) - // CHECK: linalg.generic - // CHECK: math.absf - %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> - return %0 : tensor<2x?xf32> +// CHECK-LABEL: @test_add_static_dynamic +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_static_dynamic(%arg0: tensor, %arg1: tensor<2x?xf32>) -> tensor<2x3xf32> { + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor, tensor<2x?xf32>) outs([[INIT]] : tensor<2x3xf32>) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<2x3xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<2x?xf32>) -> tensor<2x3xf32> + + // CHECK: return [[GENERIC]] : tensor<2x3xf32> + return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_add_broadcast_to_static +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_broadcast_to_static(%arg0: tensor<1x5xf32>, %arg1: tensor<3x1xf32>) -> tensor<3x5xf32> { + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<3x5xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor<1x5xf32>, tensor<3x1xf32>) outs([[INIT]] : tensor<3x5xf32>) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<3x5xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x5xf32>, tensor<3x1xf32>) -> tensor<3x5xf32> + + // CHECK: return [[GENERIC]] : tensor<3x5xf32> + return %0 : tensor<3x5xf32> } // ----- @@ -100,68 +159,122 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @test_broadcast -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32 -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32> -func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { - // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): - // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_add_broadcast_to_dynamic +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_broadcast_to_dynamic(%arg0: tensor<1x?xf32>, %arg1: tensor) -> tensor { + // CHECK-DAG: [[ZERO:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[ONE:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[DIM0:%.+]] = tensor.dim [[ARG1]], [[ZERO]] : tensor + // CHECK-DAG: [[DIM1:%.+]] = tensor.dim [[ARG0]], [[ONE]] : tensor<1x?xf32> + // CHECK: [[INIT:%.+]] = tensor.empty([[DIM0]], [[DIM1]]) : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins([[ARG0]], [[ARG1]] : tensor<1x?xf32>, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 - // CHECK: } -> tensor<2xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0 : tensor<2xf32> + // CHECK: } -> tensor + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x?xf32>, tensor) -> tensor + + // CHECK: return [[GENERIC]] : tensor + return %0 : tensor } // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()> -// CHECK-LABEL: @test_broadcast_swapped_args -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<2xf32 -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32> -func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { +// CHECK-LABEL: @test_add_cast_result +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_cast_result(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor { // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32> - // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { - // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): - // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]] : tensor<2xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 // CHECK: } -> tensor<2xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> - return %0 : tensor<2xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor + + // CHECK: [[CAST:%.+]] = tensor.cast [[GENERIC]] : tensor<2xf32> to tensor + // CHECK: return [[CAST]] : tensor + return %0 : tensor } // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @test_multibroadcast -// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]] -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]] -func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { +// CHECK-LABEL: @test_add_expand_rank +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +func.func @test_add_expand_rank(%arg0: tensor<3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: [[EXPANDED:%.+]] = tensor.expand_shape [[ARG0]] {{\[\[}}0, 1]] : tensor<3xf32> into tensor<1x3xf32> // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32> - // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) <{new_shape = array} - // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) <{new_shape = array} - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { - // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): - // CHECK: [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32 + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins([[EXPANDED]], [[ARG1]] : tensor<1x3xf32>, tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) { + // CHECK: ^bb0([[ARG2:%.*]]: f32, [[ARG3:%.*]]: f32, [[ARG4:%.*]]: f32): + // CHECK: [[ELEMENT:%.*]] = arith.addf [[ARG2]], [[ARG3]] : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 // CHECK: } -> tensor<2x3xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + + // CHECK: return [[GENERIC]] : tensor<2x3xf32> return %0 : tensor<2x3xf32> } // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: @test_select +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]: +func.func @test_select(%arg0: tensor<1x?x?xi1>, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-DAG: [[ZERO:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[DIM0:%.+]] = tensor.dim [[ARG1]], [[ZERO]] : tensor + // CHECK: [[INIT:%.+]] = tensor.empty([[DIM0]]) : tensor + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]] : tensor<1x?x?xi1>, tensor, tensor) outs([[INIT]] : tensor) { + // CHECK: ^bb0([[ARG3:%.+]]: i1, [[ARG4:%.+]]: f32, [[ARG5:%.+]]: f32, [[ARG6:%.+]]: f32): + // CHECK: [[ELEMENT:%.+]] = arith.select [[ARG3]], [[ARG4]], [[ARG5]] : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x?x?xi1>, tensor, tensor) -> tensor + + // CHECK: return [[GENERIC]] : tensor + return %0 : tensor +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (0, 0, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, d1, d2)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: @test_select_expand_rank +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]: +// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]: +func.func @test_select_expand_rank(%arg0: tensor<4xi1>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // CHECK: [[EXPANDED0:%.+]] = tensor.expand_shape [[ARG0]] {{\[\[}}0, 1, 2]] : tensor<4xi1> into tensor<1x1x4xi1> + // CHECK: [[EXPANDED1:%.+]] = tensor.expand_shape [[ARG1]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3x4xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[EXPANDED0]], [[EXPANDED1]], [[ARG2]] : tensor<1x1x4xi1>, tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs([[INIT]] : tensor<2x3x4xf32>) { + // CHECK: ^bb0([[ARG3:%.+]]: i1, [[ARG4:%.+]]: f32, [[ARG5:%.+]]: f32, [[ARG6:%.+]]: f32): + // CHECK: [[ELEMENT:%.+]] = arith.select [[ARG3]], [[ARG4]], [[ARG5]] : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<2x3x4xf32> + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} + +// ----- + // CHECK-LABEL: @test_simple_f32 func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic @@ -1412,20 +1525,6 @@ // ----- -// Regression test for using the wrong rank. - -// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> ()> -// CHECK-LABEL: @select_fp32 -func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor) -> tensor<1x12x5x5xf32> { - // CHECK: linalg.generic - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> - return %0 : tensor<1x12x5x5xf32> -} - -// ----- - // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -111,9 +111,18 @@ // ----- -func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor { - %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor - return %0 : tensor +// Error for inferred dynamic dimension but existing static dimensions +func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { + // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}} + %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + +func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor { + %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor + return %0 : tensor } // ----- @@ -145,10 +154,19 @@ // ----- -func.func @broadcast_tensor_tensor_tensor(tensor, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> { -^bb0(%arg0: tensor, %arg1: tensor<7x1x5xi32>): - %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> - return %0 : tensor<8x7x6x5xi32> +// Correct use of broadcast semantics for input dimensions +func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<7x1x5xi32>) -> tensor { + %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<7x1x5xi32>) -> tensor + return %0 : tensor +} + +// ----- + +// Incorrect attempt to use broadcast semantics for result +func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<5xi32> { + // expected-error @+1 {{op result type '5' not broadcast compatible with broadcasted operands's shapes '1'}} + %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<5xi32> + return %0 : tensor<5xi32> } // -----