-
Notifications
You must be signed in to change notification settings - Fork 646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The contiguous
layout and exposing memref indexing
#19851
Comments
Addenda:
|
Instead of just saying |
@MaheshRavishankar I think that'll get me most of what I want? The main catch being that there's a lot of code that wants monotonically-increasing strides from right to left - that is, a lot of code that's asking the question |
Yeah you can have a helper method on the |
The problem
The current suite of
memref
layouts has a gap that makes optimizations difficult: there's no good way to encode "this has the identity layout but carries an offset around".For example, if I
, MLIR will instantly constant-fold
%offset
to0
, because this memref has the identity layout, which, by definition, doesn't apply an offset.This assumption prevents allocation coalescing. If you do a rewrite that, say, merges
into
, a bunch of code may already have already assumed that
%b
has no offset, and there's no way to find the places that that assumption was made in order to add the new offset back in.On the other hand, using the strided layout is too weak.
Specifically, if I have a
memref<? x ? x i32>
we know that that's still in the identity row-major layout. However, if I want to add an offset, I needmemref<? x ? x i32, strided<[?, 1], offset: ?>
... but here' that?
in the strides could be anything. Astrided
layout doesn't make promises about row-major contiguous indexing.Those promises are important for optimizations - for example, they let you use
affine.linearize_index
to compute an index, and there're several places in the vector lowering whereisIdentity()
is really used for the property that we're in a contiguous layout.The (proposed) plan: The
contiguous
layoutThere will be one more built-in memref layout,
contiguous<N, [offset: O]>
whereN
is an integer andO
is either a static value or?
.This layout, as an affine map, has the form,
(d0, d1, ..., dN) -> (d0, d1, ..., dN + s0)
, wheres0
is the offset (and so might be a constant.That is, it's exactly what we want: the row-major identity layout with an optional offset. That way, things that need to pattern match on "this is contiguous" can do so, and we can avoid the overly-strict implications of the identity layout.
(Note also, for instance, that
memref.reshape
really just needs a contigous layout, not the identity, for example)Radical extensions
We could reasonably change the default layout from the identity
AffineMap
tocontiguous<[rank]>
, akacontiguous<[rank], offset: 0>
. This has the advantage of making it clearer what's going on and having a simpler representation for the default case.Open questions: "normalized" memrefs
I need to do more research about what a "normalized" memref is in the context of the MLIR standard library in order to understand whether or not
contiguous<N, offset: ?>
fits the definition.Memref index exposure
The current lowering patterns for
memref.load
, for example, generate a lot of non-trivial arithmetic (multiplying indices by strides, etc.)We want this arithmetic to be visible to MLIR pre-lowering, so that passes like
-int-range-optimizations
or integer narrowing can run on it, for example.The 0-D form (on reflection, I don't think we want this), retained for history
To that end, we propose an operation
memref.add_offset
and a pass-memref-expand-indexing
.%q = memref.add_offset %p, %i : memref<contiguous<N, offset: ?>
is the operation that takes a memref%p
and returns one whereoffset(%q) == offset(%p) + %i
, where%i
is anindex
. For the LLVM backend, this could begetelementptr
, while for the SPIR-V one, all theseoffset
s can be accumulated into the index to an access chain.Then,
-memref-expand-indexing
will walk through a program and replace allmemref<S1 x S2 x ... x Sd x T, contigous<d, offset: O>
withmemref<product(S_i), contigous<1, offset: ?>
- or perhaps evenmemref<T, contiguous<0, offset: ?>>
, usingmemref.add_offset
to implement the indexing that's implicit in memref operations.That is, if we have
we'd rewrite it to something along the lines of
This exposes the indexing math used by
memref
and allows it to be operated on, instead of making it implicit in complex operations, which is useful late in a lowering process.Note that this does require mid-level vector ops, like
vector.transfer_read
, to be lowered tovector.load
in order to be useful.The 1-D form
To that end, we propose the pass
-memref-linearize-contigous-indexing
We still might want
memref.add_offset
, just as syntactic sugar for "get the old offset, add something to it, reinterpret_cast".With the 1-D form, we're still doing the linearization as above, but we'll want to keep the size on your memrefs.
That is, the example with the alloc above would be (assuming we allow
alloc
s to have non-trivial offsets)Prior art
The narrow type emulation patterns, as well as some lowerings to SPIR-V, sort of have patterns for this, but they're somewhat hampered by the lack of a
contiguous
layout.The narrow type emulation pass could by substantially simplified if the linearization were a separate step, leaving narrow type emulation as "convert all the
memref<i[small], contiguous<offset: ?>>
tomemref<i8, contiguous<offset: ?>>
and dividing all the arguments tomemref.add_offset
by the relevant scale factor (+ taking remainders to extract sub-byte values).One could even - for the LLVM lowering in particular - introduce a
-llvmir-byteify-memrefs
pass, which converts all memrefs toi8
memrefs before lowering, emitting amlir.llvm.sizeof
operation if it's needed. That would both expose the remaining bit of indexing math to MLIR optimizations and would cause MLIR to stop blocking theptradd
transition in LLVM - we've got some pretty fundamental usage ofgetelementptr [not i8]
in our lowering. Obviously other targets, like SPIR-V, which need the type information, wouldn't run this pass.Concrete steps
getlayout()
orisIdentity()
, or calls out theStridedLayoutAttr
by name in order to give it a new case forcontiguous
The text was updated successfully, but these errors were encountered: