Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1710,4 +1710,48 @@ let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; } +// Experimental op used by the path vector to GPU. This allows generating IR +// with correct type for cases where we need to load/store from a memref with +// different granularity. +def Vector_ReinterpretCastOp : Vector_Op<"reinterpret_cast", [NoSideEffect]>, + Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>, + Results<(outs StaticShapeMemRefOf<[AnyType]>:$result)> { + let summary = "reinterpret_cast converts a memref into an equivalent memref."; + + let description = [{ + Performs a conversion from a memref to an memref other memref with different + shape and or element type. This is needed to be able to access a memref with + different granularity without using subviews. This is analogue to a pointer + cast. + + This is an experimental op and not all lowering will be supported. + + Syntax: + + ``` + operation ::= `vector.reinterpret_cast` ssa-use : memref-type to memref-type + ``` + + Example: + + ```mlir + %A = alloc() : memref<128x2xvector<4xi32>> + %RA = vector.reinterpret_cast %A : memref<128x2xvector<4xi32>> to memref<128x32xi8> + ``` + }]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref().getType().cast(); + } + MemRefType getResultMemRefType() { + return getResult().getType().cast(); + } + }]; + + let assemblyFormat = [{ + $memref attr-dict `:` type($memref) `to` type($result) + }]; +} + #endif // VECTOR_OPS Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2293,6 +2293,26 @@ return success(); } +//===----------------------------------------------------------------------===// +// ReinterpretCastOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReinterpretCastOp op) { + MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); + if (!canonicalType.getAffineMaps().empty()) + return op.emitOpError("expects operand to be a memref with no layout"); + if (!op.getResultMemRefType().getAffineMaps().empty()) + return op.emitOpError("expects result to be a memref with no layout"); + if (op.getResultMemRefType().getMemorySpace() != + op.getMemRefType().getMemorySpace()) + return op.emitOpError("expects result in same memory space"); + if (op.getResultMemRefType().getSizeInBits() != + op.getMemRefType().getSizeInBits()) + return op.emitOpError( + "expects result type to have the same size as operand"); + return success(); +} + namespace { // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. Index: mlir/lib/IR/StandardTypes.cpp =================================================================== --- mlir/lib/IR/StandardTypes.cpp +++ mlir/lib/IR/StandardTypes.cpp @@ -230,10 +230,11 @@ if (elementType.isIntOrFloat()) return elementType.getIntOrFloatBitWidth() * getNumElements(); - // Tensors can have vectors and other tensors as elements, other shaped types - // cannot. - assert(isa() && "unsupported element type"); - assert((elementType.isa()) && + // Tensors and Memref can have vectors and other tensors as elements, other + // shaped types cannot. + assert((isa() || isa()) && + "unsupported element type"); + assert((elementType.isa()) && "unsupported tensor element type"); return getNumElements() * elementType.cast().getSizeInBits(); } Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -1240,3 +1240,11 @@ // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}} vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref } + +// ----- + +func @reinterpret_cast_size_mismatch(%base: memref<7x24xi32>) -> memref<7x10xvector<4xi32>> { + // expected-error@+1 {{'vector.reinterpret_cast' op expects result type to have the same size as operand}} + %0 = vector.reinterpret_cast %base : memref<7x24xi32> to memref<7x10xvector<4xi32>> + return %0 : memref<7x10xvector<4xi32>> +} Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -379,3 +379,17 @@ vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref return } + +// CHECK-LABEL: @reinterpret_cast +func @reinterpret_cast(%base: memref<7x12xi8>) -> memref<7x3xi32> { + // CHECK: %{{.*}} = vector.reinterpret_cast %{{.*}} : memref<7x12xi8> to memref<7x3xi32> + %0 = vector.reinterpret_cast %base : memref<7x12xi8> to memref<7x3xi32> + return %0 : memref<7x3xi32> +} + +// CHECK-LABEL: @reinterpret_cast_vector +func @reinterpret_cast_vector(%base: memref<7x24xi32>) -> memref<7x6xvector<4xi32>> { + // CHECK: %{{.*}} = vector.reinterpret_cast %{{.*}} : memref<7x24xi32> to memref<7x6xvector<4xi32>> + %0 = vector.reinterpret_cast %base : memref<7x24xi32> to memref<7x6xvector<4xi32>> + return %0 : memref<7x6xvector<4xi32>> +}