Index: mlir/include/mlir/Dialect/VectorOps/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -943,15 +943,15 @@ def Vector_TypeCastOp : Vector_Op<"type_cast", [NoSideEffect]>, - Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>, + Arguments<(ins AnyMemRef:$memref)>, Results<(outs AnyMemRef)> { let summary = "type_cast op converts a scalar memref to a vector memref"; let description = [{ - Performs a conversion from a memref with scalar element to a memref with a - *single* vector element, copying the shape of the memref to the vector. This - is the minimal viable operation that is required to makeke - super-vectorization operational. It can be seen as a special case of the - `view` operation but scoped in the super-vectorization context. + Performs a conversion from a memref with scalar element to a memref with + vector elements, copying the shape of the memref to the vector. This + is the minimal viable operation that is required to make super-vectorization + operational. It can be seen as a special case of the `view` operation but + scoped in the super-vectorization context. Syntax: @@ -963,7 +963,7 @@ ```mlir %A = alloc() : memref<5x4x3xf32> - %VA = vector.type_cast %A : memref<5x4x3xf32> to memref> + %VA = vector.type_cast %A : memref<5x4x3xf32> to memref<4xvector<5x3xf32>> ``` }]; Index: mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp =================================================================== --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -703,11 +703,7 @@ castOp.getOperand().getType().cast(); MemRefType targetMemRefType = castOp.getResult().getType().cast(); - - // Only static shape casts supported atm. - if (!sourceMemRefType.hasStaticShape() || - !targetMemRefType.hasStaticShape()) - return matchFailure(); + auto vectorType = targetMemRefType.getElementType().cast(); auto llvmSourceDescriptorTy = operands[0].getType().dyn_cast(); @@ -720,23 +716,32 @@ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); + // Only contiguous source tensors supported atm. int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(sourceMemRefType, strides, offset); - bool isContiguous = (strides.back() == 1); - if (isContiguous) { - auto sizes = sourceMemRefType.getShape(); - for (int index = 0, e = strides.size() - 2; index < e; ++index) { - if (strides[index] != strides[index + 1] * sizes[index + 1]) { + bool isContiguous = true; + int64_t reshapeStride = 1; + int64_t reshapeIndex; + { + auto memrefSizes = sourceMemRefType.getShape(); + auto vectorSizes = vectorType.getShape(); + int64_t vectorSize = 1; + for (int index = 0; index < (int)vectorSizes.size(); ++index) + vectorSize *= vectorSizes[index]; + for (reshapeIndex = strides.size() - 1; + isContiguous && reshapeIndex >= 0 && memrefSizes[reshapeIndex] > 0; + --reshapeIndex) + if (reshapeStride != strides[reshapeIndex]) isContiguous = false; - break; - } - } + else + reshapeStride *= memrefSizes[reshapeIndex]; + if (reshapeStride % vectorSize != 0) + isContiguous = false; + if (failed(successStrides) || !isContiguous) + return matchFailure(); } - // Only contiguous source tensors supported atm. - if (failed(successStrides) || !isContiguous) - return matchFailure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); @@ -764,10 +769,17 @@ rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); + int64_t stride; + if (index > reshapeIndex) + stride = strides[index]; + else { + stride = reshapeStride = reshapeStride / size; + } auto strideAttr = - rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); - auto stride = rewriter.create(loc, int64Ty, strideAttr); - desc.setStride(rewriter, loc, index, stride); + rewriter.getIntegerAttr(rewriter.getIndexType(), stride); + auto strideConstant = + rewriter.create(loc, int64Ty, strideAttr); + desc.setStride(rewriter, loc, index, strideConstant); } rewriter.replaceOp(op, {desc}); Index: mlir/lib/Dialect/VectorOps/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1604,13 +1604,52 @@ static void print(OpAsmPrinter &p, TypeCastOp op) { auto type = op.getOperand().getType().cast(); p << op.getOperationName() << ' ' << op.memref() << " : " << type << " to " - << inferVectorTypeCastResultType(type); + << op.getResultMemRefType(); } +// Any shape is logically good as long as only fixed dimensions are statically +// reshaped static LogicalResult verify(TypeCastOp op) { - auto resultType = inferVectorTypeCastResultType(op.getMemRefType()); - if (op.getResultMemRefType() != resultType) - return op.emitOpError("expects result type to be: ") << resultType; + MemRefType memrefType = op.getMemRefType(); + auto memrefShape = memrefType.getShape(); + MemRefType resultType = op.getResultMemRefType(); + auto resultShape = resultType.getShape(); + int64_t i, j; + int64_t leftShape = 1; + + auto elementType = resultType.getElementType(); + if (!elementType.isa()) + return op.emitOpError("must cast to a vector type"); + + for (i = memrefShape.size() - 1; i >= 0 && memrefShape[i] > 0; --i) + leftShape *= memrefShape[i]; + + for (j = resultShape.size() - 1; j >= 0 && resultShape[j] > 0; --j) + if (leftShape % resultShape[j] != 0) + return op.emitOpError("shape mismatch"); + else + leftShape /= resultShape[j]; + + bool matchDims = true; + if (i != j) + matchDims = false; + for (; j >= 0 && matchDims; --j) + if (resultShape[j] != memrefShape[j]) + matchDims = false; + + if (!matchDims) + return op.emitOpError("rest of the dimensions don't match"); + auto vectorType = elementType.cast(); + auto vectorShape = vectorType.getShape(); + for (uint64_t i = 0; i < vectorShape.size(); ++i) + if (leftShape % vectorShape[i] != 0) + return op.emitOpError("not a perfect reshape"); + else + leftShape /= vectorShape[i]; + + if (leftShape > 1) + return op.emitOpError("reshape have leftover"); + return success(); } Index: mlir/test/Dialect/VectorOps/invalid.mlir =================================================================== --- mlir/test/Dialect/VectorOps/invalid.mlir +++ mlir/test/Dialect/VectorOps/invalid.mlir @@ -889,3 +889,36 @@ %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4] : vector<3x2x4xf32> to vector<2x3x5xf32> } + +// ----- + +func @type_cast_not_vector(%arg0 : memref<8xf32>) { + // expected-error@+1 {{must cast to a vector type}} + %0 = vector.type_cast %arg0 : memref<8xf32> to memref<2x4xf32> + return +} + +// ----- + +func @type_cast_unchanged_dimensions(%arg0 : memref<8x?x8xf32>) { + // expected-error@+1 {{rest of the dimensions don't match}} + %0 = vector.type_cast %arg0 : memref<8x?x8xf32> to memref<4x?xvector<8xf32>> + return +} + +// ----- + +func @type_cast_not_perfect_reshape(%arg0 : memref<8xf32>) { + // expected-error@+1 {{not a perfect reshape}} + %0 = vector.type_cast %arg0 : memref<8xf32> to memref<2xvector<3xf32>> + return +} + +// ----- + +func @type_cast_not_perfect_reshape(%arg0 : memref<8xf32>) { + // expected-error@+1 {{reshape have leftover}} + %0 = vector.type_cast %arg0 : memref<8xf32> to memref<2xvector<2xf32>> + return +} + Index: mlir/test/Dialect/VectorOps/ops.mlir =================================================================== --- mlir/test/Dialect/VectorOps/ops.mlir +++ mlir/test/Dialect/VectorOps/ops.mlir @@ -233,3 +233,15 @@ return %1 : vector<2x3x4xf32> } + +// CHECK-LABEL: typecast +func @typecast(%arg0 : memref) { + // CHECK: vector.type_cast %arg0 : memref to memref> + %0 = vector.type_cast %arg0 : memref to memref> + // CHECK: vector.type_cast %arg0 : memref to memref> + %1 = vector.type_cast %arg0 : memref to memref> + // CHECK: vector.type_cast %arg0 : memref to memref> + %2 = vector.type_cast %arg0 : memref to memref> + return +} +