diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -301,6 +301,18 @@ LogicalResult verify(); }; +/// Given an `originalShape` and a `reducedShape` assumed to be a subset of +/// `originalShape` with some `1` entries erased, return the vector of booleans +/// that specifies which of the entries of `originalShape` are keep to obtain +/// `reducedShape`. The returned mask can be applied as a projection to +/// `originalShape` to obtain the `reducedShape`. This mask is useful to track +/// which dimensions must be kept when e.g. compute MemRef strides under +/// rank-reducing operations. Return None if reducedShape cannot be obtained +/// by dropping only `1` entries in `originalShape`. +llvm::Optional> +computeRankReductionMask(ArrayRef originalShape, + ArrayRef reducedShape); + /// Prints dimension and symbol list. void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, diff --git a/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @print_memref_f32(memref<*xf32>) + +func @matmul(%A: memref, %B: memref) -> (memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %f0 = constant 0.0 : f32 + %x = dim %A, %c0 : memref + %y = dim %B, %c1 : memref + %C = alloc(%x, %y) : memref + linalg.fill(%C, %f0) : memref, f32 + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return %C : memref +} + +func @matvec(%A: memref, %B: memref) -> (memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %f0 = constant 0.0 : f32 + %m = dim %A, %c0 : memref + %x = dim %A, %c1 : memref + %n = dim %B, %c1 : memref + %C = alloc(%m, %n) : memref + linalg.fill(%C, %f0) : memref, f32 + scf.for %i = %c0 to %n step %c1 { + %b = subview %B[0, %i][%x, 1][1, 1] : memref to memref + %c = subview %C[0, %i][%m, 1][1, 1] : memref to memref + linalg.matvec ins(%A, %b: memref, memref) + outs(%c: memref) + } + return %C : memref +} + +func @main() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %m = constant 5 : index + %x = constant 3 : index + %n = constant 2 : index + %val1 = constant 13.0 : f32 + %val2 = constant 17.0 : f32 + %A = alloc(%m, %x) : memref + %B = alloc(%x, %n) : memref + linalg.fill(%A, %val1) : memref, f32 + linalg.fill(%B, %val2) : memref, f32 + store %val1, %B[%c0, %c0] : memref + %C1 = call @matmul(%A, %B) : (memref, memref) -> memref + %C2 = call @matvec(%A, %B) : (memref, memref) -> memref + scf.for %i = %c0 to %m step %c1 { + scf.for %j = %c0 to %n step %c1 { + %e1 = load %C1[%i, %j] : memref + %e2 = load %C2[%i, %j] : memref + %c = cmpf "oeq", %e1, %e2 : f32 + assert %c, "Matmul does not produce same output as matvec" + } + } + %C2_ = memref_cast %C2 : memref to memref<*xf32> + call @print_memref_f32(%C2_) : (memref<*xf32>) -> () + return +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [5, 2] strides = [2, 1] data = +// CHECK-NEXT: [ +// CHECK-SAME: [611, 663], +// CHECK-NEXT: [611, 663], +// CHECK-NEXT: [611, 663], +// CHECK-NEXT: [611, 663], +// CHECK-NEXT: [611, 663] +// CHECK-SAME: ] diff --git a/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @print_memref_f32(memref<*xf32>) + +func @main() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %f0 = constant 0.0 : f32 + %f1 = constant 1.0 : f32 + %f2 = constant 2.0 : f32 + %f3 = constant 3.0 : f32 + %A = alloc(%c2, %c2) : memref + store %f0, %A[%c0, %c0] : memref + store %f1, %A[%c0, %c1] : memref + store %f2, %A[%c1, %c0] : memref + store %f3, %A[%c1, %c1] : memref + %B = subview %A[%c1, 0][1, %c2][1, 1] : memref to memref + %C = subview %A[0, %c1][%c2, 1][1, 1] : memref to memref + %A_ = memref_cast %A : memref to memref<*xf32> + call @print_memref_f32(%A_) : (memref<*xf32>) -> () + %B_ = memref_cast %B : memref to memref<*xf32> + call @print_memref_f32(%B_) : (memref<*xf32>) -> () + %C_ = memref_cast %C : memref to memref<*xf32> + call @print_memref_f32(%C_) : (memref<*xf32>) -> () + return +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [2, 2] strides = [2, 1] data = +// CHECK-NEXT: [ +// CHECK-SAME: [0, 1], +// CHECK-NEXT: [2, 3] +// CHECK-SAME: ] +// CHECK: [2, 3] +// CHECK: [1, 3] diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2928,6 +2928,14 @@ } }; +/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size @@ -2948,6 +2956,12 @@ .dyn_cast_or_null(); auto viewMemRefType = subViewOp.getType(); + auto inferredType = SubViewOp::inferResultType( + subViewOp.getSourceType(), + extractFromI64ArrayAttr(subViewOp.static_offsets()), + extractFromI64ArrayAttr(subViewOp.static_sizes()), + extractFromI64ArrayAttr(subViewOp.static_strides())) + .cast(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); @@ -2959,7 +2973,7 @@ // Extract the offset and strides from the type. int64_t offset; SmallVector strides; - auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + auto successStrides = getStridesAndOffset(inferredType, strides, offset); if (failed(successStrides)) return failure(); @@ -2983,10 +2997,17 @@ extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); + auto shape = viewMemRefType.getShape(); + auto inferredShape = inferredType.getShape(); + size_t inferredShapeRank = inferredShape.size(); + size_t resultShapeRank = shape.size(); + SmallVector mask = + computeRankReductionMask(inferredShape, shape).getValue(); + // Extract strides needed to compute offset. SmallVector strideValues; - strideValues.reserve(viewMemRefType.getRank()); - for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) + strideValues.reserve(inferredShapeRank); + for (unsigned i = 0; i < inferredShapeRank; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. @@ -2995,7 +3016,7 @@ targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); - for (unsigned i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + for (unsigned i = 0; i < inferredShapeRank; ++i) { Value offset = subViewOp.isDynamicOffset(i) ? operands[subViewOp.getIndexOfDynamicOffset(i)] @@ -3009,14 +3030,18 @@ } // Update sizes and strides. - for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; + i >= 0 && j >= 0; --i) { + if (!mask[i]) + continue; + Value size = subViewOp.isDynamicSize(i) ? operands[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); - targetMemRef.setSize(rewriter, loc, i, size); + targetMemRef.setSize(rewriter, loc, j, size); Value stride; if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { stride = rewriter.create( @@ -3030,7 +3055,8 @@ rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } - targetMemRef.setStride(rewriter, loc, i, stride); + targetMemRef.setStride(rewriter, loc, j, stride); + j--; } rewriter.replaceOp(op, {targetMemRef}); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2823,6 +2823,30 @@ })); } +llvm::Optional> +mlir::computeRankReductionMask(ArrayRef originalShape, + ArrayRef reducedShape) { + size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); + SmallVector mask(originalRank); + unsigned reducedIdx = 0; + for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { + // Skip matching dims greedily. + mask[originalIdx] = + (reducedIdx < reducedRank) && + (originalShape[originalIdx] == reducedShape[reducedIdx]); + if (mask[originalIdx]) + reducedIdx++; + // 1 is the only non-matching allowed. + else if (originalShape[originalIdx] != 1) + return {}; + } + + if (reducedIdx != reducedRank) + return {}; + + return mask; +} + enum SubViewVerificationResult { Success, RankTooLarge, @@ -2859,20 +2883,10 @@ if (reducedRank > originalRank) return SubViewVerificationResult::RankTooLarge; - unsigned reducedIdx = 0; - SmallVector keepMask(originalRank); - for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { - // -2 is never used as a dim size so it will never match. - int reducedVal = reducedIdx < reducedRank ? reducedShape[reducedIdx] : -2; - // Skip matching dims greedily. - if ((keepMask[originalIdx] = originalShape[originalIdx] == reducedVal)) - reducedIdx++; - // 1 is the only non-matching allowed. - else if (originalShape[originalIdx] != 1) - return SubViewVerificationResult::SizeMismatch; - } - // Must match the reduced rank. - if (reducedIdx != reducedRank) + auto optionalMask = computeRankReductionMask(originalShape, reducedShape); + + // Sizes cannot be matched in case empty vector is returned. + if (!optionalMask.hasValue()) return SubViewVerificationResult::SizeMismatch; // We are done for the tensor case. @@ -2885,12 +2899,13 @@ MLIRContext *c = original.getContext(); int64_t originalOffset, reducedOffset; SmallVector originalStrides, reducedStrides, keepStrides; + SmallVector keepMask = optionalMask.getValue(); getStridesAndOffset(original, originalStrides, originalOffset); getStridesAndOffset(reduced, reducedStrides, reducedOffset); // Filter strides based on the mask and check that they are the same // as reduced ones. - reducedIdx = 0; + unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { if (keepMask[originalIdx]) { if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])