diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -152,23 +152,28 @@ return rewriter.notifyMatchFailure(operation, "All results must be a shaped type"); - // For now require no broadcasting. Consider making it support broadcasting - // operations. - Type uniqueInTy = operation->getOperand(0).getType(); - bool allInputTypesEqual = - llvm::all_of(operation->getOperandTypes(), - [&](Type operandTy) { return operandTy == uniqueInTy; }); - if (!allInputTypesEqual) - return rewriter.notifyMatchFailure(operation, - "All operands must have the same type"); - bool resultAndInputShapeEqual = - llvm::all_of(operation->getResultTypes(), [&](Type resultTy) { - return resultTy.cast().getShape() == t0.getShape(); - }); - - if (!resultAndInputShapeEqual) + // All TOSA elementwise ops should only return a single result. + assert(operation->getNumResults() == 1); + + // Dimension sizes should all be broadcastable to the output shape. + auto resultTy = operation->getResultTypes().front().cast(); + auto shapeComparison = [&](Type operandTy) { + auto shapedTy = operandTy.cast(); + if (shapedTy.getRank() != resultTy.getRank()) + return false; + for (int i = 0, r = shapedTy.getRank(); i < r; i++) { + if (resultTy.getDimSize(i) != shapedTy.getDimSize(i) && + shapedTy.getDimSize(i) != 1) + return false; + } + return true; + }; + + bool allInputShapesEqual = + llvm::all_of(operation->getOperandTypes(), shapeComparison); + if (!allInputShapesEqual) return rewriter.notifyMatchFailure( - operation, "All results must have the same shape as the input"); + operation, "All operands must broadcast compatible shapes"); // Construct the indexing maps needed for linalg.generic ops. SmallVector bodyArgTypes; @@ -194,12 +199,32 @@ auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( initTensors, [](Value v) { return getElementTypeOrSelf(v); })); - // Supports only non-broadcasted operation. Shoudl consider update indexing + // Supports only non-broadcasted operation. Should consider updating the index // map to be multidimensional. unsigned nloops = t0.getRank(); - AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops); - SmallVector indexingMaps( - operation->getNumOperands() + bodyResultTypes.size(), commonIndexingMap); + SmallVector indexingMaps; + indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); + + // Input indexing maps may be broadcasted. + for (Type types : operation->getOperandTypes()) { + auto shape = types.cast().getShape(); + SmallVector dimExprs; + dimExprs.reserve(nloops); + for (unsigned i = 0; i < nloops; ++i) { + // If the dimension is one we can broadcast the input with a constant + // affine expression. + if (shape[i] == 1) + dimExprs.push_back(rewriter.getAffineConstantExpr(0)); + else + dimExprs.push_back(rewriter.getAffineDimExpr(i)); + } + indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops, + /*symbolCount=*/0, dimExprs, + rewriter.getContext())); + } + + indexingMaps.append(operation->getNumResults(), + rewriter.getMultiDimIdentityMap(nloops)); bool didEncounterError = false; auto linalgOp = rewriter.create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -22,17 +22,17 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: @test_abs -func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) { +func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = absf %arg1 // CHECK: linalg.yield [[ELEMENT]] : f32 - // CHECK: } -> tensor<1xf32> - %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + // CHECK: } -> tensor<2xf32> + %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: return [[GENERIC]] - return %0 : tensor<1xf32> + return %0 : tensor<2xf32> } // ----- @@ -40,23 +40,23 @@ // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @test_abs -func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) { +func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg1: f32, %arg2: f32): // CHECK: [[ELEMENT:%.+]] = absf %arg1 // CHECK: linalg.yield [[ELEMENT]] : f32 - // CHECK: } -> tensor<1x2xf32> - %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: } -> tensor<2x3xf32> + %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK: return [[GENERIC]] - return %0 : tensor<1x2xf32> + return %0 : tensor<2x3xf32> } // ----- +// CHECK-LABEL: @test_add func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{failed to legalize operation 'tosa.add'}} %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> }