diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -905,6 +905,11 @@ Variadic:$strides ); + // Build `extract_strided_metadata(source)`. + // The number and type of the results are inferred from the + // shape of the source. + let builders = [OpBuilder<(ins "Value":$source)>]; + let assemblyFormat = [{ $source `:` type($source) `->` type(results) attr-dict }]; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1240,6 +1240,20 @@ // ExtractStridedMetadataOp //===----------------------------------------------------------------------===// +void ExtractStridedMetadataOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + Value source) { + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + IndexType indexType = builder.getIndexType(); + SmallVector sizeStrideTypes(sourceRank, indexType); + auto memrefType = + MemRefType::get({}, sourceType.getElementType(), + MemRefLayoutAttrInterface{}, sourceType.getMemorySpace()); + ExtractStridedMetadataOp::build(builder, state, memrefType, indexType, + sizeStrideTypes, sizeStrideTypes, source); +} + void ExtractStridedMetadataOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getBaseBuffer(), "base_buffer"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -59,16 +59,12 @@ // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(subview(memref)). Location origLoc = op.getLoc(); - IndexType indexType = rewriter.getIndexType(); Value source = subview.getSource(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); - SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = - rewriter.create( - origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, - sizeStrideTypes, source); + rewriter.create(origLoc, source); SmallVector sourceStrides; int64_t sourceOffset; @@ -486,16 +482,12 @@ // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = op.getLoc(); - IndexType indexType = rewriter.getIndexType(); Value source = reshape.getSrc(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); - SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = - rewriter.create( - origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, - sizeStrideTypes, source); + rewriter.create(origLoc, source); // Collect statically known information. SmallVector strides;