diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -48,26 +48,6 @@ } } -// Generates an affine map for parallel operations on a given type. This -// performs implicit broadcasting across any dimension of size-1. -static AffineMap createAffineMapForType(ShapedType type, - PatternRewriter &rewriter) { - unsigned rank = type.getRank(); - auto shape = type.getShape(); - SmallVector dimExprs; - dimExprs.reserve(rank); - for (unsigned i = 0; i < rank; ++i) { - // If the dimension is one we can broadcast the input with a constant - // affine expression. - if (shape[i] == 1) - dimExprs.push_back(rewriter.getAffineConstantExpr(0)); - else - dimExprs.push_back(rewriter.getAffineDimExpr(i)); - } - return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, - rewriter.getContext()); -} - template static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, mlir::ConstantOp max, P pred, @@ -464,11 +444,14 @@ PatternRewriter &rewriter) { auto loc = operation->getLoc(); auto results = operation->getResults(); - auto t0 = operation->getOperand(0).getType().template dyn_cast(); - if (!t0) + auto resultTy = operation->getOperand(0).getType().dyn_cast(); + + if (!resultTy) return rewriter.notifyMatchFailure(operation, "All results must be a shaped type"); + unsigned rank = resultTy.getRank(); + assert(operation->getNumResults() == 1 && "All TOSA elementwise ops should only return a single result."); @@ -496,23 +479,42 @@ auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( initTensors, [](Value v) { return getElementTypeOrSelf(v); })); - unsigned nloops = t0.getRank(); + SmallVector operands; SmallVector indexingMaps; indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); // Input indexing maps may be broadcasted. - for (Type type : operation->getOperandTypes()) { - indexingMaps.push_back( - createAffineMapForType(type.cast(), rewriter)); + for (Value operand : operation->getOperands()) { + ShapedType type = operand.getType().cast(); + SmallVector newShape; + SmallVector affineExprs; + newShape.reserve(type.getRank()); + for (auto it : llvm::enumerate(type.getShape())) { + if (it.value() != 1) { + newShape.push_back(it.value()); + affineExprs.push_back( + mlir::getAffineDimExpr(it.index(), rewriter.getContext())); + } + } + + if (newShape.size() != rank) { + operand = rewriter.create( + loc, RankedTensorType::get(newShape, type.getElementType()), operand); + } + + operands.push_back(operand); + indexingMaps.push_back(AffineMap::get( + /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs, + rewriter.getContext())); } indexingMaps.append(operation->getNumResults(), - rewriter.getMultiDimIdentityMap(nloops)); + rewriter.getMultiDimIdentityMap(rank)); bool didEncounterError = false; auto linalgOp = rewriter.create( - loc, opResultTypes, operation->getOperands(), initTensors, indexingMaps, - getNParallelLoopsAttrs(nloops), + loc, opResultTypes, operands, initTensors, indexingMaps, + getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), @@ -650,12 +652,20 @@ auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); + llvm::SmallVector reduceShape; + for (unsigned i = 0; i < inputTy.getRank(); i++) { + if (axis != i) + reduceShape.push_back(inputTy.getDimSize(i)); + } + + Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType()); + // First fill the output buffer with the init value. - auto initTensor = rewriter - .create(loc, ArrayRef({}), - resultTy.getShape(), - resultTy.getElementType()) - .result(); + auto initTensor = + rewriter + .create(loc, ArrayRef({}), reduceShape, + resultTy.getElementType()) + .result(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) @@ -676,14 +686,12 @@ : getParallelIteratorTypeName()); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); - else - dstExprs.push_back(rewriter.getAffineConstantExpr(0)); } bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs}); auto linalgOp = rewriter.create( - loc, resultTy, input, filledTensor, maps, iteratorTypes, + loc, reduceTy, input, filledTensor, maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { auto result = createLinalgBodyCalculationForReduceOp( op, blockArgs, elementTy, rewriter); @@ -696,7 +704,8 @@ if (!didEncounterError) return failure(); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + rewriter.replaceOpWithNewOp(op, resultTy, + linalgOp.getResults()); return success(); } @@ -971,9 +980,12 @@ } currDstDim++; } + + // Check if any remaining dimensions exist. If either is rank-0 we only + // require the directly lowering. if (currSrcDim != expandedShape.size() || currDstDim != collapsedShape.size()) - isCollapsingSource = false; + isCollapsingSource = collapsedShape.empty() || expandedShape.empty(); // Otherwise, we need to first reduce all source dimensions into one and // then expand to the destination dimensions. @@ -1084,56 +1096,65 @@ op.double_round() && llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); - // We need to broadcast along the last dimension, so make all dims 1. - SmallVector multiplierShape; - multiplierShape.resize(rank, 1); - - SmallVector shiftShape; - shiftShape.resize(rank, 1); - - // Set the channel dimension to match the number of shift/broadcast - // channels. - if (!multiplierShape.empty()) - multiplierShape.back() = multiplierValues.size(); - if (!shiftShape.empty()) - shiftShape.back() = shiftValues.size(); - - // Create the tensor types. - auto multiplierType = - RankedTensorType::get(multiplierShape, rewriter.getI32Type()); - auto shiftType = - RankedTensorType::get(shiftShape, rewriter.getIntegerType(8)); + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(rank)}; + SmallVector genericInputs = {input}; + + // If we are rescaling per-channel then we need to store the multiplier + // values in a buffer. + Value multiplierConstant; + int64_t multiplierArg = 0; + if (multiplierValues.size() == 1) { + multiplierConstant = rewriter.create( + loc, rewriter.getI32IntegerAttr(multiplierValues.front())); + } else { + SmallVector multiplierExprs{ + rewriter.getAffineDimExpr(rank - 1)}; + auto multiplierType = + RankedTensorType::get({static_cast(multiplierValues.size())}, + rewriter.getI32Type()); + genericInputs.push_back(rewriter.create( + loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); + + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, multiplierExprs, + rewriter.getContext())); + + multiplierArg = indexingMaps.size() - 1; + } - auto multiplierConst = rewriter.create( - loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)); + // If we are rescaling per-channel then we need to store the shift + // values in a buffer. + Value shiftConstant; + int64_t shiftArg = 0; + if (shiftValues.size() == 1) { + shiftConstant = rewriter.create( + loc, rewriter.getI8IntegerAttr(shiftValues.front())); + } else { + SmallVector shiftExprs = { + rewriter.getAffineDimExpr(rank - 1)}; + auto shiftType = + RankedTensorType::get({static_cast(shiftValues.size())}, + rewriter.getIntegerType(8)); + genericInputs.push_back(rewriter.create( + loc, DenseIntElementsAttr::get(shiftType, shiftValues))); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, shiftExprs, + rewriter.getContext())); + shiftArg = indexingMaps.size() - 1; + } - auto shiftConst = rewriter.create( - loc, DenseIntElementsAttr::get(shiftType, shiftValues)); + // Indexing maps for output values. + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. - SmallVector bodyArgTypes = {getElementTypeOrSelf(inputTy), - rewriter.getI32Type(), - rewriter.getI32Type()}; Value initTensor = rewriter.create( loc, ArrayRef({}), outputTy.getShape(), outputTy.getElementType()); - SmallVector indexingMaps; - - // Indexing map for input values. - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); - - // Shift and multiplier will need to broadcast across their non channel - // values. - indexingMaps.push_back(createAffineMapForType(multiplierType, rewriter)); - indexingMaps.push_back(createAffineMapForType(shiftType, rewriter)); - - // Indexing maps for output values. - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); - auto linalgOp = rewriter.create( - loc, outputTy, ValueRange{input, multiplierConst, shiftConst}, - ValueRange{initTensor}, indexingMaps, getNParallelLoopsAttrs(rank), + loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, + getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { // For now we do all of our math in 64-bit. This is not optimal but @@ -1145,8 +1166,9 @@ op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); Value value = blockArgs[0]; - Value multiplier = blockArgs[1]; - Value shift = blockArgs[2]; + Value multiplier = multiplierConstant ? multiplierConstant + : blockArgs[multiplierArg]; + Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (value.getType().getIntOrFloatBitWidth() < 32) { value = nestedBuilder.create( @@ -1608,17 +1630,6 @@ SmallVector multiples; getValuesFromIntArrayAttribute(op.multiples(), multiples); - llvm::SmallVector reshapeShape; - reshapeShape.reserve(rank * 2); - for (int i = 0; i < rank; i++) { - reshapeShape.push_back(1); - reshapeShape.push_back(inputShape[i]); - } - - ShapedType reshapeTy = RankedTensorType::get(reshapeShape, elementTy); - Value reshape = rewriter.create( - loc, reshapeTy, input, rewriter.getI64ArrayAttr(reshapeTy.getShape())); - // Broadcast the newly added dimensions to their appropriate multiple. SmallVector genericShape; for (int i = 0; i < rank; i++) { @@ -1629,12 +1640,21 @@ auto initTensor = rewriter.create( op.getLoc(), ArrayRef({}), genericShape, elementTy); + // We needs to map the input shape to the non-broadcasted dimensions. + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); + + auto readAffineMap = + AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, + rewriter.getContext()); + SmallVector affineMaps = { - createAffineMapForType(reshapeTy, rewriter), - rewriter.getMultiDimIdentityMap(genericShape.size())}; + readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; auto genericOp = rewriter.create( - loc, RankedTensorType::get(genericShape, elementTy), reshape, + loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -55,13 +55,14 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (0)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()> // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @test_broadcast func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 @@ -72,14 +73,16 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @test_multibroadcast func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> { // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32> - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<2x1xf32>) outs([[INIT]] : tensor<2x3xf32>) { + // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#map0] + // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape %arg1 [#map0] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) { // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 // CHECK: linalg.yield [[ELEMENT]] : f32 @@ -472,28 +475,30 @@ // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_float // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32> func @reduce_float(%arg0: tensor<5x4xf32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] // CHECK: [[CST0:%.+]] = constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<1x4xf32>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32 // CHECK: linalg.yield [[RES]] : f32 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<4xf32> into tensor<1x4xf32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] // CHECK: [[CST0:%.+]] = constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5x1xf32>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>) // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32 // CHECK: linalg.yield [[RES]] : f32 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<5xf32> into tensor<5x1xf32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32> // CHECK: constant 1.0 @@ -521,28 +526,30 @@ // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: @reduce_int // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32> func @reduce_int(%arg0: tensor<5x4xi32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] // CHECK: [[CST0:%.+]] = constant 0 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<1x4xi32>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>) // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32 // CHECK: linalg.yield [[RES]] : i32 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<4xi32> into tensor<1x4xi32> %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5] // CHECK: [[CST0:%.+]] = constant 0 // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5x1xi32>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>) // CHECK: ^bb0(%arg1: i32, %arg2: i32) // CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32 // CHECK: linalg.yield [[RES]] : i32 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<5xi32> into tensor<5x1xi32> %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32> // CHECK: constant 1 @@ -570,18 +577,19 @@ // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-LABEL: @reduce_bool // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi1> func @reduce_bool(%arg0: tensor<5x4xi1>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [4] // CHECK: [[CST0:%.+]] = constant true // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]]) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<1x4xi1>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<4xi1>) // CHECK: ^bb0(%arg1: i1, %arg2: i1) // CHECK: [[RES:%.+]] = and %arg1, %arg2 : i1 // CHECK: linalg.yield [[RES]] : i1 + // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<4xi1> into tensor<1x4xi1> %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1> // CHECK: constant false @@ -636,14 +644,45 @@ // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> // CHECK-LABEL: @rescale -func @rescale(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) { - // CHECK: [[C0:%.+]] = constant dense<19689> - // CHECK: [[C1:%.+]] = constant dense<15> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[C0]], [[C1]] : tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) outs([[INIT]] : tensor<1xi8>) +func @rescale(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { + // CHECK: [[C0:%.+]] = constant 19689 + // CHECK: [[C1:%.+]] = constant 15 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) + // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): + // CHECK: [[C243:%.+]] = constant 243 + // CHECK: [[C252:%.+]] = constant 252 + + // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]] + // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]] + // CHECK-DAG: [[CMIN:%.+]] = constant -128 + // CHECK-DAG: [[CMAX:%.+]] = constant 127 + // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]] + // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]] + // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]] + // CHECK-DAG: linalg.yield [[TRUNC]] + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) + + // CHECK: return [[GENERIC]] + return %0 : tensor<2xi8> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @rescale_per_channel +func @rescale_per_channel(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { + // CHECK: [[MULTIPLIERS:%.+]] = constant dense<[42, 43]> + // CHECK: [[SHIFTS:%.+]] = constant dense<[14, 15]> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8): // CHECK: [[C243:%.+]] = constant 243 // CHECK: [[C252:%.+]] = constant 252 @@ -660,28 +699,30 @@ // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]] // CHECK-DAG: linalg.yield [[TRUNC]] - %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) // CHECK: return [[GENERIC]] - return %0 : tensor<1xi8> + return %0 : tensor<2xi8> } +// ----- + // CHECK-LABEL: @rescaleDoubleRound -func @rescaleDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) { +func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { // CHECK: linalg.generic // CHECK: "tosa.apply_scale" // CHECK-SAME: {double_round = true} - %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) - return %0 : tensor<1xi8> + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) + return %0 : tensor<2xi8> } // CHECK-LABEL: @rescaleUnnecessaryDoubleRound -func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) { +func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { // CHECK: linalg.generic // CHECK: "tosa.apply_scale" // CHECK-SAME: {double_round = false} - %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) - return %0 : tensor<1xi8> + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) + return %0 : tensor<2xi8> } // ----- @@ -708,32 +749,29 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)> -// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> // CHECK: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> // CHECK: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-LABEL: @tile func @tile(%arg0 : tensor<2x3xi8>) -> () { - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>) // CHECK: linalg.yield %arg1 : i8 - // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP0]], #[[$MAP1]]] + // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP2]], #[[$MAP3]]] %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>) - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]] %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>) - // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8> // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: linalg.yield %arg1 : i8 // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]] %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>)