diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -152,23 +152,8 @@ return rewriter.notifyMatchFailure(operation, "All results must be a shaped type"); - // For now require no broadcasting. Consider making it support broadcasting - // operations. - Type uniqueInTy = operation->getOperand(0).getType(); - bool allInputTypesEqual = - llvm::all_of(operation->getOperandTypes(), - [&](Type operandTy) { return operandTy == uniqueInTy; }); - if (!allInputTypesEqual) - return rewriter.notifyMatchFailure(operation, - "All operands must have the same type"); - bool resultAndInputShapeEqual = - llvm::all_of(operation->getResultTypes(), [&](Type resultTy) { - return resultTy.cast().getShape() == t0.getShape(); - }); - - if (!resultAndInputShapeEqual) - return rewriter.notifyMatchFailure( - operation, "All results must have the same shape as the input"); + assert(operation->getNumResults() == 1 && + "All TOSA elementwise ops should only return a single result."); // Construct the indexing maps needed for linalg.generic ops. SmallVector bodyArgTypes; @@ -194,12 +179,30 @@ auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( initTensors, [](Value v) { return getElementTypeOrSelf(v); })); - // Supports only non-broadcasted operation. Shoudl consider update indexing - // map to be multidimensional. unsigned nloops = t0.getRank(); - AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops); - SmallVector indexingMaps( - operation->getNumOperands() + bodyResultTypes.size(), commonIndexingMap); + SmallVector indexingMaps; + indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); + + // Input indexing maps may be broadcasted. + for (Type types : operation->getOperandTypes()) { + auto shape = types.cast().getShape(); + SmallVector dimExprs; + dimExprs.reserve(nloops); + for (unsigned i = 0; i < nloops; ++i) { + // If the dimension is one we can broadcast the input with a constant + // affine expression. + if (shape[i] == 1) + dimExprs.push_back(rewriter.getAffineConstantExpr(0)); + else + dimExprs.push_back(rewriter.getAffineDimExpr(i)); + } + indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops, + /*symbolCount=*/0, dimExprs, + rewriter.getContext())); + } + + indexingMaps.append(operation->getNumResults(), + rewriter.getMultiDimIdentityMap(nloops)); bool didEncounterError = false; auto linalgOp = rewriter.create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s -// CHECK: #map = affine_map<() -> ()> +// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()> // CHECK-LABEL: @test_abs func @test_abs(%arg0: tensor) -> tensor { // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor) outs([[INIT]] : tensor) { + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%arg0 : tensor) outs([[INIT]] : tensor) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = absf %arg1 // CHECK: linalg.yield [[ELEMENT]] : f32 @@ -19,54 +19,73 @@ // ----- -// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @test_abs -func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { +func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = absf %arg1 // CHECK: linalg.yield [[ELEMENT]] : f32 - // CHECK: } -> tensor<1xf32> - %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: } -> tensor<2xf32> + %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return [[GENERIC]] - return %0 : tensor<1xf32> + return %0 : tensor<2xf32> } // ----- -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @test_abs -func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) { +func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = absf %arg1 // CHECK: linalg.yield [[ELEMENT]] : f32 - // CHECK: } -> tensor<1x2xf32> - %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: } -> tensor<2x3xf32> + %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK: return [[GENERIC]] - return %0 : tensor<1x2xf32> + return %0 : tensor<2x3xf32> } // ----- -func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{failed to legalize operation 'tosa.add'}} +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @test_broadcast +func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<2xf32> %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } // ----- -func @test_add(%arg0: tensor<1xf32>, %arg1: tensor) -> tensor<1xf32> { - // expected-error @+1 {{failed to legalize operation 'tosa.add'}} - %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor) -> tensor<1xf32> - return %0 : tensor<1xf32> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @test_multibroadcast +func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<2x1xf32>) outs([[INIT]] : tensor<2x3xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<2x3xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> } // -----