diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,6 +14,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -581,7 +582,6 @@ let builders = [ OpBuilder<(ins "Value":$source, "int64_t":$index)>, - OpBuilder<(ins "Value":$source, "Value":$index)> ]; let extraClassDeclaration = [{ @@ -853,7 +853,8 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [ DeclareOpInterfaceMethods, Pure, - SameVariadicResultSize]> { + SameVariadicResultSize, + DeclareOpInterfaceMethods]> { let summary = "Extracts a buffer base with offset and strides"; let description = [{ Extracts a base buffer, offset and strides. This op allows additional layers diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -807,12 +807,6 @@ build(builder, result, source, indexValue); } -void DimOp::build(OpBuilder &builder, OperationState &result, Value source, - Value index) { - auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, source, index); -} - Optional DimOp::getConstantIndex() { if (auto constantOp = getIndex().getDefiningOp()) return constantOp.getValue().cast().getInt(); @@ -1254,6 +1248,32 @@ // ExtractStridedMetadataOp //===----------------------------------------------------------------------===// +/// The number and type of the results are inferred from the +/// shape of the source. +LogicalResult ExtractStridedMetadataOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions); + auto sourceType = extractAdaptor.getSource().getType().dyn_cast(); + if (!sourceType) + return failure(); + + unsigned sourceRank = sourceType.getRank(); + IndexType indexType = IndexType::get(context); + auto memrefType = + MemRefType::get({}, sourceType.getElementType(), + MemRefLayoutAttrInterface{}, sourceType.getMemorySpace()); + // Base. + inferredReturnTypes.push_back(memrefType); + // Offset. + inferredReturnTypes.push_back(indexType); + // Sizes and strides. + for (unsigned i = 0; i < sourceRank * 2; ++i) + inferredReturnTypes.push_back(indexType); + return success(); +} + void ExtractStridedMetadataOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getBaseBuffer(), "base_buffer"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -59,16 +59,12 @@ // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(subview(memref)). Location origLoc = op.getLoc(); - IndexType indexType = rewriter.getIndexType(); Value source = subview.getSource(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); - SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = - rewriter.create( - origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, - sizeStrideTypes, source); + rewriter.create(origLoc, source); SmallVector sourceStrides; int64_t sourceOffset; @@ -486,16 +482,12 @@ // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = op.getLoc(); - IndexType indexType = rewriter.getIndexType(); Value source = reshape.getSrc(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); - SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = - rewriter.create( - origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, - sizeStrideTypes, source); + rewriter.create(origLoc, source); // Collect statically known information. SmallVector strides;