diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -1437,4 +1437,33 @@ } // namespace llvm +namespace mlir { + +template +UnderlyingTy extractValue(Attribute attr); + +/// This class implements underlying value iterator for attributes. It is +/// templated by both the value and attribute types. +template +class attr_value_iterator final + : public llvm::mapped_iterator { +public: + using reference = UnderlyingTy; + explicit attr_value_iterator(ArrayAttr::iterator it) + : llvm::mapped_iterator( + it, &extractValue) {} + UnderlyingTy operator*() { return extractValue(*this->I); } +}; + +// Explicit instantiations and definitions of attribute value iterators. +template <> +inline int64_t extractValue(Attribute attr) { + return attr.cast().getValue().getSExtValue(); +} +using i64_attr_iterator = attr_value_iterator; +using i64_attr_range = llvm::iterator_range; + +} // namespace mlir + #endif diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -6,10 +6,11 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -31,6 +32,7 @@ #include "llvm/Support/ErrorHandling.h" using namespace mlir; +using namespace mlir::vector; template static LLVM::LLVMType getPtrToElementType(T containerType, @@ -723,15 +725,86 @@ } }; +static SmallVector getI64SubArray(ArrayAttr arrayAttr, + unsigned drop_front = 1, + unsigned drop_back = 0) { + auto range = + llvm::make_range(i64_attr_iterator(arrayAttr.begin() + drop_front), + i64_attr_iterator(arrayAttr.end() - drop_back)); + return llvm::to_vector<4>(range); +} + +static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, + int64_t offset) { + auto vectorType = vector.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, vector, offset); + return rewriter.create( + loc, vectorType.getElementType(), vector, + rewriter.create(loc, offset)); +} + +static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, + Value into, int64_t offset) { + auto vectorType = into.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, from, into, offset); + return rewriter.create( + loc, vectorType, from, into, + rewriter.create(loc, offset)); +} + +class VectorStridedSliceOpConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(StridedSliceOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getResult().getType().cast(); + + if (op.offsets().getValue().empty()) + return matchFailure(); + + int64_t offset = + op.offsets().getValue().front().cast().getInt(); + int64_t size = op.sizes().getValue().front().cast().getInt(); + int64_t stride = + op.strides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + auto elemType = dstType.getElementType(); + assert(elemType.isIntOrIndexOrFloat()); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value res = rewriter.create(loc, dstType, zero); + for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; + off += stride, ++idx) { + Value extracted = extractOne(rewriter, loc, op.vector(), off); + if (op.offsets().getValue().size() > 1) { + StridedSliceOp stridedSliceOp = rewriter.create( + loc, extracted, getI64SubArray(op.offsets()), + getI64SubArray(op.sizes()), getI64SubArray(op.strides())); + if (!matchAndRewrite(stridedSliceOp, rewriter)) + return matchFailure(); + extracted = stridedSliceOp; + } + res = insertOne(rewriter, loc, extracted, res, idx); + } + rewriter.replaceOp(op, {res}); + return matchSuccess(); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert(ctx); patterns.insert(converter.getDialect()->getContext(), - converter); + VectorPrintOpConversion>(ctx, converter); } namespace { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -423,3 +423,64 @@ // CHECK: llvm.call @print_close() : () -> () // CHECK: llvm.call @print_close() : () -> () // CHECK: llvm.call @print_newline() : () -> () + + +func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) { +// CHECK-LABEL: llvm.func @strided_slice( + + %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> + + %1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> +// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]"> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]"> +// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]"> + + %2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> +// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]"> +// +// Subvector vector<8xf32> @2 +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <2 x float>]"> +// +// Subvector vector<8xf32> @3 +// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <8 x float>]"> +// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> +// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]"> + + return +} + +