diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1882,6 +1882,51 @@ }]; } +//===----------------------------------------------------------------------===// +// MemRefVectorCastOp +//===----------------------------------------------------------------------===// + +def MemRefVectorCastOp : Std_Op<"memref_vector_cast", [NoSideEffect]> { + let summary = "memref elemental type vectorizing cast operation"; + let description = [{ + The `memref_vector_cast` operation converts a memref with a non-vector + element type to a memref of a vector element type while not changing + the former memref's base element type, rank, and its dimension sizes along + all but the last dimension. The last dimension size of the source dimension + is divided (floor division) by the vector size to obtain the corresponding + dimension size for the target memref type. The source memref type's last + dimension is expected to be at least vector sized. Examples: + + ``` + %MV = memref_vector_cast %M : memref<8x16xf32> to memref<8x2xvector<8xf32>> + %AV = memref_vector_cast %A : memref to memref> + ``` + }]; + + let arguments = (ins AnyMemRef:$source); + let results = (outs AnyMemRef); + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; + + let verifier = [{ return ::verifyCastOp(*this); }]; + + let extraClassDeclaration = [{ + /// Returns true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a memref_vector_cast is always a memref. + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -2981,7 +3026,7 @@ let extraClassDeclaration = [{ /// The result of a tensor_load is always a tensor. - TensorType getType() { + TensorType getType() { Type resultType = getResult().getType(); if (resultType.isa()) return resultType.cast(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2332,6 +2332,122 @@ } }; +struct MemRefVectorCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto memRefShapeCastOp = cast(op); + MemRefType sourceType = + memRefShapeCastOp.getOperand().getType().cast(); + MemRefType targetType = memRefShapeCastOp.getType(); + if (!isSupportedMemRefType(targetType) || + !isSupportedMemRefType(sourceType)) + return failure(); + + MemRefVectorCastOp::Adaptor transformed(operands); + MemRefDescriptor srcMemRefDesc(transformed.source()); + + Type targetStructType = + typeConverter.convertType(memRefShapeCastOp.getType()); + if (!targetStructType) + return failure(); + Location loc = op->getLoc(); + MemRefDescriptor memRefDescriptor = + MemRefDescriptor::undef(rewriter, loc, targetStructType); + LLVM::LLVMType targetElementPtrType = memRefDescriptor.getElementType(); + + Value srcBuffer = srcMemRefDesc.allocatedPtr(rewriter, loc); + Value targetBuffer = rewriter.create( + loc, targetElementPtrType, ArrayRef(srcBuffer)); + memRefDescriptor.setAllocatedPtr(rewriter, loc, targetBuffer); + + Value srcBufferAligned = srcMemRefDesc.alignedPtr(rewriter, loc); + Value targetBufAligned = rewriter.create( + loc, targetElementPtrType, ArrayRef(srcBufferAligned)); + memRefDescriptor.setAlignedPtr(rewriter, loc, targetBufAligned); + + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(targetType, strides, offset))) + return failure(); + + // Unhandled dynamic offset. + if (offset == MemRefType::getDynamicStrideOrOffset()) + return failure(); + + memRefDescriptor.setOffset(rewriter, loc, + createIndexConstant(rewriter, loc, offset)); + + // Get the sizes of the memref: all but the last one are copied from the + // source memref. If the dimension size was static, the target memref would + // have the same size. + SmallVector sizes; + sizes.reserve(targetType.getRank()); + for (unsigned pos = 0, e = targetType.getRank() - 1; pos < e; ++pos) { + int64_t dimSize = targetType.getDimSize(pos); + if (dimSize == MemRefType::kDynamicSize) + sizes.push_back(srcMemRefDesc.size(rewriter, loc, pos)); + else + sizes.push_back(createIndexConstant(rewriter, loc, dimSize)); + } + + if (targetType.getShape().back() != MemRefType::kDynamicSize) { + // The op is already verified to have the right size for the last + // dimension. + sizes.push_back( + createIndexConstant(rewriter, loc, targetType.getShape().back())); + } else { + // We need to divide the dynamic size on the source by the vector width. + Value vecWidth = createIndexConstant( + rewriter, loc, + targetType.getElementType().cast().getNumElements()); + sizes.push_back(rewriter.create( + loc, srcMemRefDesc.size(rewriter, loc, sourceType.getRank() - 1), + vecWidth)); + } + + assert(!sizes.empty() && "target memref rank can't be zero"); + + // Compute the total number of memref elements. + Value cumulativeSize = sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSize = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); + + // Calculate the strides. + Value runningStride = nullptr; + // Iterate strides in reverse order, compute runningStride and strideValues. + unsigned nStrides = strides.size(); + SmallVector strideValues(nStrides, nullptr); + for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { + int64_t index = nStrides - 1 - indexedStride.index(); + if (strides[index] == MemRefType::getDynamicStrideOrOffset()) + // Identity layout map is enforced in the match function, so we compute: + // `runningStride *= sizes[index + 1]`. + runningStride = runningStride + ? rewriter.create(loc, runningStride, + sizes[index + 1]) + : createIndexConstant(rewriter, loc, 1); + else + runningStride = createIndexConstant(rewriter, loc, strides[index]); + strideValues[index] = runningStride; + } + + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(sizes)) { + int64_t index = indexedSize.index(); + memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); + } + + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); + } +}; + struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3337,6 +3453,7 @@ DimOpLowering, LoadOpLowering, MemRefCastOpLowering, + MemRefVectorCastOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1925,6 +1925,68 @@ return impl::foldCastOp(*this); } +//===----------------------------------------------------------------------===// +// MemRefVectorCastOp +//===----------------------------------------------------------------------===// + +bool MemRefVectorCastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + + if (!aT || !bT) + return false; + + if (aT.getAffineMaps() != bT.getAffineMaps()) + return false; + + if (aT.getMemorySpace() != bT.getMemorySpace()) + return false; + + // With rank 0, there is no vec cast. + if (aT.getRank() == 0) + return false; + + if (aT.getRank() != bT.getRank()) + return false; + + // Should have the same shape up until the last n-1 dimensions. + if (!std::equal(aT.getShape().begin(), std::prev(aT.getShape().end()), + bT.getShape().begin())) + return false; + + // The source memref can't have a vector elemental type. + if (auto shapedEltType = aT.getElementType().dyn_cast()) + return false; + + // The destination memref elt type has be a vector type. + auto vectorEltTypeB = bT.getElementType().dyn_cast(); + if (!vectorEltTypeB) + return false; + + auto eltA = aT.getElementType(); + auto eltB = vectorEltTypeB.getElementType(); + if (eltA != eltB) + return false; + + int64_t lastDimA = aT.getShape().back(); + int64_t lastDimB = bT.getShape().back(); + + // If one of them is dynamic but not the other, they are incompatible. + if (lastDimA * lastDimB < 0) + return false; + + // The last dim of the target should be of the right size. + if (lastDimB != MemRefType::kDynamicSize && + lastDimA / vectorEltTypeB.getNumElements() != lastDimB) + return false; + + return true; +} + +OpFoldResult MemRefVectorCastOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -393,6 +393,34 @@ return } +// CHECK-LABEL: func @memref_vector_cast_dynamic +func @memref_vector_cast_dynamic(%M : memref) -> memref> { +// CHECK: %[[SRC:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, +// CHECK-NEXT: %[[U:.*]] = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr> +// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr> +// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFFSET:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[OFFSET]], %{{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[size1:.*]] = llvm.extractvalue %[[SRC]][3, 0] +// CHECK-NEXT: %[[VEC_WIDTH:.*]] = llvm.mlir.constant(16 : index) +// CHECK-NEXT: %[[src_size2:.*]] = llvm.extractvalue %[[SRC]][3, 1] +// CHECK-NEXT: %[[size2:.*]] = llvm.udiv %[[src_size2]], %[[VEC_WIDTH]] +// CHECK-NEXT: %[[cum_size:.*]] = llvm.mul %[[size1]], %[[size2]] +// CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) +// CHECK-NEXT: %[[st1:.*]] = llvm.mul %[[st2]], %[[size2]] +// CHECK-NEXT: llvm.insertvalue %[[size1]], %{{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.insertvalue %[[size2]], %{{.*}}[3, 1] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.insertvalue %[[st2]], %{{.*}}[4, 1] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.return %{{.*}} : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> + %MV = memref_vector_cast %M : memref to memref> + return %MV : memref> +} + // CHECK-LABEL: func @mixed_memref_dim func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { // CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -250,6 +250,34 @@ // ----- +// CHECK-LABEL: func @memref_vector_cast +func @memref_vector_cast(%M : memref<42x16xf32>) -> memref<42x4xvector<4xf32>> { +// CHECK: %[[SRC:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] +// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr> +// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr> +// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[size1:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[size2:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK-NEXT: %[[cumul_size:.*]] = llvm.mul %[[size1]], %[[size2]] +// CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(4 : index) +// CHECK-NEXT: llvm.insertvalue %[[size1]], %{{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.insertvalue %[[size2]], %{{.*}}[3, 1] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.insertvalue %[[st2]], %{{.*}}[4, 1] : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: llvm.return %{{.*}} : !llvm.struct<(ptr>, ptr>, i64, array<2 x i64>, array<2 x i64>)> + %MV = memref_vector_cast %M : memref<42x16xf32> to memref<42x4xvector<4xf32>> + return %MV : memref<42x4xvector<4xf32>> +} + +// ----- + // CHECK-LABEL: func @zero_d_load // BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm.ptr) -> !llvm.float func @zero_d_load(%arg0: memref) -> f32 { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -715,6 +715,15 @@ return } +// CHECK-LABEL: func @memref_vector_cast +func @memref_vector_cast(%arg0: memref<60x8xf32>, %arg1: memref) -> memref<60x2xvector<4xf32>> { + // CHECK: {{.*}} = memref_vector_cast %arg0 : memref<60x8xf32> to memref<60x2xvector<4xf32>> + %0 = memref_vector_cast %arg0 : memref<60x8xf32> to memref<60x2xvector<4xf32>> + // CHECK: memref_vector_cast %arg1 : memref to memref> + %1 = memref_vector_cast %arg1 : memref to memref> + return %0 : memref<60x2xvector<4xf32>> +} + // Check that unranked memrefs with non-default memory space roundtrip // properly. // CHECK-LABEL: @unranked_memref_roundtrip(memref<*xf32, 4>) diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1094,6 +1094,57 @@ // ----- +func @invalid_memref_vector_cast_1() { + %0 = alloc() : memref<2x4xf32> + %1 = memref_vector_cast %0 : memref<2x4xf32> to memref<2x2xf64> + // expected-error@-1 {{operand type 'memref<2x4xf32>' and result type 'memref<2x2xf64>' are cast incompatible}} +} + +// ----- + +func @invalid_memref_vector_cast_2() { + %0 = alloc() : memref<2x4xf32> + memref_vector_cast %0 : memref<2x4xf32> to memref<2xvector<4xf32>> + // expected-error@-1 {{operand type 'memref<2x4xf32>' and result type 'memref<2xvector<4xf32>>' are cast incompatible}} + return +} + +// ----- + +func @invalid_memref_vector_cast_3() { + %0 = alloc() : memref<2x4xf32> + memref_vector_cast %0 : memref<2x4xf32> to memref<2x?xvector<4xf32>> + // expected-error@-1 {{operand type 'memref<2x4xf32>' and result type 'memref<2x?xvector<4xf32>>' are cast incompatible}} + return +} + +// ----- + +func @invalid_memref_vector_cast_4() { + %0 = alloc() : memref<2x4xf32> + memref_vector_cast %0 : memref<2x4xf32> to memref> + // expected-error@-1 {{operand type 'memref<2x4xf32>' and result type 'memref>' are cast incompatible}} + return +} + +// ----- + +func @invalid_memref_vector_cast_5() { + %0 = alloc() : memref<2x4xf32> + %1 = memref_vector_cast %0 : memref<2x4xf32> to memref<4x2xf32> + // expected-error@-1 {{operand type 'memref<2x4xf32>' and result type 'memref<4x2xf32>' are cast incompatible}} +} + +// ----- + +func @invalid_memref_vector_cast_6() { + %0 = alloc() : memref<2x4xf32> + %1 = memref_vector_cast %0 : memref<2x4xf32> to memref<2x4xf32> + // expected-error@-1 {{operand type 'memref<2x4xf32>' and result type 'memref<2x4xf32>' are cast incompatible}} +} + +// ----- + func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) { // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<16x10xf32>) -> f32