diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -816,6 +816,9 @@ ConversionPatternRewriter &rewriter) const final; }; +/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is +/// stored in memory. A linalg.reshape is introduced to convert to the desired +/// n-D buffer form. class TensorConstantOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -827,6 +830,7 @@ ConversionPatternRewriter &rewriter) const final; }; +/// TensorCastOp converts 1-1 to MemRefCastOp. class TensorCastOpConverter : public BufferAssignmentOpConversionPattern { public: diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir @@ -4,12 +4,13 @@ // RUN: | FileCheck %s func @main() { - %A = constant dense<[[1.0, 2.0], [4.0, 5.0]]> : tensor<2x2xf32> + %A = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> %B = constant dense<[[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0]]> : tensor<2x4xf32> + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32> %C = constant dense<1000.0> : tensor<2x4xf32> - %D = linalg.matmul ins(%A, %B: tensor<2x2xf32>, tensor<2x4xf32>) + %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>) init(%C: tensor<2x4xf32>) -> tensor<2x4xf32> %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32> @@ -17,10 +18,11 @@ // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} // CHECK-SAME: rank = 2 offset = 0 sizes = [2, 4] strides = [4, 1] data = - // CHECK-NEXT: [1011, 1014, 1017, 1020] - // CHECK-NEXT: [1029, 1038, 1047, 1056] + // CHECK-NEXT: [1038, 1044, 1050, 1056] + // CHECK-NEXT: [1083, 1098, 1113, 1128] return } func @print_memref_f32(%ptr : tensor<*xf32>) + diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -239,22 +239,43 @@ LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite( ConstantOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!op.getType().isa()) + RankedTensorType rankedTensorType = op.getType().dyn_cast(); + if (!rankedTensorType) + return failure(); + if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { + return s == 0 || ShapedType::isDynamic(s); + })) return failure(); - auto attr = op.getValue().cast(); - Location loc = op.getLoc(); + int64_t nElements = 1; + for (int64_t s : rankedTensorType.getShape()) + nElements *= s; + Type elementType = rankedTensorType.getElementType(); MemRefType memrefType = converter.convertType(op.getType()).cast(); - VectorType vectorType = - VectorType::get(memrefType.getShape(), memrefType.getElementType()); - Value cstVec = - rewriter.create(loc, vectorType, attr.reshape(vectorType)); + VectorType flatVectorType = VectorType::get({nElements}, elementType); + MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); + MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); - MemRefType memrefOfVectorType = MemRefType::get({}, vectorType); - Value alloc = rewriter.create(loc, memrefOfVectorType, ValueRange{}); + Location loc = op.getLoc(); + auto attr = op.getValue().cast(); + Value alloc = + rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); + Value cstVec = rewriter.create(loc, flatVectorType, + attr.reshape(flatVectorType)); rewriter.create(loc, cstVec, alloc); - rewriter.replaceOpWithNewOp(op, memrefType, alloc); + + Value memref = + rewriter.create(loc, flatMemrefType, alloc); + if (rankedTensorType.getRank() > 1) { + // Introduce a linalg.reshape to flatten the memref. + AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( + /*numDims=*/rankedTensorType.getRank(), op.getContext()); + memref = rewriter.create( + loc, memrefType, memref, + rewriter.getAffineMapArrayAttr(collapseAllDims)); + } + rewriter.replaceOp(op, memref); return success(); } diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir --- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir +++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir @@ -126,28 +126,29 @@ // ----- -func @foo() -> tensor<4xf32> { +func @foo() -> tensor<2x3xf32> { // CHECK-LABEL: func @foo( -// CHECK-SAME: %[[A:[0-9a-z]*]]: memref<4xf32>) { - - %0 = constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> -// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<4xf32> -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref> -// CHECK-NEXT: store %[[CST]], %[[ALLOC]][] : memref> -// CHECK-NEXT: %[[RES:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<4xf32> - - return %0 : tensor<4xf32> -// CHECK-NEXT: linalg.copy(%[[RES]], %[[A]]) : memref<4xf32>, memref<4xf32> -// CHECK-NEXT: dealloc %[[ALLOC]] : memref> +// CHECK-SAME: %[[A:[0-9a-z]*]]: memref<2x3xf32>) { + + %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref> +// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32> +// CHECK-NEXT: store %[[CST]], %[[ALLOC]][] : memref> +// CHECK-NEXT: %[[FLAT:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<6xf32> +// CHECK-NEXT: %[[RES:.*]] = linalg.reshape %[[FLAT]] {{.*}} : memref<6xf32> into memref<2x3xf32> + + return %0 : tensor<2x3xf32> +// CHECK-NEXT: linalg.copy(%[[RES]], %[[A]]) : memref<2x3xf32>, memref<2x3xf32> +// CHECK-NEXT: dealloc %[[ALLOC]] : memref> // CHECK-NEXT: return } func @bar() { // CHECK-LABEL: func @bar() { - %0 = call @foo() : () -> tensor<4xf32> -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<4xf32> -// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<4xf32>) -> () + %0 = call @foo() : () -> tensor<2x3xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<2x3xf32> +// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<2x3xf32>) -> () // Instead of relying on tensor_store which introduces aliasing, we rely on // the conversion of print_memref_f32(tensor<*xf32>) to @@ -155,15 +156,15 @@ // Note that this is skipping a step and we would need at least some function // attribute to declare that this conversion is valid (e.g. when we statically // know that things will play nicely at the C ABI boundary). - %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32> + %unranked = tensor_cast %0 : tensor<2x3xf32> to tensor<*xf32> // CHECK-NEXT: %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] : -// CHECK-SAME: memref<4xf32> to memref<*xf32> +// CHECK-SAME: memref<2x3xf32> to memref<*xf32> call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () // CHECK-NEXT: call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> () return -// CHECK-NEXT: dealloc %[[ALLOC]] : memref<4xf32> +// CHECK-NEXT: dealloc %[[ALLOC]] : memref<2x3xf32> // CHECK-NEXT: return }