diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -361,6 +361,12 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { + // Skip the pattern if the op has any tensor with special encoding. + if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { + auto tensorType = type.dyn_cast(); + return tensorType && tensorType.getEncoding() != nullptr; + })) + return failure(); MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -796,3 +796,34 @@ // CHECK: linalg.yield %[[ARG]] : f32 // CHECK: } // CHECK: return %[[ARG2]] : memref + +// ----- + +// Negative test for case with tensor encoding. +#matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + iterator_types = ["parallel", "reduction"] +} + +#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> + +func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %0 = linalg.init_tensor [8] : tensor<8xf32> + %1 = linalg.generic #matvec + ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>) + outs(%0: tensor<8xf32>) { + ^bb(%a: f32, %b: f32, %x: f32): + %m = mulf %a, %b : f32 + %add = addf %x, %m : f32 + linalg.yield %add : f32 + } -> tensor<8xf32> + return %1: tensor<8xf32> +} + +// CHECK-LABEL: func @sparse_case +// CHECK-NEXT: linalg.init_tensor +// CHECK-NEXT: linalg.generic