diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -818,13 +818,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef operands) { // Fold producer-consumer reshape ops that where the operand type of the - // producer is same as the return type of the consumer. This can only be - // verified if the shapes in question are static. + // producer is same as the return type of the consumer. ReshapeOpTy reshapeSrcOp = reshapeOp.src().template getDefiningOp(); - if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() && - reshapeOp.getResultType().hasStaticShape() && - reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) + if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) return reshapeSrcOp.src(); // Reshape of a constant can be replaced with a new constant. if (auto elements = operands.front().dyn_cast_or_null()) { @@ -1028,6 +1025,57 @@ Value mlir::linalg::ReshapeOp::getViewSource() { return src(); } +/// Verify that shapes of the reshaped types using following rules +/// 1) if a dimension in the collapsed type is static, then the corresponding +/// dimensions in the the expanded shape should be +/// a) static +/// b) the product should be same as the collaped shape. +/// 2) if a dimension in the collapes type is dynamic, one and only one of the +/// corresponding dimensions in the expanded type should be dynamic. This +/// rule is only needed with reshape operations that are expanding. +template +static LogicalResult verifyReshapeLikeShapes(Op op, ShapedType collapsedType, + ShapedType expandedType, + bool isExpandingReshape) { + ArrayRef collapsedShape = collapsedType.getShape(); + ArrayRef expandedShape = expandedType.getShape(); + unsigned expandedDimStart = 0; + for (auto map : llvm::enumerate(op.getReassociationMaps())) { + Optional dynamicDims; + int64_t linearizedStaticShape = 1; + for (auto dim : llvm::enumerate(expandedShape.slice( + expandedDimStart, map.value().getNumResults()))) { + if (ShapedType::isDynamic(dim.value())) { + if (isExpandingReshape && dynamicDims) { + return op->emitOpError("invalid to have a single dimension (") + << map.index() << ") expanded into multiple dynamic dims (" + << expandedDimStart + dynamicDims.getValue() << "," + << expandedDimStart + dim.index() << ")"; + } + dynamicDims = dim.index(); + } else { + linearizedStaticShape *= dim.value(); + } + } + if (dynamicDims) { + if (!ShapedType::isDynamic(collapsedShape[map.index()])) { + return op->emitOpError("expected dimension ") + << map.index() + << " of collapsed type to be dynamic since one or more of the " + "corresponding dimensions in the expanded type is dynamic"; + } + } else { + if (collapsedShape[map.index()] != linearizedStaticShape) { + return op->emitOpError("expected dimension ") + << map.index() << " of collapsed type to be static value of " + << linearizedStaticShape << " "; + } + } + expandedDimStart += map.value().getNumResults(); + } + return success(); +} + // Common verifier for reshape-like types. Fills `expandedType` and // `collapsedType` with the proper `src` or `result` type. template @@ -1071,7 +1119,7 @@ if (!isReassociationValid(maps, &invalidIdx)) return op.emitOpError("expected reassociation map #") << invalidIdx << " to be valid and contiguous"; - return success(); + return verifyReshapeLikeShapes(op, collapsedType, expandedType, !isCollapse); } static LogicalResult verify(ReshapeOp op) { @@ -1150,8 +1198,6 @@ if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) return failure(); auto maps = getAffineMaps(op.reassociation()); - // TODO: expanding a ? with a non-constant is under-specified. Error - // out. RankedTensorType expectedType = computeTensorReshapeCollapsedType(expandedType, maps); if (collapsedType != expectedType) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -43,8 +43,6 @@ // ----- -// ----- - func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) -> tensor { %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : @@ -71,18 +69,18 @@ // ----- -func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor +func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : - tensor into tensor + tensor into tensor %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] : - tensor into tensor - return %1 : tensor + tensor into tensor + return %1 : tensor } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> @@ -113,18 +111,18 @@ // ----- -func @expanding_memref_reshapes(%arg0 : memref) -> memref +func @expanding_memref_reshapes(%arg0 : memref) -> memref { %0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : - memref into memref + memref into memref %1 = linalg.reshape %0 [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] : - memref into memref - return %1 : memref + memref into memref + return %1 : memref } // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> @@ -178,21 +176,20 @@ // ----- -func @no_fold_tensor_reshape(%arg0 : tensor) -> tensor +func @fold_tensor_reshape_dynamic(%arg0 : tensor) -> tensor { %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : - tensor into tensor + tensor into tensor %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : - tensor into tensor + tensor into tensor return %1 : tensor } -// CHECK-LABEL: @no_fold_tensor_reshape -// CHECK: linalg.tensor_reshape -// CHECK: linalg.tensor_reshape +// CHECK-LABEL: @fold_tensor_reshape_dynamic +// CHECK-NOT: linalg.tensor_reshape // ----- @@ -213,21 +210,20 @@ // ----- -func @no_fold_memref_reshape(%arg0 : memref) -> memref +func @fold_memref_reshape_dynamic(%arg0 : memref) -> memref { %0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : - memref into memref + memref into memref %1 = linalg.reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : - memref into memref + memref into memref return %1 : memref } -// CHECK-LABEL: @no_fold_memref_reshape -// CHECK: linalg.reshape -// CHECK: linalg.reshape +// CHECK-LABEL: @fold_memref_reshape_dynamic +// CHECK-NOT: linalg.reshape // ----- diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -409,3 +409,211 @@ -> tensor return } + + +// ----- + +func @init_tensor_err(%arg0 : index, %arg1 : index) +{ + // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}} + %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32> + return +} + +// ----- + +func @init_tensor_err(%arg0 : index) +{ + // expected-error @+1 {{expected 4 sizes values}} + %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32> + return +} + +// ----- + +func @init_tensor_err(%arg0 : index) +{ + // expected-error @+1 {{expected 2 dynamic sizes values}} + %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32> + return +} + +// ----- + +func @illegal_expanding_reshape_dynamic_tensor + (%arg0: tensor) -> tensor +{ + // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_expanding_reshape_dynamic_memref + (%arg0: memref) -> memref +{ + // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + memref into memref + return %0 : memref +} + +// ----- + +func @illegal_expanding_reshape_static_tensor + (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> +{ + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32> + return %0 : tensor<2x3x2x4x5xf32> +} + +// ----- + +func @illegal_collapsing_reshape_static_tensor + (%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32> +{ + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32> + return %0 : tensor<2x3x20xf32> +} + +// ----- + +func @illegal_expanding_reshape_static_memref + (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> +{ + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + memref<2x3x20xf32> into memref<2x3x2x4x5xf32> + return %0 : memref<2x3x2x4x5xf32> +} + +// ----- + +func @illegal_collapsing_reshape_static_memref + (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> +{ + // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + memref<2x3x2x4x5xf32> into memref<2x3x20xf32> + return %0 : memref<2x3x20xf32> +} + +// ----- + +func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor) -> tensor +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] : + tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) -> tensor +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor) -> tensor +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] : + tensor into tensor + return %0 : tensor +} + +// ----- + +func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref) -> memref +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref into memref + return %0 : memref +} + +// ----- + +func @illegal_collapsing_reshape_mixed_memref_2(%arg0 : memref) -> memref +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] : + memref into memref + return %0 : memref +} + +// ----- + +func @illegal_expanding_reshape_mixed_memref(%arg0 : memref) -> memref +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref into memref + return %0 : memref +} + +// ----- + +func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref) -> memref +{ + // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>] : + memref into memref + return %0 : memref +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file %s | FileCheck %s // TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered. // @@ -621,7 +621,7 @@ memref into memref %r0 = linalg.reshape %0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : - memref into memref + memref into memref %1 = linalg.reshape %arg1 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into @@ -629,7 +629,7 @@ %r1 = linalg.reshape %1 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into - memref + memref %2 = linalg.reshape %arg2 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into @@ -637,7 +637,7 @@ %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] : memref into - memref + memref return } @@ -648,15 +648,15 @@ // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] // CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] // CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] // CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref into memref func @named_ops(%a3: memref, %b3: memref, %c3: memref, %ta3: tensor, %tb3: tensor, %tc3: tensor) @@ -720,27 +720,36 @@ // ----- -func @init_tensor_err(%arg0 : index, %arg1 : index) +func @legal_collapsing_reshape_dynamic_tensor + (%arg0: tensor) -> tensor { - // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}} - %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32> - return -} - -// ----- - -func @init_tensor_err(%arg0 : index) -{ - // expected-error @+1 {{expected 4 sizes values}} - %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32> - return + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + tensor into tensor + return %0 : tensor } +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +// CHECK: func @legal_collapsing_reshape_dynamic_tensor +// CHECK: linalg.tensor_reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // ----- -func @init_tensor_err(%arg0 : index) +func @legal_collapsing_reshape_dynamic_memref + (%arg0: memref) -> memref { - // expected-error @+1 {{expected 2 dynamic sizes values}} - %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32> - return -} + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] : + memref into memref + return %0 : memref +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +// CHECK: func @legal_collapsing_reshape_dynamic_memref +// CHECK: linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]