diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -301,6 +301,9 @@ LogicalResult verify(); }; +/// Extracts int64_t from the assumedArrayAttr of IntegerAttr. +SmallVector extractFromI64ArrayAttr(Attribute attr); + /// Prints dimension and symbol list. void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, diff --git a/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @print_memref_f32(memref<*xf32>) + +func @matmul(%A: memref, %B: memref) -> (memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %f0 = constant 0.0 : f32 + %x = dim %A, %c0 : memref + %y = dim %B, %c1 : memref + %C = alloc(%x, %y) : memref + linalg.fill(%C, %f0) : memref, f32 + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return %C : memref +} + +func @matvec(%A: memref, %B: memref) -> (memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %f0 = constant 0.0 : f32 + %m = dim %A, %c0 : memref + %x = dim %A, %c1 : memref + %n = dim %B, %c1 : memref + %C = alloc(%m, %n) : memref + linalg.fill(%C, %f0) : memref, f32 + scf.for %i = %c0 to %n step %c1 { + %b = subview %B[0, %i][%x, 1][1, 1] : memref to memref + %c = subview %C[0, %i][%m, 1][1, 1] : memref to memref + linalg.matvec ins(%A, %b: memref, memref) + outs(%c: memref) + } + return %C : memref +} + +func @main() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %m = constant 5 : index + %x = constant 3 : index + %n = constant 2 : index + %val1 = constant 13.0 : f32 + %val2 = constant 17.0 : f32 + %A = alloc(%m, %x) : memref + %B = alloc(%x, %n) : memref + linalg.fill(%A, %val1) : memref, f32 + linalg.fill(%B, %val2) : memref, f32 + store %val1, %B[%c0, %c0] : memref + %C1 = call @matmul(%A, %B) : (memref, memref) -> memref + %C2 = call @matvec(%A, %B) : (memref, memref) -> memref + scf.for %i = %c0 to %m step %c1 { + scf.for %j = %c0 to %n step %c1 { + %e1 = load %C1[%i, %j] : memref + %e2 = load %C2[%i, %j] : memref + %c = cmpf "oeq", %e1, %e2 : f32 + assert %c, "Matmul does not produce same output as matvec" + } + } + %C2_ = memref_cast %C2 : memref to memref<*xf32> + call @print_memref_f32(%C2_) : (memref<*xf32>) -> () + return +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [5, 2] strides = [2, 1] data = +// CHECK-NEXT: [ +// CHECK-SAME: [611, 663], +// CHECK-NEXT: [611, 663], +// CHECK-NEXT: [611, 663], +// CHECK-NEXT: [611, 663], +// CHECK-NEXT: [611, 663] +// CHECK-SAME: ] diff --git a/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \ +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @print_memref_f32(memref<*xf32>) + +func @main() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %f0 = constant 0.0 : f32 + %f1 = constant 1.0 : f32 + %f2 = constant 2.0 : f32 + %f3 = constant 3.0 : f32 + %A = alloc(%c2, %c2) : memref + store %f0, %A[%c0, %c0] : memref + store %f1, %A[%c0, %c1] : memref + store %f2, %A[%c1, %c0] : memref + store %f3, %A[%c1, %c1] : memref + %B = subview %A[%c1, 0][1, %c2][1, 1] : memref to memref + %C = subview %A[0, %c1][%c2, 1][1, 1] : memref to memref + %A_ = memref_cast %A : memref to memref<*xf32> + call @print_memref_f32(%A_) : (memref<*xf32>) -> () + %B_ = memref_cast %B : memref to memref<*xf32> + call @print_memref_f32(%B_) : (memref<*xf32>) -> () + %C_ = memref_cast %C : memref to memref<*xf32> + call @print_memref_f32(%C_) : (memref<*xf32>) -> () + return +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [2, 2] strides = [2, 1] data = +// CHECK-NEXT: [ +// CHECK-SAME: [0, 1], +// CHECK-NEXT: [2, 3] +// CHECK-SAME: ] +// CHECK: [2, 3] +// CHECK: [1, 3] diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2948,6 +2948,12 @@ .dyn_cast_or_null(); auto viewMemRefType = subViewOp.getType(); + auto inferredType = SubViewOp::inferResultType( + subViewOp.getSourceType(), + extractFromI64ArrayAttr(subViewOp.static_offsets()), + extractFromI64ArrayAttr(subViewOp.static_sizes()), + extractFromI64ArrayAttr(subViewOp.static_strides())) + .cast(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); @@ -2959,7 +2965,7 @@ // Extract the offset and strides from the type. int64_t offset; SmallVector strides; - auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + auto successStrides = getStridesAndOffset(inferredType, strides, offset); if (failed(successStrides)) return failure(); @@ -2983,10 +2989,20 @@ extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); + auto shape = viewMemRefType.getShape(); + auto inferredShape = inferredType.getShape(); + size_t inferredShapeRank = inferredShape.size(); + size_t resultShapeRank = shape.size(); + SmallVector mask(inferredShapeRank); + for (unsigned i = 0, j = 0; i < inferredShapeRank && j < resultShapeRank; + ++i) + if ((mask[i] = inferredShape[i] == shape[j])) + j++; + // Extract strides needed to compute offset. SmallVector strideValues; - strideValues.reserve(viewMemRefType.getRank()); - for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) + strideValues.reserve(inferredShapeRank); + for (unsigned i = 0; i < inferredShapeRank; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. @@ -2995,7 +3011,7 @@ targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); - for (unsigned i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + for (unsigned i = 0; i < inferredShapeRank; ++i) { Value offset = subViewOp.isDynamicOffset(i) ? operands[subViewOp.getIndexOfDynamicOffset(i)] @@ -3009,14 +3025,18 @@ } // Update sizes and strides. - for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; + i >= 0 && j >= 0; --i) { + if (!mask[i]) + continue; + Value size = subViewOp.isDynamicSize(i) ? operands[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); - targetMemRef.setSize(rewriter, loc, i, size); + targetMemRef.setSize(rewriter, loc, j, size); Value stride; if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { stride = rewriter.create( @@ -3030,7 +3050,8 @@ rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } - targetMemRef.setStride(rewriter, loc, i, stride); + targetMemRef.setStride(rewriter, loc, j, stride); + j--; } rewriter.replaceOp(op, {targetMemRef}); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2815,8 +2815,7 @@ return success(); } -/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { +SmallVector mlir::extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { return a.cast().getInt();