diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -183,6 +183,14 @@ return llvm::to_vector<4>(llvm::map_range(reassociation(), [ ](Attribute a) { return a.cast().getValue(); })); } + SmallVector getReassociationExprs() { + return + llvm::to_vector<4>(llvm::map_range(reassociation(), + [](Attribute a) { + return llvm::to_vector<2>( + a.cast().getValue().getResults()); + })); + } }]; let assemblyFormat = [{ $src $reassociation attr-dict `:` type($src) `into` type(results) diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -566,45 +566,6 @@ return RankedTensorType::get(expandedShape, originalType.getElementType()); } -/// Get the value to use for the output in the expanded operation given the -/// `indexingMap` for the output in the original op. Creates an -/// `linalg.init_tensor` operation to materialize the tensor that carries the -/// shape information. This is only used when the tensor_reshape is expanding -/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees -/// that the shape of the output is computable from the shape of the input since -/// at most one of the expanded dims can be dynamic. -static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc, - AffineMap indexingMap, Value result, - const ExpansionInfo &expansionInfo) { - SmallVector dynamicDims; - SmallVector staticDims; - ShapedType resultType = result.getType().cast(); - ArrayRef origShape = resultType.getShape(); - for (AffineExpr expr : indexingMap.getResults()) { - unsigned origDimPos = expr.cast().getPosition(); - bool foundDynamic = false; - int64_t linearizedShape = 1; - for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) { - if (ShapedType::isDynamic(extent)) { - assert(!foundDynamic && - "Expanded dimensions of reshape can have only one dynamic dim"); - staticDims.push_back(ShapedType::kDynamicSize); - foundDynamic = true; - continue; - } - staticDims.push_back(extent); - linearizedShape *= extent; - } - if (ShapedType::isDynamic(origShape[origDimPos])) { - Value origDim = builder.create(loc, result, origDimPos); - dynamicDims.push_back(builder.create( - loc, origDim, builder.create(loc, linearizedShape))); - } - } - return builder.create(loc, dynamicDims, staticDims, - resultType.getElementType()); -} - /// Returns the reassociation maps to use in the `linalg.tensor_reshape` /// operation to convert the operands of the origial operation to operands of /// the expanded operation. The same method is used to compute the @@ -734,8 +695,16 @@ SmallVector outputs; for (auto result : llvm::enumerate(linalgOp.getOutputs())) { AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index()); - outputs.push_back(getOutputValueForExpandedOp( - rewriter, loc, indexingMap, result.value(), expansionInfo)); + RankedTensorType expandedOutputType = + getExpandedType(result.value().getType().cast(), + indexingMap, expansionInfo); + if (expandedOutputType != result.value().getType()) { + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); + outputs.push_back(rewriter.create( + linalgOp.getLoc(), expandedOutputType, result.value(), + reassociation)); + } } // The iterator types of the expanded op are all parallel. @@ -779,47 +748,6 @@ return resultVals; } -static Value -getOutputValueForLinearization(OpBuilder &builder, Location loc, - Value origOutput, - ArrayRef reassociationMaps) { - SmallVector dynamicDims; - SmallVector staticDims; - auto shapedType = origOutput.getType().cast(); - ArrayRef origShape = shapedType.getShape(); - for (auto map : reassociationMaps) { - Optional dynamicDim; - int64_t staticLinearizedShape = 1; - for (AffineDimExpr expr : - llvm::map_range(map.getResults(), [](AffineExpr e) { - return e.cast(); - })) { - unsigned pos = expr.getPosition(); - if (ShapedType::isDynamic(origShape[pos])) { - Value dim = builder.create(loc, origOutput, pos); - if (dynamicDim) { - dynamicDim = builder.create(loc, dynamicDim.getValue(), dim); - } else { - dynamicDim = dim; - } - } else { - staticLinearizedShape *= origShape[pos]; - } - } - if (dynamicDim) { - dynamicDim = builder.create( - loc, dynamicDim.getValue(), - builder.create(loc, staticLinearizedShape)); - dynamicDims.push_back(dynamicDim.getValue()); - staticDims.push_back(ShapedType::kDynamicSize); - } else { - staticDims.push_back(staticLinearizedShape); - } - } - return builder.create(loc, dynamicDims, staticDims, - shapedType.getElementType()); -} - namespace { /// Pattern to fold tensor_reshape op with its consumer by using the source of @@ -973,7 +901,7 @@ reshapeOp.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) - return reshapeOp.emitRemark("fused op indexing map is not affine"); + return producer.emitRemark("fused op indexing map is not affine"); } fusedIndexMaps.back() = modifiedMap; @@ -983,9 +911,8 @@ return reshapeOp.emitRemark("fused op loop bound computation failed"); Location loc = producer.getLoc(); - Value output = - getOutputValueForLinearization(rewriter, loc, producer.getOutputs()[0], - reshapeOp.getReassociationMaps()); + Value output = rewriter.create( + loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs()); LinalgOp fusedOp = createLinalgOpOfSameType( producer, rewriter, loc, reshapeOp.getResultType(), /*inputs=*/producer.getInputs(), diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -14,7 +14,7 @@ indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0, %arg1 : tensor, tensor) - outs(%0 : tensor) { + outs(%0 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 @@ -32,19 +32,12 @@ // CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C2:.+]] = constant 2 : index -// CHECK-DAG: %[[C4:.+]] = constant 4 : index // CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [#[[MAP0]], #[[MAP3]], #[[MAP4]]] -// CHECK-DAG: %[[D0:.+]] = dim %[[T0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = dim %[[T0]], %[[C1]] -// CHECK-DAG: %[[D2:.+]] = dim %[[T0]], %[[C2]] -// CHECK: %[[D3:.+]] = divi_unsigned %[[D0]], %[[C4]] -// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D1]], %[[D2]], %[[D3]], 4] +// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP3]], #[[MAP4]]] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -66,7 +59,7 @@ indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) - outs(%arg0 : tensor) { + outs(%arg0 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 @@ -83,19 +76,14 @@ // CHECK: func @generic_op_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C20:.+]] = constant 20 : index // CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: tensor into tensor // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: tensor into tensor -// CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK: %[[D2:.+]] = divi_unsigned %[[D1]], %[[C20]] -// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D0]], 4, %[[D2]], 5] +// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -132,30 +120,25 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> -// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C2:.+]] = constant 2 : index -// CHECK-DAG: %[[C12:.+]] = constant 12 : index // CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: tensor into tensor<3x4x?x?x2x?xf32> // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [#[[MAP3]], #[[MAP4]]] // CHECK-SAME: tensor into tensor<3x4x?x?xf32> -// CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]] -// CHECK: %[[D1:.+]] = divi_unsigned %[[D0]], %[[C2]] -// CHECK-DAG: %[[D2:.+]] = dim %[[ARG0]], %[[C2]] -// CHECK-DAG: %[[D3:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[D4:.+]] = divi_unsigned %[[D3]], %[[C12]] -// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D1]], 2, %[[D2]], 3, 4, %[[D4]]] +// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP5]], #[[MAP6]], #[[MAP7]]] // CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] +// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>) // CHECK-SAME: outs(%[[T2]] : tensor) @@ -170,18 +153,19 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) -> tensor<8x33x4xf32> { %cst = constant dense<2.000000e+00> : tensor<264x4xf32> - %0 = linalg.generic { + %0 = linalg.init_tensor [264, 4] : tensor<264x4xf32> + %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) - outs(%arg0 : tensor<264x4xf32>) { + outs(%0 : tensor<264x4xf32>) { ^bb0(%arg1: f32, %arg2: f32, %s: f32): // no predecessors %2 = mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 } -> tensor<264x4xf32> - %1 = linalg.tensor_reshape %0 [#map1, #map2] : + %2 = linalg.tensor_reshape %1 [#map1, #map2] : tensor<264x4xf32> into tensor<8x33x4xf32> - return %1 : tensor<8x33x4xf32> + return %2 : tensor<8x33x4xf32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -189,51 +173,54 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_reshape_consumer_static // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK: %[[T0:.+]] = linalg.init_tensor [264, 4] +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> -// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4] : tensor<8x33x4xf32> -// CHECK: %[[T2:.+]] = linalg.generic +// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] +// CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>) -// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>) -// CHECK: return %[[T2]] : tensor<8x33x4xf32> +// CHECK-SAME: ins(%[[T1]] : tensor<8x33x4xf32>) +// CHECK-SAME: outs(%[[T2]] : tensor<8x33x4xf32>) +// CHECK: return %[[T3]] : tensor<8x33x4xf32> // ----- func @scalar_reshape( - %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>, %shape : tensor<10xf32>) - -> tensor<1x10xf32> + %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32> { %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor - %1 = linalg.generic + %1 = linalg.init_tensor [10] : tensor<10xf32> + %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0 : tensor) - outs(%shape : tensor<10xf32>) { + outs(%1 : tensor<10xf32>) { ^bb0(%arg2: f32, %s: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<10xf32> - %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1) -> (d0, d1)>] + %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<10xf32> into tensor<1x10xf32> - return %2 : tensor<1x10xf32> + return %3 : tensor<1x10xf32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()> // CHECK: func @scalar_reshape // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32> // CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] [] // CHECK-SAME: tensor<1xf32> into tensor -// CHECK: %[[T1:.+]] = linalg.init_tensor [1, 10] : tensor<1x10xf32> -// CHECK: %[[T2:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK: %[[T1:.+]] = linalg.init_tensor [10] +// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]] +// CHECK: %[[T3:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: ins(%[[T0]] : tensor) -// CHECK-SAME: outs(%[[T1]] : tensor<1x10xf32>) -// CHECK: return %[[T2]] : tensor<1x10xf32> +// CHECK-SAME: outs(%[[T2]] : tensor<1x10xf32>) +// CHECK: return %[[T3]] : tensor<1x10xf32> // ----- @@ -331,15 +318,16 @@ // ----- func @reshape_as_consumer_permutation - (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>, %shape : tensor<6x4x210xi32>) + (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) -> tensor<2x3x4x5x6x7xi32> { + %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32> %c = linalg.indexed_generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) - outs(%shape : tensor<6x4x210xi32>) { + outs(%shape : tensor<6x4x210xi32>) { ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32): %1 = addi %arg3, %arg4 : i32 %2 = index_cast %arg0 : index to i32 @@ -364,38 +352,43 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)> -// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> -// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP11:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)> +// CHECK-DAG: #[[MAP12:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)> // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> -// CHECK-DAG: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [6, 4, 210] +// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [#[[MAP3]], #[[MAP4]]] -// CHECK: %[[T2:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] -// CHECK: %[[T3:.+]] = linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[MAP7]], #[[MAP8]], #[[MAP9]]] -// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) -// CHECK-SAME: outs(%[[T2]] : tensor<2x3x4x5x6x7xi32>) +// CHECK: %[[T3:.+]] = linalg.tensor_reshape %[[T0]] +// CHECK-SAME: [#[[MAP5]], #[[MAP6]], #[[MAP7]]] +// CHECK: %[[T4:.+]] = linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] +// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) +// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>) // CHECK: ^{{.+}}( // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32, // CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32) -// CHECK-DAG: %[[T4:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]]) -// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]]) -// CHECK-DAG: %[[T6:.+]] = addi %[[ARG8]], %[[ARG9]] -// CHECK: %[[T7:.+]] = index_cast %[[T4]] -// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] -// CHECK: %[[T9:.+]] = index_cast %[[T5]] -// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]] -// CHECK: %[[T11:.+]] = index_cast %[[ARG7]] -// CHECK: %[[T12:.+]] = addi %[[T10]], %[[T11]] +// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP11]](%[[ARG2]], %[[ARG3]]) +// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP12]](%[[ARG4]], %[[ARG5]], %[[ARG6]]) +// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]] +// CHECK: %[[T8:.+]] = index_cast %[[T5]] +// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]] +// CHECK: %[[T10:.+]] = index_cast %[[T6]] +// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]] +// CHECK: %[[T12:.+]] = index_cast %[[ARG7]] +// CHECK: %[[T13:.+]] = addi %[[T11]], %[[T12]] // ----- @@ -466,7 +459,7 @@ indexing_maps = [#map0, #map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) - outs(%arg0 : tensor) { + outs(%arg0 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors %1 = mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 @@ -479,8 +472,10 @@ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> // CHECK: func @generic_op_reshape_consumer_fusion_projected // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor @@ -490,8 +485,11 @@ // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: tensor into tensor -// CHECK: %[[T2:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]] +// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP2]], #[[MAP3]]] +// CHECK: %[[T3:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) -// CHECK: return %[[T2]] : tensor +// CHECK-SAME: outs(%[[T2]] : tensor) +// CHECK: return %[[T3]] : tensor diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s +// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization -verify-diagnostics %s | FileCheck %s #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func @generic_op_reshape_producer_fusion(%arg0 : tensor, @@ -21,14 +21,19 @@ return %1 : tensor } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]] // CHECK-SAME: ins(%[[ARG0]], %{{.+}} : tensor, tensor) -// CHECK-SAME: outs(%{{.+}} : tensor) +// CHECK-SAME: outs(%[[T0]] : tensor) // ----- @@ -52,47 +57,17 @@ return %1 : tensor } - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> // CHECK: func @generic_op_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C20:.+]] = constant 20 : index -// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] -// CHECK: %[[T1:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK: %[[T2:.+]] = muli %[[T1]], %[[C20]] -// CHECK: %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]] +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]] -// CHECK-SAME: outs(%[[T3]] : tensor) - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%arg0 : tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors - %1 = mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, - affine_map<(i, j, k, l) -> (j, k, l)>] : - tensor into tensor - return %1 : tensor -} - -// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion -// CHECK: %[[T0:.+]] = linalg.generic -// CHECK: linalg.tensor_reshape %[[T0]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: outs(%[[T0]] : tensor) // ----- @@ -116,13 +91,19 @@ return %1 : tensor } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @indexed_generic_op_reshape_producer_fusion // CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK: linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]] // CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs(%[[T0]] : tensor) // ----- @@ -144,20 +125,17 @@ return %1 : tensor } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> -// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> +// CHECK: func @indexed_generic_op_reshape_consumer_fusion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C20:.+]] = constant 20 : index -// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] -// CHECK: %[[T1:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK: %[[T2:.+]] = muli %[[T1]], %[[C20]] -// CHECK: %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]] +// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK: linalg.indexed_generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] -// CHECK-SAME: outs(%[[T3]] : tensor) +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK-SAME: outs(%[[T0]] : tensor) // CHECK-NOT: linalg.tensor_reshape // ----- @@ -179,12 +157,12 @@ return %2 : tensor<3x7x5xf32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @generic_op_021_permultation_reshape_producer_fusion // CHECK-NOT: linalg.tensor_reshape // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // ----- @@ -210,7 +188,7 @@ // CHECK: func @generic_op_120_permultation_reshape_producer_fusion // CHECK-NOT: linalg.tensor_reshape // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // ----- @@ -237,7 +215,7 @@ // CHECK: func @generic_op_102_permultation_reshape_producer_fusion // CHECK-NOT: linalg.tensor_reshape // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // ----- @@ -258,10 +236,39 @@ return %2 : tensor<5x21xf32> } - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> // CHECK: func @generic_op_102_permultation_reshape_consumer_fusion -// CHECK-NOT: linalg.tensor_reshape +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32> +// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7] +// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[T0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]]] // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[ARG0]] : tensor<3x5x7xf32>) +// CHECK-SAME: outs(%[[T1]] : tensor<5x21xf32>) + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + // expected-remark @+1 {{fused op indexing map is not affine}} + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %1 = mulf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +}